本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
激活检查点
激活检查点(或渐变检查点) 是一种通过清除某些层的激活并在向后过程中重新计算它们来减少内存使用的技术。实际上,这会交换额外的计算时间以减少内存使用量。如果模块处于选中点状态,则在向前传结束时,模块的输入和输出将保留在内存中。任何本来是该模块内计算的一部分的中间张量都会在向前传递期间释放出来。在检查点模块的向后传递期间,这些张量被重新计算。此时,超出此检查点模块的层已经完成了向后传递,因此带检查点的峰值内存使用率可能会降低。
如何使用激活检查点
与smdistributed.modelparallel,你可以在模块的粒度上使用激活检查点。对于所有torch.nn模块除外torch.nn.Sequential,只有从管道并行度的角度来看,如果模块树位于一个分区内,才能检查模块树。如果是torch.nn.Sequential模块,顺序模块内的每个模块树都需要完全位于一个分区内才能启用激活检查点才能正常工作。当您使用手动分区时,您需要了解这些限制。
当您使用自动模型分区,你可以找到分区分配日志Partition assignments:在培训作业日志中。如果一个模块被分成多个等级(例如,一个子体在一个等级,另一个子体处于不同等级),则库将忽略检查模块的尝试,并引发一条警告消息,说明该模块不会被检查点。
这些区域有: SageMaker 模型并行库同时支持重叠和非重叠allreduce结合检查点。
PyTorch 的原生检查点 API 不兼容smdistributed.modelparallel.
示例 1:以下示例代码显示了在脚本中有模型定义时如何使用激活检查点。
import torch.nn as nn import torch.nn.functional as F from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = F.max_pool2d(x, 2) x = torch.flatten(x, 1) # This call of fc1 will be checkpointed x = checkpoint(self.fc1, x) x = self.fc2(x) return F.log_softmax(x, 1)
示例 2:以下示例代码显示了在脚本中有顺序模型时如何使用激活检查点。
import torch.nn as nn from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint_sequential class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.seq = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) def forward(self, x): # This call of self.seq will be checkpointed x = checkpoint_sequential(self.seq, x) return F.log_softmax(x, 1)
示例 3:以下示例代码显示了从库导入预构建模型时如何使用激活检查点,例如 PyTorch 和拥抱脸变形金刚。无论您是否检查点顺序模块,您都需要执行以下操作:
-
将模型包裹起来
smp.DistributedModel(). -
为顺序图层定义对象。
-
按顺序图层对象包裹
smp.set_activation_checkpointig().
import smdistributed.modelparallel.torch as smp from transformers import AutoModelForCausalLM smp.init() model = AutoModelForCausalLM(*args, **kwargs) model = smp.DistributedModel(model) # Call set_activation_checkpointing API transformer_layers = model.module.module.module.transformer.seq_layers smp.set_activation_checkpointing( transformer_layers, pack_args_as_tuple=True, strategy='each')