本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
使用 Tensor 并行度进行检查点的说明
这些区域有: SageMaker 模型 parallel 库支持使用张量并行性来保存部分或完整的检查点。以下指南介绍了在使用 tensor 并行机制时如何修改脚本以保存和加载检查点。
-
准备一个模型对象并用库的包装函数包装它
smp.DistributedModel().model = MyModel(...) model = smp.DistributedModel(model) -
为模型准备优化器。一组模型参数是优化器函数所需的可迭代参数。要准备一组模型参数,必须处理
model.parameters()为单个模型参数分配唯一 ID。如果模型参数可迭代对象中存在具有重复 ID 的参数,则加载检查点优化程序状态将失败。要为优化程序创建具有唯一 ID 的模型参数的可迭代对象,请参阅以下内容:
unique_params = [] unique_params_set = set() for p in model.parameters(): if p not in unique_params_set: unique_params.append(p) unique_params_set.add(p) del unique_params_set optimizer = MyOpt(unique_params, ...) -
使用库的包装函数包装优化器
smp.DistributedOptimizer().optimizer = smp.DistributedOptimizer(optimizer) -
使用保存模型和优化程序状态
smp.save(). 根据您要保存检查点的方式,请选择以下两个选项之一: -
选项 1:在每个模型上保存部分模型
mp_rank单个MP_GROUP.model_dict = model.local_state_dict() # save a partial model opt_dict = optimizer.local_state_dict() # save a partial optimizer state # Save the dictionaries at rdp_rank 0 as a checkpoint if smp.rdp_rank() == 0: smp.save( {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, f"/checkpoint.pt", partial=True, )使用 tensor 并行性,库保存以下格式命名的检查点文件:
checkpoint.pt_{pp_rank}_{tp_rank}.注意 使用张量并行度,请务必将 if 语句设置为
if smp.rdp_rank() == 0而不是if smp.dp_rank() == 0. 当优化程序状态与张量并行分区时,所有减少的数据并 parallel 排名都必须保存自己的优化程序状态分区。使用错误如果检查点声明可能会导致培训工作停滞。有关使用的更多信息if smp.dp_rank() == 0没有张量并行性,请参见保存和加载的一般说明中的SageMaker Python 开发工具包文.
-
选项 2:保存完整模型。
if smp.rdp_rank() == 0: model_dict = model.state_dict(gather_to_rank0=True) # save the full model if smp.rank() == 0: smp.save( {"model_state_dict": model_dict}, "/checkpoint.pt", partial=False, )注意 请考虑以下内容进行完整的检查点:
-
如果你设置
gather_to_rank0=True,除以外的所有级别0返回空字典。 -
对于完整的检查点,你只能检查模型。目前不支持优化程序状态的完整检查点。
-
完整模型只需要保存在
smp.rank() == 0.
-
-
-
使用加载检查点
smp.load(). 根据您在上一步中检查点的方式,选择以下两个选项之一: -
选项 1:加载部分检查点。
checkpoint = smp.load("/checkpoint.pt", partial=True) model.load_state_dict(checkpoint["model_state_dict"], same_partition_load=False) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])您可以设置
same_partition_load=True在model.load_state_dict()为了更快地加载,如果你知道分区不会改变。 -
选项 2:加载完整的检查点。
if smp.rdp_rank() == 0: checkpoint = smp.load("/checkpoint.pt", partial=False) model.load_state_dict(checkpoint["model_state_dict"])这些区域有:
if smp.rdp_rank() == 0条件不是必需的,但它可以帮助避免不同之间的冗余装载MP_GROUP。张量并行性目前不支持完整的检查点优化器状态 dict。
-