73 lines
2.3 KiB
Python
73 lines
2.3 KiB
Python
import pandas as pd
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
|
|
import joblib
|
|
import os
|
|
|
|
def train_and_evaluate_random_forest(input_path, model_output_path):
|
|
"""
|
|
Treina e avalia um modelo Random Forest no dataset pré-processado CICIDS2017.
|
|
"""
|
|
print(f"Carregando o dataset pré-processado de: {input_path}")
|
|
df = pd.read_csv(input_path)
|
|
|
|
if 'Label' not in df.columns:
|
|
raise ValueError("A coluna 'Label' não foi encontrada no dataset.")
|
|
|
|
X = df.drop('Label', axis=1)
|
|
y = df['Label']
|
|
|
|
# Codifica rótulos: BENIGN = 0, qualquer outro = 1
|
|
y = y.apply(lambda x: 0 if str(x).strip().upper() == 'BENIGN' else 1)
|
|
|
|
# Divide em treino e teste
|
|
X_train, X_test, y_train, y_test = train_test_split(
|
|
X, y, test_size=0.3, random_state=42, stratify=y
|
|
)
|
|
|
|
print("Treinando o modelo Random Forest...")
|
|
model = RandomForestClassifier(
|
|
n_estimators=100,
|
|
random_state=42,
|
|
n_jobs=-1,
|
|
class_weight='balanced'
|
|
)
|
|
model.fit(X_train, y_train)
|
|
|
|
joblib.dump(model, model_output_path)
|
|
print(f"Modelo salvo em: {model_output_path}")
|
|
|
|
print("Avaliando o modelo...")
|
|
y_pred = model.predict(X_test)
|
|
acc = accuracy_score(y_test, y_pred)
|
|
prec = precision_score(y_test, y_pred)
|
|
rec = recall_score(y_test, y_pred)
|
|
f1 = f1_score(y_test, y_pred)
|
|
cm = confusion_matrix(y_test, y_pred)
|
|
|
|
print(f"Acurácia: {acc:.4f}")
|
|
print(f"Precisão: {prec:.4f}")
|
|
print(f"Recall: {rec:.4f}")
|
|
print(f"F1-score: {f1:.4f}")
|
|
print("Matriz de confusão:")
|
|
print(cm)
|
|
|
|
# Salvar métricas
|
|
with open("model_metrics.txt", "w") as f:
|
|
f.write(f"Acurácia: {acc:.4f}\n")
|
|
f.write(f"Precisão: {prec:.4f}\n")
|
|
f.write(f"Recall: {rec:.4f}\n")
|
|
f.write(f"F1-score: {f1:.4f}\n")
|
|
f.write(f"Matriz de Confusão:\n{cm}\n")
|
|
|
|
if __name__ == '__main__':
|
|
input_file = 'cicids2017_preprocessed.csv'
|
|
model_file = 'random_forest_model.joblib'
|
|
|
|
if not os.path.exists(input_file):
|
|
print(f"❌ Arquivo '{input_file}' não encontrado. Execute o preprocess_data.py antes.")
|
|
else:
|
|
train_and_evaluate_random_forest(input_file, model_file)
|
|
|