本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
Support 拥抱脸变压器模型
这些区域有: SageMaker 模型 parallel 库的张量并行性提供 out-of-the-box 支持以下拥抱脸变压器型号:
-
GPT-2、BERT 和 Roberta(可在 SageMaker 模型 parallel 库 v1.7.0 及更高版本)
-
GPT-J(可在 SageMaker 模型 parallel 库 v1.8.0 及更高版本)
对于任何其他变形金刚型号,你需要使用smp.tp_注册er_Wit_module ()
要使用 tensor 并行性来训练 Huging Face 变压器模型,你应该添加transformers==4.4.2到requirements.txt用于在 PyTorch v1.8.1 及更高版本的深度学习容器中安装变形金刚库。
如果您使用其中一个拥抱 Face 变压器模型,则无需手动实现挂钩即可将变压器 API 转换为smdistributed变压器层。你也不需要使用smp.tp_registerAPI。你可以通过使用上下文管理器激活张量并行smp.tensor_parallelism()然后通过包装模型smp.DistributedModel()
from transformers import AutoModelForCausalLM with smp.tensor_parallelism(): model = AutoModelForCausalLM.from_config(hf_gpt2_config) model = smp.DistributedModel(model)
此外,鉴于state_dict来自 的DistributedModel对象,你可以将权重加载到原始 HuggingFace GPT-2 模型使用translate_state_dict_to_hf_gpt2API:
from smdistributed.modelparallel.torch.nn.huggingface.gpt2 \ import translate_state_dict_to_hf_gpt2 max_seq_len = 1024 with smp.tensor_parallelism(): model = AutoModelForCausalLM.from_config(hf_gpt2_config) model = smp.DistributedModel(model) # [...run training...] if smp.rdp_rank() == 0: state_dict = dist_model.state_dict() hf_state_dict = translate_state_dict_to_hf_gpt2(state_dict, max_seq_len) # can now call model.load_state_dict(hf_state_dict) to the original HF model
同样,给定受支持 HuggingFace 模型state_dict,您可以使用translate_hf_state_dict_to_smdistributedAPI 将其转换为可读的格式smp.DistributedModel. 这在转移学习用例中非常有用,在这种情况下,预训练的模型被加载到smp.DistributedModel对于模型并行微调:
from smdistributed.modelparallel.torch.nn.huggingface.roberta \ import translate_state_dict_to_smdistributed model = AutoModelForMaskedLM.from_config(roberta_config) model = smp.DistributedModel(model) pretrained_model = AutoModelForMaskedLM.from_pretrained("roberta-large") translated_state_dict = translate_state_dict_to_smdistributed(pretrained_model.state_dict()) # load the translated pretrained weights into the smp.DistributedModel model.load_state_dict(translated_state_dict) # start fine-tuning...
相关state_dict拥抱 Face 和之间的翻译功能smp可以按照如下方式访问。
-
from smdistributed.modelparallel.torch.nn.huggingface.gpt2 import translate_state_dict_to_hf_gpt2 -
from smdistributed.modelparallel.torch.nn.huggingface.gpt2 import translate_hf_state_dict_to_smdistributed -
from smdistributed.modelparallel.torch.nn.huggingface.bert import translate_state_dict_to_hf_bert -
from smdistributed.modelparallel.torch.nn.huggingface.bert import translate_hf_state_dict_to_smdistributed -
from smdistributed.modelparallel.torch.nn.huggingface.roberta import translate_state_dict_to_hf_roberta -
from smdistributed.modelparallel.torch.nn.huggingface.roberta import translate_hf_state_dict_to_smdistributed -
from smdistributed.modelparallel.torch.nn.huggingface.gptj import translate_hf_gptj_state_dict_to_smdistributed(在中可用 SageMaker 模型 parallel 库 v1.8.0 及更高版本) -
from smdistributed.modelparallel.torch.nn.huggingface.gptj import translate_state_dict_to_hf_gptj(在中可用 SageMaker 模型 parallel 库 v1.8.0 及更高版本)