Artigo Mestrado
This commit is contained in:
72
train_model.py
Normal file
72
train_model.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
|
||||
import joblib
|
||||
import os
|
||||
|
||||
def train_and_evaluate_random_forest(input_path, model_output_path):
|
||||
"""
|
||||
Treina e avalia um modelo Random Forest no dataset pré-processado CICIDS2017.
|
||||
"""
|
||||
print(f"Carregando o dataset pré-processado de: {input_path}")
|
||||
df = pd.read_csv(input_path)
|
||||
|
||||
if 'Label' not in df.columns:
|
||||
raise ValueError("A coluna 'Label' não foi encontrada no dataset.")
|
||||
|
||||
X = df.drop('Label', axis=1)
|
||||
y = df['Label']
|
||||
|
||||
# Codifica rótulos: BENIGN = 0, qualquer outro = 1
|
||||
y = y.apply(lambda x: 0 if str(x).strip().upper() == 'BENIGN' else 1)
|
||||
|
||||
# Divide em treino e teste
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.3, random_state=42, stratify=y
|
||||
)
|
||||
|
||||
print("Treinando o modelo Random Forest...")
|
||||
model = RandomForestClassifier(
|
||||
n_estimators=100,
|
||||
random_state=42,
|
||||
n_jobs=-1,
|
||||
class_weight='balanced'
|
||||
)
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
joblib.dump(model, model_output_path)
|
||||
print(f"Modelo salvo em: {model_output_path}")
|
||||
|
||||
print("Avaliando o modelo...")
|
||||
y_pred = model.predict(X_test)
|
||||
acc = accuracy_score(y_test, y_pred)
|
||||
prec = precision_score(y_test, y_pred)
|
||||
rec = recall_score(y_test, y_pred)
|
||||
f1 = f1_score(y_test, y_pred)
|
||||
cm = confusion_matrix(y_test, y_pred)
|
||||
|
||||
print(f"Acurácia: {acc:.4f}")
|
||||
print(f"Precisão: {prec:.4f}")
|
||||
print(f"Recall: {rec:.4f}")
|
||||
print(f"F1-score: {f1:.4f}")
|
||||
print("Matriz de confusão:")
|
||||
print(cm)
|
||||
|
||||
# Salvar métricas
|
||||
with open("model_metrics.txt", "w") as f:
|
||||
f.write(f"Acurácia: {acc:.4f}\n")
|
||||
f.write(f"Precisão: {prec:.4f}\n")
|
||||
f.write(f"Recall: {rec:.4f}\n")
|
||||
f.write(f"F1-score: {f1:.4f}\n")
|
||||
f.write(f"Matriz de Confusão:\n{cm}\n")
|
||||
|
||||
if __name__ == '__main__':
|
||||
input_file = 'cicids2017_preprocessed.csv'
|
||||
model_file = 'random_forest_model.joblib'
|
||||
|
||||
if not os.path.exists(input_file):
|
||||
print(f"❌ Arquivo '{input_file}' não encontrado. Execute o preprocess_data.py antes.")
|
||||
else:
|
||||
train_and_evaluate_random_forest(input_file, model_file)
|
||||
|
||||
Reference in New Issue
Block a user