import pandas as pd import joblib import shap import matplotlib.pyplot as plt import os import numpy as np def interpret_model_with_shap(model_path, data_path, output_dir="shap_outputs", sample_size=10000): """ Gera explicações SHAP para o modelo treinado e os dados, com paralelização para acelerar o processo. Args: model_path (str): Caminho para o modelo salvo. data_path (str): Caminho para o dataset pré-processado. output_dir (str): Diretório onde os gráficos serão salvos. sample_size (int): Número de amostras a ser utilizado para gerar as explicações SHAP. """ print(f"Carregando modelo de: {model_path}") model = joblib.load(model_path) print(f"Carregando dados de: {data_path}") df = pd.read_csv(data_path) if "Label" not in df.columns: raise ValueError("A coluna 'Label' não foi encontrada no dataset.") X = df.drop("Label", axis=1) y = df["Label"].apply(lambda x: 0 if str(x).strip().upper() == "BENIGN" else 1) # Se o dataset tiver mais de 10.000 amostras, fazemos uma amostragem aleatória if len(X) > sample_size: X_sample = X.sample(sample_size, random_state=42) else: X_sample = X print(f"Gerando explicações SHAP com {len(X_sample)} amostras...") # Usando o TreeExplainer corretamente sem n_jobs explainer = shap.TreeExplainer(model) # Remove n_jobs=-1 shap_values_list = explainer.shap_values(X_sample) # Se for uma lista (para modelos multiclasse), escolhemos os valores de SHAP para a classe 1 (ataque) if isinstance(shap_values_list, list): shap_values = shap_values_list[1] # Pegamos os valores SHAP da classe 1 (ataque) else: shap_values = shap_values_list # Garante que os shap_values estão no formato correto shap_values = np.array(shap_values) if shap_values.ndim == 3: shap_values = shap_values[:, :, 1] # Se for 3D: (n_samples, n_features, n_classes) # Cria o diretório para salvar os gráficos, se não existir os.makedirs(output_dir, exist_ok=True) # Gráfico de importância das features (tipo barra) plt.figure() shap.summary_plot(shap_values, X_sample, plot_type="bar", show=False) plt.tight_layout() plt.savefig(f"{output_dir}/shap_bar_plot.png", dpi=300) plt.close() print("✅ Gráfico de importância (bar) salvo como shap_bar_plot.png") # Gráfico de importância das features (sumário) plt.figure() shap.summary_plot(shap_values, X_sample, show=False) plt.tight_layout() plt.savefig(f"{output_dir}/shap_summary_plot.png", dpi=300) plt.close() print("✅ Gráfico de importância (summary) salvo como shap_summary_plot.png") # Gráfico de dependência para a feature mais importante most_important_feature = X_sample.columns[abs(shap_values).mean(0).argmax()] print(f"📌 Feature mais importante: {most_important_feature}") plt.figure() shap.dependence_plot(most_important_feature, shap_values, X_sample, show=False) plt.tight_layout() plt.savefig(f"{output_dir}/shap_dependence_plot.png", dpi=300) plt.close() print(f"✅ Gráfico de dependência salvo como shap_dependence_plot.png") if __name__ == "__main__": model_file = "random_forest_model.joblib" # Caminho para o modelo salvo data_file = "cicids2017_preprocessed.csv" # Caminho para o dataset pré-processado interpret_model_with_shap(model_file, data_file)