88 lines
3.4 KiB
Python
88 lines
3.4 KiB
Python
import pandas as pd
|
|
import joblib
|
|
import shap
|
|
import matplotlib.pyplot as plt
|
|
import os
|
|
import numpy as np
|
|
|
|
def interpret_model_with_shap(model_path, data_path, output_dir="shap_outputs", sample_size=10000):
|
|
"""
|
|
Gera explicações SHAP para o modelo treinado e os dados, com paralelização para acelerar o processo.
|
|
|
|
Args:
|
|
model_path (str): Caminho para o modelo salvo.
|
|
data_path (str): Caminho para o dataset pré-processado.
|
|
output_dir (str): Diretório onde os gráficos serão salvos.
|
|
sample_size (int): Número de amostras a ser utilizado para gerar as explicações SHAP.
|
|
"""
|
|
print(f"Carregando modelo de: {model_path}")
|
|
model = joblib.load(model_path)
|
|
|
|
print(f"Carregando dados de: {data_path}")
|
|
df = pd.read_csv(data_path)
|
|
|
|
if "Label" not in df.columns:
|
|
raise ValueError("A coluna 'Label' não foi encontrada no dataset.")
|
|
|
|
X = df.drop("Label", axis=1)
|
|
y = df["Label"].apply(lambda x: 0 if str(x).strip().upper() == "BENIGN" else 1)
|
|
|
|
# Se o dataset tiver mais de 10.000 amostras, fazemos uma amostragem aleatória
|
|
if len(X) > sample_size:
|
|
X_sample = X.sample(sample_size, random_state=42)
|
|
else:
|
|
X_sample = X
|
|
|
|
print(f"Gerando explicações SHAP com {len(X_sample)} amostras...")
|
|
|
|
# Usando o TreeExplainer corretamente sem n_jobs
|
|
explainer = shap.TreeExplainer(model) # Remove n_jobs=-1
|
|
shap_values_list = explainer.shap_values(X_sample)
|
|
|
|
# Se for uma lista (para modelos multiclasse), escolhemos os valores de SHAP para a classe 1 (ataque)
|
|
if isinstance(shap_values_list, list):
|
|
shap_values = shap_values_list[1] # Pegamos os valores SHAP da classe 1 (ataque)
|
|
else:
|
|
shap_values = shap_values_list
|
|
|
|
# Garante que os shap_values estão no formato correto
|
|
shap_values = np.array(shap_values)
|
|
if shap_values.ndim == 3:
|
|
shap_values = shap_values[:, :, 1] # Se for 3D: (n_samples, n_features, n_classes)
|
|
|
|
# Cria o diretório para salvar os gráficos, se não existir
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# Gráfico de importância das features (tipo barra)
|
|
plt.figure()
|
|
shap.summary_plot(shap_values, X_sample, plot_type="bar", show=False)
|
|
plt.tight_layout()
|
|
plt.savefig(f"{output_dir}/shap_bar_plot.png", dpi=300)
|
|
plt.close()
|
|
print("✅ Gráfico de importância (bar) salvo como shap_bar_plot.png")
|
|
|
|
# Gráfico de importância das features (sumário)
|
|
plt.figure()
|
|
shap.summary_plot(shap_values, X_sample, show=False)
|
|
plt.tight_layout()
|
|
plt.savefig(f"{output_dir}/shap_summary_plot.png", dpi=300)
|
|
plt.close()
|
|
print("✅ Gráfico de importância (summary) salvo como shap_summary_plot.png")
|
|
|
|
# Gráfico de dependência para a feature mais importante
|
|
most_important_feature = X_sample.columns[abs(shap_values).mean(0).argmax()]
|
|
print(f"📌 Feature mais importante: {most_important_feature}")
|
|
|
|
plt.figure()
|
|
shap.dependence_plot(most_important_feature, shap_values, X_sample, show=False)
|
|
plt.tight_layout()
|
|
plt.savefig(f"{output_dir}/shap_dependence_plot.png", dpi=300)
|
|
plt.close()
|
|
print(f"✅ Gráfico de dependência salvo como shap_dependence_plot.png")
|
|
|
|
if __name__ == "__main__":
|
|
model_file = "random_forest_model.joblib" # Caminho para o modelo salvo
|
|
data_file = "cicids2017_preprocessed.csv" # Caminho para o dataset pré-processado
|
|
interpret_model_with_shap(model_file, data_file)
|
|
|