add custom class weights based on sklearn balance
This commit is contained in:
parent
b35f23e518
commit
6b787792db
@ -61,6 +61,7 @@ parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
|
||||
# parser.add_argument("--tmp", action="store_true", dest="tmp")
|
||||
#
|
||||
parser.add_argument("--stop_early", action="store_true", dest="stop_early")
|
||||
parser.add_argument("--balanced_weights", action="store_true", dest="class_weights")
|
||||
parser.add_argument("--gpu", action="store_true", dest="gpu")
|
||||
|
||||
|
||||
|
37
main.py
37
main.py
@ -7,6 +7,7 @@ import pandas as pd
|
||||
import tensorflow as tf
|
||||
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
|
||||
from keras.models import load_model
|
||||
from sklearn.utils import class_weight
|
||||
|
||||
import arguments
|
||||
import dataset
|
||||
@ -101,6 +102,17 @@ def main_hyperband():
|
||||
json.dump(results, open("hyperband.json"))
|
||||
|
||||
|
||||
def get_custom_class_weights(client_tr, server_tr):
|
||||
client = client_tr.value.argmax(1)
|
||||
server = server_tr.value.argmax(1)
|
||||
client_class_weight = class_weight.compute_class_weight('balanced', np.unique(client), client)
|
||||
server_class_weight = class_weight.compute_class_weight('balanced', np.unique(server), server)
|
||||
return {
|
||||
"client": client_class_weight,
|
||||
"server": server_class_weight
|
||||
}
|
||||
|
||||
|
||||
def main_train(param=None):
|
||||
exists_or_make_path(args.model_path)
|
||||
|
||||
@ -151,6 +163,14 @@ def main_train(param=None):
|
||||
model.compile(optimizer='adam',
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy'] + custom_metrics)
|
||||
|
||||
if args.class_weights:
|
||||
logger.info("class weights: compute custom weights")
|
||||
custom_class_weights = get_custom_class_weights(client_tr, server_tr)
|
||||
else:
|
||||
logger.info("class weights: set default")
|
||||
custom_class_weights = None
|
||||
|
||||
logger.info("start training")
|
||||
model.fit([domain_tr, flow_tr],
|
||||
[client_tr, server_tr],
|
||||
@ -158,7 +178,8 @@ def main_train(param=None):
|
||||
epochs=args.epochs,
|
||||
callbacks=callbacks,
|
||||
shuffle=True,
|
||||
validation_split=0.2)
|
||||
validation_split=0.2,
|
||||
class_weight=custom_class_weights)
|
||||
logger.info("save embedding")
|
||||
embedding.save(args.embedding_model)
|
||||
|
||||
@ -167,11 +188,9 @@ def main_test():
|
||||
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
||||
args.domain_length, args.window)
|
||||
clf = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||
stats = clf.evaluate([domain_val, flow_val],
|
||||
[client_val, server_val],
|
||||
batch_size=args.batch_size)
|
||||
# logger.info(f"loss: {loss}\nclient acc: {client_acc}\nserver acc: {server_acc}")
|
||||
logger.info(stats)
|
||||
# stats = clf.evaluate([domain_val, flow_val],
|
||||
# [client_val, server_val],
|
||||
# batch_size=args.batch_size)
|
||||
y_pred = clf.predict([domain_val, flow_val],
|
||||
batch_size=args.batch_size)
|
||||
np.save(args.future_prediction, y_pred)
|
||||
@ -197,6 +216,12 @@ def main_visualization():
|
||||
logger.info("plot roc curve")
|
||||
visualize.plot_roc_curve(client_val.value, client_pred, "{}/client_roc.png".format(args.model_path))
|
||||
visualize.plot_roc_curve(server_val.value, server_pred, "{}/server_roc.png".format(args.model_path))
|
||||
visualize.plot_confusion_matrix(client_val.value.argmax(1), client_pred.argmax(1),
|
||||
"{}/client_cov.png".format(args.model_path),
|
||||
normalize=False, title="Client Confusion Matrix")
|
||||
visualize.plot_confusion_matrix(server_val.value.argmax(1), server_pred.argmax(1),
|
||||
"{}/server_cov.png".format(args.model_path),
|
||||
normalize=False, title="Server Confusion Matrix")
|
||||
|
||||
|
||||
def main_score():
|
||||
|
@ -90,10 +90,10 @@ def plot_roc_curve(mask, prediction, path):
|
||||
print("roc_auc", roc_auc)
|
||||
|
||||
|
||||
def plot_confusion_matrix(y_true, y_pred,
|
||||
def plot_confusion_matrix(y_true, y_pred, path,
|
||||
normalize=False,
|
||||
title='Confusion matrix',
|
||||
cmap="Blues"):
|
||||
cmap="Blues", dpi=600):
|
||||
"""
|
||||
This function prints and plots the confusion matrix.
|
||||
Normalization can be applied by setting `normalize=True`.
|
||||
@ -125,6 +125,8 @@ def plot_confusion_matrix(y_true, y_pred,
|
||||
plt.tight_layout()
|
||||
plt.ylabel('True label')
|
||||
plt.xlabel('Predicted label')
|
||||
plt.savefig(path, dpi=dpi)
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_training_curve(logs, key, path, dpi=600):
|
||||
|
Loading…
x
Reference in New Issue
Block a user