From f4da1476881d2f4ea2357c4d718adc47d9b7a882 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Sun, 30 Jul 2017 15:49:37 +0200 Subject: [PATCH] refactor cmd argument to have single value for mode --- Makefile | 11 +++++-- arguments.py | 10 +++++-- dataset.py | 12 ++++++++ main.py | 84 ++++++++++++++++++---------------------------------- 4 files changed, 55 insertions(+), 62 deletions(-) diff --git a/Makefile b/Makefile index 1a23113..cf52d24 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +1,16 @@ run: - python3 main.py --modes train --batch 128 --model results/test --train data/rk_mini.csv.gz --epochs 10 + python3 main.py --modes train --train data/rk_mini.csv.gz --model results/test --epochs 10 \ + --hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights + +run_new: + python3 main.py --modes train --train data/rk_mini.csv.gz --model results/test2 --epochs 10 \ + --hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights --new_model test: - python3 main.py --modes test --batch 128 --model results/test --test data/rk_mini.csv.gz + python3 main.py --modes test --batch 128 --model results/test --test data/rk_mini.csv.gz fancy: - python3 main.py --modes fancy --batch 128 --model results/test --test data/rk_mini.csv.gz + python3 main.py --modes fancy --batch 128 --model results/test --test data/rk_mini.csv.gz hyper: python3 main.py --modes hyperband --batch 64 --train data/rk_data.csv.gz diff --git a/arguments.py b/arguments.py index 6559fab..9fab93f 100644 --- a/arguments.py +++ b/arguments.py @@ -3,15 +3,19 @@ import os parser = argparse.ArgumentParser() -parser.add_argument("--modes", action="store", dest="modes", nargs="+", - default=[]) +parser.add_argument("--mode", action="store", dest="mode", + default="") parser.add_argument("--train", action="store", dest="train_data", default="data/full_dataset.csv.tar.gz") +parser.add_argument("--data", action="store", dest="train_data", + default="data/full_dataset.csv.tar.gz") + parser.add_argument("--test", action="store", dest="test_data", default="data/full_future_dataset.csv.tar.gz") + parser.add_argument("--model", action="store", dest="model_path", default="results/model_x") @@ -74,5 +78,5 @@ def parse(): args.train_log = os.path.join(args.model_path, "train.log.csv") args.train_h5data = args.train_data + ".h5" args.test_h5data = args.test_data + ".h5" - args.future_prediction = os.path.join(args.model_path, "future_predict.npy") + args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.test_data)}_pred.h5") return args diff --git a/dataset.py b/dataset.py index 2cc2ffd..ccebf93 100644 --- a/dataset.py +++ b/dataset.py @@ -246,3 +246,15 @@ def load_or_generate_domains(train_data, domain_length): user_flow_df.groupby(user_flow_df.domain).mean() return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix() + + +def save_predictions(path, c_pred, s_pred): + f = h5py.File(path, "w") + f.create_dataset("client", data=c_pred) + f.create_dataset("server", data=s_pred) + f.close() + + +def load_predictions(path): + f = h5py.File(path, "r") + return f["client"], f["server"] diff --git a/main.py b/main.py index 8ccb6e8..8ec4d45 100644 --- a/main.py +++ b/main.py @@ -175,41 +175,16 @@ def main_test(): domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data, args.domain_length, args.window) clf = load_model(args.clf_model, custom_objects=models.get_metrics()) - y_pred = clf.predict([domain_val, flow_val], - batch_size=args.batch_size, - verbose=1) - np.save(args.future_prediction, y_pred) - - # char_dict = dataset.get_character_dict() - # user_flow_df = dataset.get_user_flow_data(args.test_data) - # domains = user_flow_df.domain.unique()[:-1] - # - # def get_domain_features_reduced(d): - # return dataset.get_domain_features(d[0], char_dict, args.domain_length) - # - # domain_features = [] - # for ds in domains: - # domain_features.append(np.apply_along_axis(get_domain_features_reduced, 2, np.atleast_3d(ds))) - # - # model = load_model(args.embedding_model) - # domain_features = np.stack(domain_features).reshape((-1, 40)) - # pred = model.predict(domain_features, batch_size=args.batch_size, verbose=1) - # - # np.save("/tmp/rk/domains.npy", domains) - # np.save("/tmp/rk/domain_features.npy", domain_features) - # np.save("/tmp/rk/domain_embd.npy", pred) - - -def main_embedding(): - model = load_model(args.embedding_model) - domain_encs, labels = dataset.load_or_generate_domains(args.train_data, args.domain_length) - domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1) - visualize.plot_embedding(domain_embedding, labels, path="results/pp3/embd.png") + c_pred, s_pred = clf.predict([domain_val, flow_val], + batch_size=args.batch_size, + verbose=1) + dataset.save_predictions(args.future_prediction, c_pred, s_pred) def main_visualization(): domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data, args.domain_length, args.window) + client_val, server_val = client_val.value, server_val.value logger.info("plot model") model = load_model(args.clf_model, custom_objects=models.get_metrics()) visualize.plot_model(model, os.path.join(args.model_path, "model.png")) @@ -221,28 +196,27 @@ def main_visualization(): except Exception as e: logger.warning(f"could not generate training curves: {e}") - client_pred, server_pred = np.load(args.future_prediction) + client_pred, server_pred = dataset.load_predictions(args.future_prediction) + client_pred, server_pred = client_pred.value, server_pred.value logger.info("plot pr curve") - visualize.plot_precision_recall(client_val.value, client_pred, "{}/client_prc.png".format(args.model_path)) - visualize.plot_precision_recall(server_val.value, server_pred, "{}/server_prc.png".format(args.model_path)) - visualize.plot_precision_recall_curves(client_val.value, client_pred, "{}/client_prc2.png".format(args.model_path)) - visualize.plot_precision_recall_curves(server_val.value, server_pred, "{}/server_prc2.png".format(args.model_path)) + visualize.plot_precision_recall(client_val, client_pred.flatten(), "{}/client_prc.png".format(args.model_path)) + # visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path)) + # visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path)) + # visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path)) logger.info("plot roc curve") - visualize.plot_roc_curve(client_val.value, client_pred, "{}/client_roc.png".format(args.model_path)) - visualize.plot_roc_curve(server_val.value, server_pred, "{}/server_roc.png".format(args.model_path)) - visualize.plot_confusion_matrix(client_val.value.argmax(1), client_pred.argmax(1), + visualize.plot_roc_curve(client_val, client_pred.flatten(), "{}/client_roc.png".format(args.model_path)) + # visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path)) + visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(), "{}/client_cov.png".format(args.model_path), normalize=False, title="Client Confusion Matrix") - visualize.plot_confusion_matrix(server_val.value.argmax(1), server_pred.argmax(1), - "{}/server_cov.png".format(args.model_path), - normalize=False, title="Server Confusion Matrix") - - -def main_score(): - # mask = dataset.load_mask_eval(args.data, args.test_image) - # pred = np.load(args.pred) - # visualize.score_model(mask, pred) - pass + # visualize.plot_confusion_matrix(server_val.argmax(1), server_pred.argmax(1), + # "{}/server_cov.png".format(args.model_path), + # normalize=False, title="Server Confusion Matrix") + logger.info("visualize embedding") + model = load_model(args.embedding_model) + domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length) + domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1) + visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path)) def main_data(): @@ -259,19 +233,17 @@ def main_data(): def main(): - if "train" in args.modes: + if "train" == args.mode: main_train() - if "hyperband" in args.modes: + if "hyperband" == args.mode: main_hyperband() - if "test" in args.modes: + if "test" == args.mode: main_test() - if "fancy" in args.modes: + if "fancy" == args.mode: main_visualization() - if "score" in args.modes: - main_score() - if "paul" in args.modes: + if "paul" == args.mode: main_paul_best() - if "data" in args.modes: + if "data" == args.mode: main_data()