Support 拥抱脸变压器模型 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅中国的 Amazon Web Services 服务入门

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

Support 拥抱脸变压器模型

这些区域有: SageMaker 模型 parallel 库的张量并行性提供 out-of-the-box 支持以下拥抱脸变压器型号:

  • GPT-2、BERT 和 Roberta(可在 SageMaker 模型 parallel 库 v1.7.0 及更高版本)

  • GPT-J(可在 SageMaker 模型 parallel 库 v1.8.0 及更高版本)

注意

对于任何其他变形金刚型号,你需要使用smp.tp_注册er_Wit_module ()API 来应用张量并行性。

注意

要使用 tensor 并行性来训练 Huging Face 变压器模型,你应该添加transformers==4.4.2requirements.txt用于在 PyTorch v1.8.1 及更高版本的深度学习容器中安装变形金刚库。

如果您使用其中一个拥抱 Face 变压器模型,则无需手动实现挂钩即可将变压器 API 转换为smdistributed变压器层。你也不需要使用smp.tp_registerAPI。你可以通过使用上下文管理器激活张量并行smp.tensor_parallelism()然后通过包装模型smp.DistributedModel(). 例如,请参阅以下代码:

from transformers import AutoModelForCausalLM with smp.tensor_parallelism(): model = AutoModelForCausalLM.from_config(hf_gpt2_config) model = smp.DistributedModel(model)

此外,鉴于state_dict来自 的DistributedModel对象,你可以将权重加载到原始 HuggingFace GPT-2 模型使用translate_state_dict_to_hf_gpt2API:

from smdistributed.modelparallel.torch.nn.huggingface.gpt2 \ import translate_state_dict_to_hf_gpt2 max_seq_len = 1024 with smp.tensor_parallelism(): model = AutoModelForCausalLM.from_config(hf_gpt2_config) model = smp.DistributedModel(model) # [...run training...] if smp.rdp_rank() == 0: state_dict = dist_model.state_dict() hf_state_dict = translate_state_dict_to_hf_gpt2(state_dict, max_seq_len) # can now call model.load_state_dict(hf_state_dict) to the original HF model

同样,给定受支持 HuggingFace 模型state_dict,您可以使用translate_hf_state_dict_to_smdistributedAPI 将其转换为可读的格式smp.DistributedModel. 这在转移学习用例中非常有用,在这种情况下,预训练的模型被加载到smp.DistributedModel对于模型并行微调:

from smdistributed.modelparallel.torch.nn.huggingface.roberta \ import translate_state_dict_to_smdistributed model = AutoModelForMaskedLM.from_config(roberta_config) model = smp.DistributedModel(model) pretrained_model = AutoModelForMaskedLM.from_pretrained("roberta-large") translated_state_dict = translate_state_dict_to_smdistributed(pretrained_model.state_dict()) # load the translated pretrained weights into the smp.DistributedModel model.load_state_dict(translated_state_dict) # start fine-tuning...

相关state_dict拥抱 Face 和之间的翻译功能smp可以按照如下方式访问。

  • from smdistributed.modelparallel.torch.nn.huggingface.gpt2 import translate_state_dict_to_hf_gpt2

  • from smdistributed.modelparallel.torch.nn.huggingface.gpt2 import translate_hf_state_dict_to_smdistributed

  • from smdistributed.modelparallel.torch.nn.huggingface.bert import translate_state_dict_to_hf_bert

  • from smdistributed.modelparallel.torch.nn.huggingface.bert import translate_hf_state_dict_to_smdistributed

  • from smdistributed.modelparallel.torch.nn.huggingface.roberta import translate_state_dict_to_hf_roberta

  • from smdistributed.modelparallel.torch.nn.huggingface.roberta import translate_hf_state_dict_to_smdistributed

  • from smdistributed.modelparallel.torch.nn.huggingface.gptj import translate_hf_gptj_state_dict_to_smdistributed(在中可用 SageMaker 模型 parallel 库 v1.8.0 及更高版本)

  • from smdistributed.modelparallel.torch.nn.huggingface.gptj import translate_state_dict_to_hf_gptj(在中可用 SageMaker 模型 parallel 库 v1.8.0 及更高版本)