修改 PyTorch 训练脚本 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅中国的 Amazon Web Services 服务入门

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

修改 PyTorch 训练脚本

在 SageMaker 数据 parallel 库 v1.4.0 及更高版本,该库可作为后端选项PyTorch 分布式程序. 只需在训练脚本顶部导入库一次,然后将其设置为 PyTorch 初始化期间分布式后端。使用单行后端规范,您可以保留 PyTorch 训练脚本不变并直接使用 PyTorch 分布式模块。要查找该库的最新 API 文档,请参阅SageMaker 为 PyTorch 分布式数据 parallel API中的SageMaker Python 开发工具包文. 了解有关的更多信息 PyTorch 分布式包和后端选项,请参阅分布式通信包-torch.分布式/分布.

重要

由于 SageMaker 分布式数据并行度库 v1.4.0 及更高版本可作为 PyTorch 已分发,以下sm分布式API(对于 ) PyTorch 已弃用分布式软件包。

如果需要使用库的早期版本(v1.3.0 或更早版本),请参阅archived SageMaker 分布式数据 parallel 库文档中的SageMaker Python 开发工具包文.

使用 SageMaker 作为后端的分布式数据并行库torch.distributed

使用 SageMaker 分布式数据 parallel 库,您唯一需要做的是导入 SageMaker 分布式数据 parallel 库 PyTorch 客户端 (smdistributed.dataparallel.torch.torch_smddp)。客户注册smddp作为 PyTorch 的后端。当你初始化 PyTorch 分布式进程组使用torch.distributed.init_process_groupAPI,请务必指定'smddp'backend参数。

import smdistributed.dataparallel.torch.torch_smddp import torch.distributed as dist dist.init_process_group(backend='smddp')
注意

这些区域有:smddp后端目前不支持使用torch.distributed.new_group()API。您无法使用smddp后端与其他进程组后端(例如 NCCL 和 Gloo)同时进行。

如果您已有工作方式 PyTorch 脚本,只需添加后端规范,就可以继续使用 SageMaker 框架估算器对于 PyTorch TensorFlow中的第 2 步:启动 SageMaker 使用分布式培训 Job SageMaker Python 开发工具包主题。

如果您仍需修改训练脚本才能正确使用 PyTorch 分布式软件包,请按照本页面上的其余步骤进行操作。

准备 PyTorch 分布式训练的训练脚本

以下步骤提供了有关如何准备训练脚本以使用 PyTorch 成功运行分布式训练作业的其他提示。

注意

在 v1.4.0 中, SageMaker 分布式数据 parallel 库支持以下集体原始数据类型torch.分布式/分布接口:all_reducebroadcastreduceall_gather, 和barrier.

  1. 导入 PyTorch 分布式模块。

    import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP
  2. 在解析参数并定义批处理大小参数之后(例如,batch_size=args.batch_size) 中,添加两行代码以调整每个工作人员的批量大小 (GPU)。PyTorch 的 DataLoader 操作不会自动处理分布式训练的批量调整大小。

    batch_size //= dist.get_world_size() batch_size = max(batch_size, 1)
  3. 将每个 GPU 固定到单个 GPU SageMaker 数据 parallel 库进程local_rank— 这是指给定节点内进程的相对排名。

    您可以从LOCAL_RANK环境变量。

    import os local_rank = os.environ["LOCAL_RANK"] torch.cuda.set_device(local_rank)
  4. 定义模型后,用 PyTorch 包装模型DistributedDataParallelAPI。

    model = ... # Wrap the model with the PyTorch DistributedDataParallel API model = DDP(model)
  5. 当你打电话给torch.utils.data.distributed.DistributedSamplerAPI 中,指定在集群中所有节点参与训练的进程 (GPU) 总数。这叫world_size,你可以从torch.distributed.get_world_size()API。此外,还可以使用torch.distributed.get_rank()API。

    from torch.utils.data.distributed import DistributedSampler train_sampler = DistributedSampler( train_dataset, num_replicas = dist.get_world_size(), rank = dist.get_rank() )
  6. 修改脚本,以便仅在领导进程中保存检查点(等级 0)。领导者进程具有同步的模型。这也避免了其他过程覆盖检查点,并可能会损坏检查点。

    if dist.get_rank() == 0: torch.save(...)

以下示例代码显示了 PyTorch 训练脚本smddp作为后端。

import os import torch # SageMaker data parallel: Import the library PyTorch API import smdistributed.dataparallel.torch.torch_smddp # SageMaker data parallel: Import PyTorch's distributed API import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP # SageMaker data parallel: Initialize the process group dist.init_process_group(backend='smddp') class Net(nn.Module): ... # Define model def train(...): ... # Model training def test(...): ... # Model evaluation def main(): # SageMaker data parallel: Scale batch size by world size batch_size //= dist.get_world_size() batch_size = max(batch_size, 1) # Prepare dataset train_dataset = torchvision.datasets.MNIST(...) # SageMaker data parallel: Set num_replicas and rank in DistributedSampler train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()) train_loader = torch.utils.data.DataLoader(..) # SageMaker data parallel: Wrap the PyTorch model with the library's DDP model = DDP(Net().to(device)) # SageMaker data parallel: Pin each GPU to a single library process. local_rank = os.environ["LOCAL_RANK"] torch.cuda.set_device(local_rank) model.cuda(local_rank) # Train optimizer = optim.Adadelta(...) scheduler = StepLR(...) for epoch in range(1, args.epochs + 1): train(...) if rank == 0: test(...) scheduler.step() # SageMaker data parallel: Save model on master node. if dist.get_rank() == 0: torch.save(...) if __name__ == '__main__': main()

完成训练脚本调整后,请继续第 2 步:启动 SageMaker 使用分布式培训 Job SageMaker Python 开发工具包.