Artigo Mestrado
This commit is contained in:
87
interpret_model_with_shap.py
Normal file
87
interpret_model_with_shap.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import pandas as pd
|
||||
import joblib
|
||||
import shap
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
def interpret_model_with_shap(model_path, data_path, output_dir="shap_outputs", sample_size=10000):
|
||||
"""
|
||||
Gera explicações SHAP para o modelo treinado e os dados, com paralelização para acelerar o processo.
|
||||
|
||||
Args:
|
||||
model_path (str): Caminho para o modelo salvo.
|
||||
data_path (str): Caminho para o dataset pré-processado.
|
||||
output_dir (str): Diretório onde os gráficos serão salvos.
|
||||
sample_size (int): Número de amostras a ser utilizado para gerar as explicações SHAP.
|
||||
"""
|
||||
print(f"Carregando modelo de: {model_path}")
|
||||
model = joblib.load(model_path)
|
||||
|
||||
print(f"Carregando dados de: {data_path}")
|
||||
df = pd.read_csv(data_path)
|
||||
|
||||
if "Label" not in df.columns:
|
||||
raise ValueError("A coluna 'Label' não foi encontrada no dataset.")
|
||||
|
||||
X = df.drop("Label", axis=1)
|
||||
y = df["Label"].apply(lambda x: 0 if str(x).strip().upper() == "BENIGN" else 1)
|
||||
|
||||
# Se o dataset tiver mais de 10.000 amostras, fazemos uma amostragem aleatória
|
||||
if len(X) > sample_size:
|
||||
X_sample = X.sample(sample_size, random_state=42)
|
||||
else:
|
||||
X_sample = X
|
||||
|
||||
print(f"Gerando explicações SHAP com {len(X_sample)} amostras...")
|
||||
|
||||
# Usando o TreeExplainer corretamente sem n_jobs
|
||||
explainer = shap.TreeExplainer(model) # Remove n_jobs=-1
|
||||
shap_values_list = explainer.shap_values(X_sample)
|
||||
|
||||
# Se for uma lista (para modelos multiclasse), escolhemos os valores de SHAP para a classe 1 (ataque)
|
||||
if isinstance(shap_values_list, list):
|
||||
shap_values = shap_values_list[1] # Pegamos os valores SHAP da classe 1 (ataque)
|
||||
else:
|
||||
shap_values = shap_values_list
|
||||
|
||||
# Garante que os shap_values estão no formato correto
|
||||
shap_values = np.array(shap_values)
|
||||
if shap_values.ndim == 3:
|
||||
shap_values = shap_values[:, :, 1] # Se for 3D: (n_samples, n_features, n_classes)
|
||||
|
||||
# Cria o diretório para salvar os gráficos, se não existir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Gráfico de importância das features (tipo barra)
|
||||
plt.figure()
|
||||
shap.summary_plot(shap_values, X_sample, plot_type="bar", show=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"{output_dir}/shap_bar_plot.png", dpi=300)
|
||||
plt.close()
|
||||
print("✅ Gráfico de importância (bar) salvo como shap_bar_plot.png")
|
||||
|
||||
# Gráfico de importância das features (sumário)
|
||||
plt.figure()
|
||||
shap.summary_plot(shap_values, X_sample, show=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"{output_dir}/shap_summary_plot.png", dpi=300)
|
||||
plt.close()
|
||||
print("✅ Gráfico de importância (summary) salvo como shap_summary_plot.png")
|
||||
|
||||
# Gráfico de dependência para a feature mais importante
|
||||
most_important_feature = X_sample.columns[abs(shap_values).mean(0).argmax()]
|
||||
print(f"📌 Feature mais importante: {most_important_feature}")
|
||||
|
||||
plt.figure()
|
||||
shap.dependence_plot(most_important_feature, shap_values, X_sample, show=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"{output_dir}/shap_dependence_plot.png", dpi=300)
|
||||
plt.close()
|
||||
print(f"✅ Gráfico de dependência salvo como shap_dependence_plot.png")
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_file = "random_forest_model.joblib" # Caminho para o modelo salvo
|
||||
data_file = "cicids2017_preprocessed.csv" # Caminho para o dataset pré-processado
|
||||
interpret_model_with_shap(model_file, data_file)
|
||||
|
||||
Reference in New Issue
Block a user