为生产中的模型创建 SHAP 基线 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅中国的 Amazon Web Services 服务入门

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

为生产中的模型创建 SHAP 基线

解释通常是对比的,也就是说,它们解释了偏离基线的原因。有关可解释性基准的信息,请参阅SHAP 基准的可解释性.

除了对每个实例的推断提供解释外, SageMaker 澄清还支持对机器学习模型的全局解释,以帮助您从特性角度了解模型作为一个整体的行为。 SageMaker 澄清通过在多个实例上聚合 Shapley 值来生成机器学习模型的全局解释。 SageMaker 澄清支持以下不同的聚合方式,您可以使用这些方式来定义基线:

  • mean_abs— 所有实例的绝对 SHAP 值的均值。

  • median— 所有实例的 SHAP 值的中位数。

  • mean_sq— 所有实例的 SHAP 平方值的均值。

在将应用程序配置为捕获实时推理数据之后,监控要素归因偏移的第一项任务是创建一个基线进行比较。这涉及配置数据输入、哪些组敏感、如何捕获预测、模型及其训练后偏差指标。然后你需要开始基线工作。模型可解释性监视器可以解释已部署模型的预测,该模型正在生成推论并定期检测功能归因偏移。

model_explainability_monitor = ModelExplainabilityMonitor( role=role, sagemaker_session=sagemaker_session, max_runtime_in_seconds=1800, )

在此示例中,可解释性基线作业将测试数据集与偏差基线作业共享,因此它使用相同的DataConfig,唯一的区别是作业输出 URI。

model_explainability_baselining_job_result_uri = f"{baseline_results_uri}/model_explainability" model_explainability_data_config = DataConfig( s3_data_input_path=validation_dataset, s3_output_path=model_explainability_baselining_job_result_uri, label=label_header, headers=all_headers, dataset_type=dataset_type, )

目前 SageMaker 澄清解释器提供了 SHAP 的可扩展和高效实施,因此可解释性配置是 shapConfig,包括以下内容:

  • baseline— 要在内核 SHAP 算法中用作基线数据集的行(至少一个)或 S3 对象 URI 的列表。格式应与数据集格式相同。每行应只包含要素列/值,并忽略标注列/值。

  • num_samples— 内核 SHAP 算法中要使用的样本数量。此数字决定了用于计算 SHAP 值的生成合成数据集的大小。

  • agg_method — 全局 SHAP 值的聚合方法。有效值如下所示:

    • mean_abs— 所有实例的绝对 SHAP 值的均值。

    • median— 所有实例的 SHAP 值的中位数。

    • mean_sq— 所有实例的 SHAP 平方值的均值。

  • use_logit— 是否将 logit 函数应用于模型预测的指标。默认为 False。如果use_logitTrue,SHAP 值将具有对数赔率单位。

  • save_local_shap_values(bool) — 指示是否将本地 SHAP 值保存在输出位置。默认为 False

# Here use the mean value of test dataset as SHAP baseline test_dataframe = pd.read_csv(test_dataset, header=None) shap_baseline = [list(test_dataframe.mean())] shap_config = SHAPConfig( baseline=shap_baseline, num_samples=100, agg_method="mean_abs", save_local_shap_values=False, )

开始基线工作。相同的model_config是必需的,因为可解释性基线作业需要创建阴影端点才能获得对生成的合成数据集的预测。

model_explainability_monitor.suggest_baseline( data_config=model_explainability_data_config, model_config=model_config, explainability_config=shap_config, ) print(f"ModelExplainabilityMonitor baselining job: {model_explainability_monitor.latest_baselining_job_name}")