SHAP值的解读
对于信贷问题,我们除了希望知道是否存在风险,还希望知道每个特征贡献了多少,比如年收入+0.15,收入高,加分;负债率-0.30负债太高,减分;工作年限+0.05工作稳定,小加分;信用评分-0.25 信用不好,减分;年龄+0.02影响很小,把模型的决策分解到每个特征上
到这里你可能心想,这样是不是很类似于线性回归的特征前的系数 y=ax1+bx2+cx3,那既然如此,我直接选择用线性回归的系数作解释,那岂不是更好?
线性回归的系数:y=0.15x收入 +(-0.30)x负债率 +0.05x工作年限 +(-0.25)x信用评分+0.02x年龄
- 系数是固定的:不管是谁,"收入"的系数永远是 0.15
- 全局解释:一个系数解释所有样本
- 简单,但假设特征和目标是线性关系
比如线性回归会说收入每增加1万,贡献固定增加 x。但现实中:收入从 5万→30万,影响很大,收入从 100 万→500 万,影响就没那么大了(边际效应递减),SHAP能捕捉这种非线性关系
核心差异就在于SHAP值是因人而异的:张三的“收入”贡献可能是+0.15,李四的可能是+0.08
- 局部解释:每个样本有自己的一组SHAP值
- 复杂,但能捕捉非线性关系,同一个特征,不同样本贡献不同
这种非线性如何呈现:
- 特征本身的非线性关系:比如边际效应
- 特征之间存在交互效应:男性(性别)+年龄(25)发生质变
Shapley 值的核心就是当特征之间有交互作用时,如何公平地把"功劳"分给每个特征
SHAP 的原理来自博弈论,但我们用一个更简单的例子来理解:想象三个人合伙开了一家奶茶店,年底赚了 100 万。问题来了:这 100 万怎么分?
小王负责研发配方、小李负责营销推广、小张负责店面运营
直接三等分?不公平!因为每个人的贡献不一样。经济学家 Shapley 提出了一个方法:
数学家 做了数学假设:
博弈论基础上有4条规则,满足这4个客观规则的只有 shap 值,很自洽
一般来说基准值不是-500 +500
100 个样本,对这个样本的预测取平均(训练)=基准值
shap 值 本质上是解释模型在训练集学习的东西 加入什么都没学 直接取平均 最好的
SHAP 在机器学习中的应用
开店=机器学习
合伙人 = 特征
总收入 = 预测值 - 基准值
每人贡献 = 每个特征的 SHAP 值
基准值(Base Value):模型在所有样本上的平均预测值
三人合伙前,收入是 0,三人合伙后,收入是 100万,要分配的"蛋糕"就是 100万-0=100 万
基准值 = 没有任何特征信息时的"默认"预测(相当于"0”的起点),这个值一般就是平均值,把训练集的所有样本都输入模型,得到所有预测值取平均值,在没有关于这个特定样本的区分性信息时,最合理的猜测就是平均值
预测值 = 加入所有特征后的预测(相当于"100 万"的终点)
要分配的"蛋糕"=预测值 -基准值
核心公式:
模型预测值 = 基准值+SHAP(特征 1)+SHAP(特征 2)+..+SHAP(特征 N)
SHAP 值加起来 =预测值与基准值的差!
那么如何实现 shaply 值动态变化呢?上面说的是平均这一家店是一个样本,对于多家店每个店都是样本,所以特征贡献不同那么对于一个样本,如果控制变量计算特征贡献呢?真实在做的时候肯定是没法实现控制其他特征不动,检测单个特征的贡献,其实还是多样本比对了,shap 值本质上也是一个近似值。
虽然不完美,但 SHAP 是目前理论最完善、应用最广泛的解释方法
SHAP 值的计算用训练还是测试集?
先说结论,两者均可,但是为了图好看一般都是选择训练集。
- 做机器学习的专业大多都是交叉学科,本身你的研究多是针对私有数据集,别人不会关注你的泛化性。所以不必因为这个纠结。
- shap 值是每个样本的每个特征都会得到对应类别的值,所以如果你的数据量本身就不大,用训练集来绘制,点多会让图美观很多,或者也可以对测试集插值也可以起到一样的效果
- 补充一个 shap图美观的小技巧,可以绘制出每一个类别 shap 曲线的置信区间,因为机器学习多是点估计,而区间估计会让你的结果更加具有信服力,利用bootstrap 重采样思想可以绘制出置信区间,自己写一下 shap 图函数,不用借助 shap 库的接口。
import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from sklearn.datasets import load_diabetes from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestRegressor from sklearn.linear_model import LinearRegression from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error import warnings warnings.filterwarnings('ignore') # --- 1. 全局设置 --- plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文字体 plt.rcParams['axes.unicode_minus'] = False # 正常显示负号 sns.set(style="whitegrid", font='SimHei') # ========================================== # 2. 加载糖尿病数据集 # ========================================== print("正在加载糖尿病数据集...") # 加载数据 diabetes = load_diabetes() X, y = diabetes.data, diabetes.target # 转换为DataFrame以便更好地显示 df = pd.DataFrame(X, columns=diabetes.feature_names) df['target'] = y print("="*30 + " 数据集概览 " + "="*30) print(f"数据形状: {df.shape}") print(f"特征数量: {len(diabetes.feature_names)}") print(f"目标变量: 一年后疾病进展的定量测量") print("\n特征说明:") for i, (name, desc) in enumerate(zip(diabetes.feature_names, diabetes.DESCR.split('\n')[10:20])): print(f"{i+1:2d}. {name:15s} - {desc.strip()}") print("\n前5行数据:") print(df.head()) # ========================================== # 3. 数据探索性分析 # ========================================== print("\n" + "="*30 + " 数据探索分析 " + "="*30) # 创建可视化 fig, axes = plt.subplots(2, 3, figsize=(15, 10)) axes = axes.flatten() # 1. 目标变量分布 axes[0].hist(y, bins=30, color='skyblue', edgecolor='black', alpha=0.7) axes[0].set_title('目标变量分布', fontsize=12, fontweight='bold') axes[0].set_xlabel('疾病进展') axes[0].set_ylabel('频数') axes[0].grid(True, alpha=0.3) # 2. 特征相关性热图 corr_matrix = df.corr() sns.heatmap(corr_matrix, annot=False, cmap='coolwarm', center=0, ax=axes[1], cbar_kws={'shrink': 0.8}) axes[1].set_title('特征相关性热图', fontsize=12, fontweight='bold') # 3. 目标变量与重要特征的关系 important_features = ['bmi', 's5', 'bp', 's3'] for i, feature in enumerate(important_features[:3]): axes[i+2].scatter(df[feature], y, alpha=0.5, s=20) axes[i+2].set_title(f'目标 vs {feature}', fontsize=12) axes[i+2].set_xlabel(feature) axes[i+2].set_ylabel('疾病进展') axes[i+2].grid(True, alpha=0.3) plt.suptitle('糖尿病数据集探索性分析', fontsize=14, fontweight='bold', y=1.02) plt.tight_layout() plt.show() # 数据统计 print("\n数据统计信息:") print(df.describe().round(3)) # ========================================== # 4. 训练随机森林模型 # ========================================== print("\n" + "="*30 + " 模型训练 " + "="*30) # 划分数据集 X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42 ) print(f"训练集形状: {X_train.shape}") print(f"测试集形状: {X_test.shape}") # 训练随机森林模型 rf_model = RandomForestRegressor( n_estimators=100, max_depth=10, random_state=42, n_jobs=-1 ) rf_model.fit(X_train, y_train) y_pred = rf_model.predict(X_test) # 评估模型 mse = mean_squared_error(y_test, y_pred) rmse = np.sqrt(mse) mae = mean_absolute_error(y_test, y_pred) r2 = r2_score(y_test, y_pred) print("\n模型性能评估:") print(f"均方误差 (MSE): {mse:.3f}") print(f"均方根误差 (RMSE): {rmse:.3f}") print(f"平均绝对误差 (MAE): {mae:.3f}") print(f"决定系数 (R²): {r2:.3f}") # 预测结果可视化 plt.figure(figsize=(8, 6)) plt.scatter(y_test, y_pred, alpha=0.5, color='blue') plt.plot([y.min(), y.max()], [y.min(), y.max()], 'r--', lw=2) plt.xlabel('真实值') plt.ylabel('预测值') plt.title('随机森林模型预测结果', fontsize=14, fontweight='bold') plt.grid(True, alpha=0.3) plt.show() # ========================================== # 5. SHAP可解释性分析 # ========================================== print("\n" + "="*30 + " SHAP可解释性分析 " + "="*30) # 安装shap库(如果未安装) try: import shap print("SHAP库已安装,开始分析...") except ImportError: print("正在安装SHAP库...") import subprocess import sys subprocess.check_call([sys.executable, "-m", "pip", "install", "shap"]) import shap # 初始化SHAP解释器 explainer = shap.TreeExplainer(rf_model) shap_values = explainer.shap_values(X_test) # 转换为DataFrame以便更好地处理 shap_df = pd.DataFrame(shap_values, columns=diabetes.feature_names) print("\nSHAP分析完成!") print(f"SHAP值形状: {shap_df.shape}") # ========================================== # 6. SHAP可视化分析 # ========================================== print("\n" + "="*30 + " SHAP可视化分析 " + "="*30) # 创建可视化图表 fig, axes = plt.subplots(2, 2, figsize=(16, 12)) # 1. 特征重要性总结图 shap.summary_plot(shap_values, X_test, feature_names=diabetes.feature_names, show=False, plot_size=None, max_display=10) plt.title('SHAP特征重要性总结', fontsize=14, fontweight='bold', y=1.02) fig1 = plt.gcf() fig1.set_size_inches(10, 6) plt.tight_layout() plt.show() # 2. 特征重要性条形图 plt.figure(figsize=(10, 6)) shap.summary_plot(shap_values, X_test, feature_names=diabetes.feature_names, plot_type="bar", show=False, max_display=10) plt.title('SHAP特征重要性(平均绝对影响)', fontsize=14, fontweight='bold') plt.tight_layout() plt.show() # 3. 单个特征的SHAP依赖图(选择最重要的3个特征) important_indices = np.argsort(np.abs(shap_values).mean(0))[-3:][::-1] important_features = [diabetes.feature_names[i] for i in important_indices] print(f"\n最重要的3个特征:") for i, feature in enumerate(important_features, 1): shap_mean_abs = np.abs(shap_df[feature]).mean() print(f"{i}. {feature}: 平均绝对SHAP值 = {shap_mean_abs:.3f}") fig, axes = plt.subplots(1, 3, figsize=(15, 4)) for idx, (ax, feature) in enumerate(zip(axes, important_features)): feature_idx = list(diabetes.feature_names).index(feature) # 使用shap的partial_dependence_plot shap.dependence_plot( feature_idx, shap_values, X_test, feature_names=diabetes.feature_names, ax=ax, show=False ) ax.set_title(f'{feature}的SHAP依赖图', fontsize=12, fontweight='bold') ax.set_xlabel(feature) ax.set_ylabel('SHAP值') ax.grid(True, alpha=0.3) plt.suptitle('重要特征的SHAP依赖图', fontsize=14, fontweight='bold', y=1.05) plt.tight_layout() plt.show() # 4. 单个样本的SHAP解释(选择3个样本) print("\n" + "="*30 + " 单样本解释 " + "="*30) # 选择3个有代表性的样本 sample_indices = [0, 50, 100] # 可以根据需要调整 fig, axes = plt.subplots(1, 3, figsize=(15, 5)) for i, (ax, sample_idx) in enumerate(zip(axes, sample_indices)): # 创建force plot(瀑布图) shap.force_plot( explainer.expected_value, shap_values[sample_idx, :], X_test[sample_idx, :], feature_names=diabetes.feature_names, matplotlib=True, show=False, text_rotation=30 ) # 获取当前图形 temp_fig = plt.gcf() # 手动设置标题和标签 ax.set_title(f'样本 {sample_idx} 的SHAP解释', fontsize=12, fontweight='bold') ax.set_xlabel('特征') ax.set_ylabel('SHAP贡献') # 清理并显示 plt.close(temp_fig) plt.suptitle('单个样本的SHAP解释(Force Plot)', fontsize=14, fontweight='bold', y=1.05) plt.tight_layout() plt.show() # 显示样本的具体数值 for i, sample_idx in enumerate(sample_indices): print(f"\n样本 {sample_idx} 详情:") print(f"真实值: {y_test[sample_idx]:.2f}") print(f"预测值: {y_pred[sample_idx]:.2f}") print(f"预测偏差: {y_pred[sample_idx] - y_test[sample_idx]:.2f}") # 显示特征值和SHAP值 sample_shap = shap_values[sample_idx, :] important_features_idx = np.argsort(np.abs(sample_shap))[-5:][::-1] print("最重要的5个特征贡献:") for j, feat_idx in enumerate(important_features_idx): feat_name = diabetes.feature_names[feat_idx] feat_value = X_test[sample_idx, feat_idx] shap_value = sample_shap[feat_idx] print(f" {feat_name:10s}: 值={feat_value:6.3f}, SHAP={shap_value:7.3f}") # 5. 特征交互分析 print("\n" + "="*30 + " 特征交互分析 " + "="*30) # 寻找最重要的交互特征 interaction_feature = 'bmi' # 选择一个重要特征 feature_idx = list(diabetes.feature_names).index(interaction_feature) plt.figure(figsize=(10, 6)) shap.dependence_plot( feature_idx, shap_values, X_test, feature_names=diabetes.feature_names, interaction_index='auto', # 自动检测交互特征 show=False ) plt.title(f'{interaction_feature}的特征交互分析', fontsize=14, fontweight='bold') plt.xlabel(interaction_feature) plt.ylabel('SHAP值') plt.grid(True, alpha=0.3) plt.tight_layout() plt.show() # ========================================== # 7. 与传统特征重要性对比 # ========================================== print("\n" + "="*30 + " 特征重要性对比 " + "="*30) # 传统特征重要性(基于基尼不纯度) traditional_importance = pd.DataFrame({ '特征': diabetes.feature_names, '传统重要性': rf_model.feature_importances_ }).sort_values('传统重要性', ascending=False) # SHAP特征重要性(基于平均绝对SHAP值) shap_importance = pd.DataFrame({ '特征': diabetes.feature_names, 'SHAP重要性': np.abs(shap_values).mean(0) }).sort_values('SHAP重要性', ascending=False) print("\n传统特征重要性(基尼不纯度):") print(traditional_importance.head(10)) print("\nSHAP特征重要性(平均绝对SHAP值):") print(shap_importance.head(10)) # 可视化对比 fig, axes = plt.subplots(1, 2, figsize=(14, 6)) # 传统特征重要性 axes[0].barh(range(10), traditional_importance['传统重要性'].head(10)[::-1], color='skyblue', alpha=0.7) axes[0].set_yticks(range(10)) axes[0].set_yticklabels(traditional_importance['特征'].head(10)[::-1]) axes[0].set_xlabel('重要性得分') axes[0].set_title('传统特征重要性', fontsize=12, fontweight='bold') axes[0].grid(True, alpha=0.3, axis='x') # SHAP特征重要性 axes[1].barh(range(10), shap_importance['SHAP重要性'].head(10)[::-1], color='lightcoral', alpha=0.7) axes[1].set_yticks(range(10)) axes[1].set_yticklabels(shap_importance['特征'].head(10)[::-1]) axes[1].set_xlabel('平均绝对SHAP值') axes[1].set_title('SHAP特征重要性', fontsize=12, fontweight='bold') axes[1].grid(True, alpha=0.3, axis='x') plt.suptitle('特征重要性方法对比', fontsize=14, fontweight='bold', y=1.02) plt.tight_layout() plt.show() # ========================================== # 8. 模型对比分析 # ========================================== print("\n" + "="*30 + " 模型对比分析 " + "="*30) # 训练线性回归模型进行对比 lr_model = LinearRegression() lr_model.fit(X_train, y_train) y_pred_lr = lr_model.predict(X_test) # 计算线性回归的SHAP值(使用KernelExplainer) lr_explainer = shap.KernelExplainer(lr_model.predict, X_train[:100]) # 使用子集提高速度 lr_shap_values = lr_explainer.shap_values(X_test[:100]) # 使用子集 print("\n模型性能对比:") print(f"随机森林 R²: {r2_score(y_test, y_pred):.3f}") print(f"线性回归 R²: {r2_score(y_test[:100], y_pred_lr[:100]):.3f}") # 对比特征重要性 if lr_shap_values is not None: lr_shap_importance = pd.DataFrame({ '特征': diabetes.feature_names, '线性回归SHAP': np.abs(lr_shap_values).mean(0) }).sort_values('线性回归SHAP', ascending=False) print("\n线性回归SHAP重要性:") print(lr_shap_importance.head(10)) # ========================================== # 9. 总结报告 # ========================================== print("\n" + "="*30 + " SHAP分析总结 " + "="*30) print("\n关键发现:") print("1. 最重要的预测特征:") for i, row in shap_importance.head(3).iterrows(): print(f" {row['特征']}: 平均绝对SHAP值 = {row['SHAP重要性']:.3f}") print("\n2. 模型解释性:") print(" • SHAP提供了局部和全局的解释") print(" • 可以理解每个特征对单个预测的贡献") print(" • 揭示了特征之间的相互作用") print("\n3. 临床应用建议:") print(" • 重点关注BMI、s5和血压等关键指标") print(" • 这些特征对疾病进展预测有最大影响") print(" • 可以为个性化医疗提供数据支持") print("\n4. 技术要点:") print(" • 随机森林模型表现良好 (R² = {:.3f})".format(r2)) print(" • SHAP分析揭示了特征的非线性关系") print(" • 与传统特征重要性方法相比,SHAP更准确") # 保存结果 results = { '模型性能': { 'R2': r2, 'RMSE': rmse, 'MAE': mae }, '特征重要性': shap_importance.to_dict('records')[:5] } import json with open('diabetes_shap_results.json', 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=2) print(f"\n分析结果已保存到: diabetes_shap_results.json") print("\n" + "="*30 + " 分析完成 " + "="*30)@浙大疏锦行