import pandas as pd from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix import joblib import os def train_and_evaluate_random_forest(input_path, model_output_path): """ Treina e avalia um modelo Random Forest no dataset pré-processado CICIDS2017. """ print(f"Carregando o dataset pré-processado de: {input_path}") df = pd.read_csv(input_path) if 'Label' not in df.columns: raise ValueError("A coluna 'Label' não foi encontrada no dataset.") X = df.drop('Label', axis=1) y = df['Label'] # Codifica rótulos: BENIGN = 0, qualquer outro = 1 y = y.apply(lambda x: 0 if str(x).strip().upper() == 'BENIGN' else 1) # Divide em treino e teste X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.3, random_state=42, stratify=y ) print("Treinando o modelo Random Forest...") model = RandomForestClassifier( n_estimators=100, random_state=42, n_jobs=-1, class_weight='balanced' ) model.fit(X_train, y_train) joblib.dump(model, model_output_path) print(f"Modelo salvo em: {model_output_path}") print("Avaliando o modelo...") y_pred = model.predict(X_test) acc = accuracy_score(y_test, y_pred) prec = precision_score(y_test, y_pred) rec = recall_score(y_test, y_pred) f1 = f1_score(y_test, y_pred) cm = confusion_matrix(y_test, y_pred) print(f"Acurácia: {acc:.4f}") print(f"Precisão: {prec:.4f}") print(f"Recall: {rec:.4f}") print(f"F1-score: {f1:.4f}") print("Matriz de confusão:") print(cm) # Salvar métricas with open("model_metrics.txt", "w") as f: f.write(f"Acurácia: {acc:.4f}\n") f.write(f"Precisão: {prec:.4f}\n") f.write(f"Recall: {rec:.4f}\n") f.write(f"F1-score: {f1:.4f}\n") f.write(f"Matriz de Confusão:\n{cm}\n") if __name__ == '__main__': input_file = 'cicids2017_preprocessed.csv' model_file = 'random_forest_model.joblib' if not os.path.exists(input_file): print(f"❌ Arquivo '{input_file}' não encontrado. Execute o preprocess_data.py antes.") else: train_and_evaluate_random_forest(input_file, model_file)