- TensorFlow
-
访问深度分析功能TensorFlow,目前您需要指定最新的Amazon使用 CUDA 11 进行深度学习容器映像。例如,您必须指定具体映像 URI,如以下示例代码所示:
# An example of constructing a SageMaker TensorFlow estimator
import boto3
import sagemaker
from sagemaker.tensorflow import TensorFlow
from sagemaker.debugger import ProfilerConfig, DebuggerHookConfig, Rule, ProfilerRule, rule_configs
session=boto3.session.Session()
region=session.region_name
profiler_config=ProfilerConfig(...)
debugger_hook_config=DebuggerHookConfig(...)
rules=[
Rule.sagemaker(rule_configs.built_in_rule()),
ProfilerRule.sagemaker(rule_configs.BuiltInRule())
]
estimator=TensorFlow(
entry_point="directory/to/your_training_script.py",
role=sagemaker.get_execution_role(),
base_job_name="debugger-demo",
instance_count=1,
instance_type="ml.p3.2xlarge",
image_uri=f"763104351884.dkr.ecr.{region}.amazonaws.com/tensorflow-training:2.3.1-gpu-py37-cu110-ubuntu18.04"
# Debugger-specific parameters
profiler_config=profiler_config,
debugger_hook_config=debugger_hook_config,
rules=rules
)
estimator.fit(wait=False)
- PyTorch
-
访问深度分析功能PyTorch,目前您需要指定最新的Amazon使用 CUDA 11 进行深度学习容器映像。例如,您必须指定具体映像 URI,如以下示例代码所示:
# An example of constructing a SageMaker PyTorch estimator
import boto3
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker.debugger import ProfilerConfig, DebuggerHookConfig, Rule, ProfilerRule, rule_configs
session=boto3.session.Session()
region=session.region_name
profiler_config=ProfilerConfig(...)
debugger_hook_config=DebuggerHookConfig(...)
rules=[
Rule.sagemaker(rule_configs.built_in_rule()),
ProfilerRule.sagemaker(rule_configs.BuiltInRule())
]
estimator=PyTorch(
entry_point="directory/to/your_training_script.py",
role=sagemaker.get_execution_role(),
base_job_name="debugger-demo",
instance_count=1,
instance_type="ml.p3.2xlarge",
image_uri=f"763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-training:1.6.0-gpu-py36-cu110-ubuntu18.04",
# Debugger-specific parameters
profiler_config=profiler_config,
debugger_hook_config=debugger_hook_config,
rules=rules
)
estimator.fit(wait=False)
- MXNet
-
# An example of constructing a SageMaker MXNet estimator
import sagemaker
from sagemaker.mxnet import MXNet
from sagemaker.debugger import ProfilerConfig, DebuggerHookConfig, Rule, ProfilerRule, rule_configs
profiler_config=ProfilerConfig(...)
debugger_hook_config=DebuggerHookConfig(...)
rules=[
Rule.sagemaker(rule_configs.built_in_rule()),
ProfilerRule.sagemaker(rule_configs.BuiltInRule())
]
estimator=MXNet(
entry_point="directory/to/your_training_script.py",
role=sagemaker.get_execution_role(),
base_job_name="debugger-demo",
instance_count=1,
instance_type="ml.p3.2xlarge",
framework_version="1.7.0",
py_version="py37",
# Debugger-specific parameters
profiler_config=profiler_config,
debugger_hook_config=debugger_hook_config,
rules=rules
)
estimator.fit(wait=False)
对于 MXNet,在配置profiler_config参数,您只能配置系统监控。MxNet 不支持分析框架指标。
- XGBoost
-
# An example of constructing a SageMaker XGBoost estimator
import sagemaker
from sagemaker.xgboost.estimator import XGBoost
from sagemaker.debugger import ProfilerConfig, DebuggerHookConfig, Rule, ProfilerRule, rule_configs
profiler_config=ProfilerConfig(...)
debugger_hook_config=DebuggerHookConfig(...)
rules=[
Rule.sagemaker(rule_configs.built_in_rule()),
ProfilerRule.sagemaker(rule_configs.BuiltInRule())
]
estimator=XGBoost(
entry_point="directory/to/your_training_script.py",
role=sagemaker.get_execution_role(),
base_job_name="debugger-demo",
instance_count=1,
instance_type="ml.p3.2xlarge",
framework_version="1.2-1",
# Debugger-specific parameters
profiler_config=profiler_config,
debugger_hook_config=debugger_hook_config,
rules=rules
)
estimator.fit(wait=False)
对于 xgBoost,在配置profiler_config参数,您只能配置系统监控。xgBoost 不支持分析框架指标。
- Generic estimator
-
# An example of constructing a SageMaker generic estimator using the XGBoost algorithm base image
import boto3
import sagemaker
from sagemaker.estimator import Estimator
from sagemaker import image_uris
from sagemaker.debugger import ProfilerConfig, DebuggerHookConfig, Rule, ProfilerRule, rule_configs
profiler_config=ProfilerConfig(...)
debugger_hook_config=DebuggerHookConfig(...)
rules=[
Rule.sagemaker(rule_configs.built_in_rule()),
ProfilerRule.sagemaker(rule_configs.BuiltInRule())
]
region=boto3.Session().region_name
xgboost_container=sagemaker.image_uris.retrieve("xgboost", region, "1.2-1")
estimator=Estimator(
role=sagemaker.get_execution_role()
image_uri=xgboost_container,
base_job_name="debugger-demo",
instance_count=1,
instance_type="ml.m5.2xlarge",
# Debugger-specific parameters
profiler_config=profiler_config,
debugger_hook_config=debugger_hook_config,
rules=rules
)
estimator.fit(wait=False)