Files
cicids2017-visualization/interpret_model_with_shap.py
2025-07-28 22:40:31 -03:00

88 lines
3.4 KiB
Python

import pandas as pd
import joblib
import shap
import matplotlib.pyplot as plt
import os
import numpy as np
def interpret_model_with_shap(model_path, data_path, output_dir="shap_outputs", sample_size=10000):
"""
Gera explicações SHAP para o modelo treinado e os dados, com paralelização para acelerar o processo.
Args:
model_path (str): Caminho para o modelo salvo.
data_path (str): Caminho para o dataset pré-processado.
output_dir (str): Diretório onde os gráficos serão salvos.
sample_size (int): Número de amostras a ser utilizado para gerar as explicações SHAP.
"""
print(f"Carregando modelo de: {model_path}")
model = joblib.load(model_path)
print(f"Carregando dados de: {data_path}")
df = pd.read_csv(data_path)
if "Label" not in df.columns:
raise ValueError("A coluna 'Label' não foi encontrada no dataset.")
X = df.drop("Label", axis=1)
y = df["Label"].apply(lambda x: 0 if str(x).strip().upper() == "BENIGN" else 1)
# Se o dataset tiver mais de 10.000 amostras, fazemos uma amostragem aleatória
if len(X) > sample_size:
X_sample = X.sample(sample_size, random_state=42)
else:
X_sample = X
print(f"Gerando explicações SHAP com {len(X_sample)} amostras...")
# Usando o TreeExplainer corretamente sem n_jobs
explainer = shap.TreeExplainer(model) # Remove n_jobs=-1
shap_values_list = explainer.shap_values(X_sample)
# Se for uma lista (para modelos multiclasse), escolhemos os valores de SHAP para a classe 1 (ataque)
if isinstance(shap_values_list, list):
shap_values = shap_values_list[1] # Pegamos os valores SHAP da classe 1 (ataque)
else:
shap_values = shap_values_list
# Garante que os shap_values estão no formato correto
shap_values = np.array(shap_values)
if shap_values.ndim == 3:
shap_values = shap_values[:, :, 1] # Se for 3D: (n_samples, n_features, n_classes)
# Cria o diretório para salvar os gráficos, se não existir
os.makedirs(output_dir, exist_ok=True)
# Gráfico de importância das features (tipo barra)
plt.figure()
shap.summary_plot(shap_values, X_sample, plot_type="bar", show=False)
plt.tight_layout()
plt.savefig(f"{output_dir}/shap_bar_plot.png", dpi=300)
plt.close()
print("✅ Gráfico de importância (bar) salvo como shap_bar_plot.png")
# Gráfico de importância das features (sumário)
plt.figure()
shap.summary_plot(shap_values, X_sample, show=False)
plt.tight_layout()
plt.savefig(f"{output_dir}/shap_summary_plot.png", dpi=300)
plt.close()
print("✅ Gráfico de importância (summary) salvo como shap_summary_plot.png")
# Gráfico de dependência para a feature mais importante
most_important_feature = X_sample.columns[abs(shap_values).mean(0).argmax()]
print(f"📌 Feature mais importante: {most_important_feature}")
plt.figure()
shap.dependence_plot(most_important_feature, shap_values, X_sample, show=False)
plt.tight_layout()
plt.savefig(f"{output_dir}/shap_dependence_plot.png", dpi=300)
plt.close()
print(f"✅ Gráfico de dependência salvo como shap_dependence_plot.png")
if __name__ == "__main__":
model_file = "random_forest_model.joblib" # Caminho para o modelo salvo
data_file = "cicids2017_preprocessed.csv" # Caminho para o dataset pré-processado
interpret_model_with_shap(model_file, data_file)