Files
cicids2017-visualization/train_model.py
2025-07-28 22:40:31 -03:00

73 lines
2.3 KiB
Python

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import joblib
import os
def train_and_evaluate_random_forest(input_path, model_output_path):
"""
Treina e avalia um modelo Random Forest no dataset pré-processado CICIDS2017.
"""
print(f"Carregando o dataset pré-processado de: {input_path}")
df = pd.read_csv(input_path)
if 'Label' not in df.columns:
raise ValueError("A coluna 'Label' não foi encontrada no dataset.")
X = df.drop('Label', axis=1)
y = df['Label']
# Codifica rótulos: BENIGN = 0, qualquer outro = 1
y = y.apply(lambda x: 0 if str(x).strip().upper() == 'BENIGN' else 1)
# Divide em treino e teste
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42, stratify=y
)
print("Treinando o modelo Random Forest...")
model = RandomForestClassifier(
n_estimators=100,
random_state=42,
n_jobs=-1,
class_weight='balanced'
)
model.fit(X_train, y_train)
joblib.dump(model, model_output_path)
print(f"Modelo salvo em: {model_output_path}")
print("Avaliando o modelo...")
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
prec = precision_score(y_test, y_pred)
rec = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
cm = confusion_matrix(y_test, y_pred)
print(f"Acurácia: {acc:.4f}")
print(f"Precisão: {prec:.4f}")
print(f"Recall: {rec:.4f}")
print(f"F1-score: {f1:.4f}")
print("Matriz de confusão:")
print(cm)
# Salvar métricas
with open("model_metrics.txt", "w") as f:
f.write(f"Acurácia: {acc:.4f}\n")
f.write(f"Precisão: {prec:.4f}\n")
f.write(f"Recall: {rec:.4f}\n")
f.write(f"F1-score: {f1:.4f}\n")
f.write(f"Matriz de Confusão:\n{cm}\n")
if __name__ == '__main__':
input_file = 'cicids2017_preprocessed.csv'
model_file = 'random_forest_model.joblib'
if not os.path.exists(input_file):
print(f"❌ Arquivo '{input_file}' não encontrado. Execute o preprocess_data.py antes.")
else:
train_and_evaluate_random_forest(input_file, model_file)