From dc9180da10352e8e3996942a2ae8e71828947754 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Fri, 1 Sep 2017 10:42:26 +0200 Subject: [PATCH] refactor visualization, change arguments for model type and its depth --- Makefile | 31 +++++++++--- arguments.py | 19 +++++++- dataset.py | 2 +- main.py | 87 ++++++++++++++++++++++++---------- models/__init__.py | 8 +++- run.sh | 114 +++++++++------------------------------------ visualize.py | 45 +++++++++--------- 7 files changed, 156 insertions(+), 150 deletions(-) diff --git a/Makefile b/Makefile index cdf029f..29eb9d9 100644 --- a/Makefile +++ b/Makefile @@ -1,16 +1,33 @@ run: - python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test --epochs 10 \ - --hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test1 --epochs 10 --depth small \ + --hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type final -run_new: - python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test2 --epochs 10 \ - --hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights --new_model + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test2 --epochs 10 --depth small \ + --hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type inter + + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test3 --epochs 10 --depth medium \ + --hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type final + + python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test4 --epochs 10 --depth medium \ + --hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type inter test: - python3 main.py --mode test --batch 128 --model results/test --test data/rk_mini.csv.gz + python3 main.py --mode test --batch 128 --models results/test* --test data/rk_mini.csv.gz fancy: - python3 main.py --mode fancy --batch 128 --model results/test --test data/rk_mini.csv.gz + python3 main.py --mode fancy --batch 128 --model results/test1 --test data/rk_mini.csv.gz + + python3 main.py --mode fancy --batch 128 --model results/test2 --test data/rk_mini.csv.gz + + python3 main.py --mode fancy --batch 128 --model results/test3 --test data/rk_mini.csv.gz + + python3 main.py --mode fancy --batch 128 --model results/test4 --test data/rk_mini.csv.gz + +all-fancy: + python3 main.py --mode all_fancy --batch 128 --models results/test* --test data/rk_mini.csv.gz hyper: python3 main.py --mode hyperband --batch 64 --train data/rk_data.csv.gz + +clean: + rm -r results/test* \ No newline at end of file diff --git a/arguments.py b/arguments.py index 094c602..7ecff86 100644 --- a/arguments.py +++ b/arguments.py @@ -19,8 +19,14 @@ parser.add_argument("--test", action="store", dest="test_data", parser.add_argument("--model", action="store", dest="model_path", default="results/model_x") +parser.add_argument("--models", action="store", dest="model_paths", nargs="+", + default=[]) + parser.add_argument("--type", action="store", dest="model_type", - default="paul") + default="final") # inter, final, staggered + +parser.add_argument("--depth", action="store", dest="model_depth", + default="small") # small, medium parser.add_argument("--model_output", action="store", dest="model_output", default="both") @@ -74,6 +80,17 @@ parser.add_argument("--new_model", action="store_true", dest="new_model") +def get_model_args(args): + return [{ + "model_path": model_path, + "embedding_model": os.path.join(model_path, "embd.h5"), + "clf_model": os.path.join(model_path, "clf.h5"), + "train_log": os.path.join(model_path, "train.log.csv"), + "train_h5data": args.train_data + ".h5", + "test_h5data": args.test_data + ".h5", + "future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred.h5") + } for model_path in args.model_paths] + def parse(): args = parser.parse_args() args.embedding_model = os.path.join(args.model_path, "embd.h5") diff --git a/dataset.py b/dataset.py index ccfc8e1..3516ee6 100644 --- a/dataset.py +++ b/dataset.py @@ -233,11 +233,11 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size): def load_or_generate_domains(train_data, domain_length): fn = f"{train_data}_domains.gz" + char_dict = get_character_dict() try: user_flow_df = pd.read_csv(fn) except FileNotFoundError: - char_dict = get_character_dict() user_flow_df = get_user_flow_data(train_data) # user_flow_df.reset_index(inplace=True) user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0, diff --git a/main.py b/main.py index b1232a6..f062c1b 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,7 @@ import models import visualize from dataset import load_or_generate_h5data from utils import exists_or_make_path, get_custom_class_weights +from arguments import get_model_args logger = logging.getLogger('logger') logger.setLevel(logging.DEBUG) @@ -56,6 +57,7 @@ if args.gpu: # default parameter PARAMS = { "type": args.model_type, + "depth": args.model_depth, "batch_size": 64, "window_size": args.window, "domain_length": args.domain_length, @@ -149,8 +151,8 @@ def main_train(param=None): logger.info("class weights: set default") custom_class_weights = None - logger.info(f"select model: {'new' if args.new_model else 'old'}") - if args.new_model: + logger.info(f"select model: {args.model_type}") + if args.model_type == "inter": server_tr = np.expand_dims(server_windows_tr, 2) model = new_model logger.info("compile and train model") @@ -181,35 +183,44 @@ def main_train(param=None): def main_test(): + logger.info("start test: load data") domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data, args.domain_length, args.window) - clf = load_model(args.clf_model, custom_objects=models.get_metrics()) - pred = clf.predict([domain_val, flow_val], - batch_size=args.batch_size, - verbose=1) - if args.model_output == "both": - c_pred, s_pred = pred - elif args.model_output == "client": - c_pred = pred - s_pred = np.zeros(0) - else: - c_pred = np.zeros(0) - s_pred = pred - dataset.save_predictions(args.future_prediction, c_pred, s_pred) - - model = load_model(args.embedding_model) domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length) - domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1) - np.save(args.model_path + "/domain_embds.npy", domain_embedding) + + for model_args in get_model_args(args): + logger.info(f"process model {model_args['model_path']}") + clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics()) + + pred = clf_model.predict([domain_val, flow_val], + batch_size=args.batch_size, + verbose=1) + + if args.model_output == "both": + c_pred, s_pred = pred + elif args.model_output == "client": + c_pred = pred + s_pred = np.zeros(0) + else: + c_pred = np.zeros(0) + s_pred = pred + dataset.save_predictions(model_args["future_prediction"], c_pred, s_pred) + + embd_model = load_model(model_args["embedding_model"]) + domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1) + np.save(model_args["model_path"] + "/domain_embds.npy", domain_embeddings) def main_visualization(): domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data, args.domain_length, args.window) - client_val, server_val = client_val.value, server_val.value - logger.info("plot model") - model = load_model(args.clf_model, custom_objects=models.get_metrics()) - visualize.plot_model(model, os.path.join(args.model_path, "model.png")) + # client_val, server_val = client_val.value, server_val.value + client_val = client_val.value + + # logger.info("plot model") + # model = load_model(model_args.clf_model, custom_objects=models.get_metrics()) + # visualize.plot_model(model, os.path.join(args.model_path, "model.png")) + try: logger.info("plot training curve") logs = pd.read_csv(args.train_log) @@ -224,12 +235,16 @@ def main_visualization(): client_pred, server_pred = dataset.load_predictions(args.future_prediction) client_pred, server_pred = client_pred.value, server_pred.value logger.info("plot pr curve") - visualize.plot_precision_recall(client_val, client_pred.flatten(), "{}/client_prc.png".format(args.model_path)) + visualize.plot_clf() + visualize.plot_precision_recall(client_val, client_pred) + visualize.plot_save("{}/client_prc.png".format(args.model_path)) # visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path)) # visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path)) # visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path)) logger.info("plot roc curve") - visualize.plot_roc_curve(client_val, client_pred.flatten(), "{}/client_roc.png".format(args.model_path)) + visualize.plot_clf() + visualize.plot_roc_curve(client_val, client_pred) + visualize.plot_save("{}/client_roc.png".format(args.model_path)) # visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path)) visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(), "{}/client_cov.png".format(args.model_path), @@ -243,6 +258,26 @@ def main_visualization(): visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path)) +def main_visualize_all(): + domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data, + args.domain_length, args.window) + logger.info("plot pr curves") + visualize.plot_clf() + for model_args in get_model_args(args): + client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"]) + visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"]) + visualize.plot_legend() + visualize.plot_save("all_client_prc.png") + + logger.info("plot roc curves") + visualize.plot_clf() + for model_args in get_model_args(args): + client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"]) + visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"]) + visualize.plot_legend() + visualize.plot_save("all_client_roc.png") + + def main_data(): char_dict = dataset.get_character_dict() user_flow_df = dataset.get_user_flow_data(args.train_data) @@ -265,6 +300,8 @@ def main(): main_test() if "fancy" == args.mode: main_visualization() + if "all_fancy" == args.mode: + main_visualize_all() if "paul" == args.mode: main_paul_best() if "data" == args.mode: diff --git a/models/__init__.py b/models/__init__.py index d71715c..d55fd18 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -8,6 +8,7 @@ def get_models_by_params(params: dict): # decomposing param section # mainly embedding model network_type = params.get("type") + network_depth = params.get("depth") embedding_size = params.get("embedding_size") input_length = params.get("input_length") filter_embedding = params.get("filter_embedding") @@ -24,7 +25,12 @@ def get_models_by_params(params: dict): dense_dim = params.get("dense_main") model_output = params.get("model_output", "both") # create models - networks = renes_networks if network_type == "rene" else pauls_networks + if network_depth == "small": + networks = pauls_networks + elif network_depth == "medium": + networks = renes_networks + else: + raise Exception("network not found") embedding_model = networks.get_embedding(embedding_size, input_length, filter_embedding, kernel_embedding, hidden_embedding, dropout) diff --git a/run.sh b/run.sh index 7643b58..c42c0a3 100644 --- a/run.sh +++ b/run.sh @@ -1,98 +1,26 @@ #!/usr/bin/env bash -python main.py --mode train \ - --train /tmp/rk/currentData.csv \ - --model /tmp/rk/results/small_both \ - --epochs 25 \ - --embd 64 \ - --hidden_chaar_dims 128 \ - --domain_embd 32 \ - --batch 256 \ - --balanced_weights \ - --model_output both -python main.py --mode train \ - --train /tmp/rk/currentData.csv \ - --model /tmp/rk/results/small_client \ - --epochs 25 \ - --embd 64 \ - --hidden_char_dims 128 \ - --domain_embd 32 \ - --batch 256 \ - --balanced_weights \ - --model_output client +for output in client both +do + for depth in small medium + do + for mtype in inter final staggered + do -python main.py --mode train \ - --train /tmp/rk/currentData.csv \ - --model /tmp/rk/results/small_new_both \ - --epochs 25 \ - --embd 64 \ - --hidden_char_dims 128 \ - --domain_embd 32 \ - --batch 256 \ - --balanced_weights \ - --model_output both \ - --new_model + python main.py --mode train \ + --train /tmp/rk/currentData.csv \ + --model /tmp/rk/results/${output}_${depth}_${mtype} \ + --epochs 50 \ + --embd 64 \ + --hidden_chaar_dims 128 \ + --domain_embd 32 \ + --batch 256 \ + --balanced_weights \ + --model_output ${output} \ + --type ${mtype} \ + --depth ${depth} -python main.py --mode train \ - --train /tmp/rk/currentData.csv \ - --model /tmp/rk/results/small_new_client \ - --epochs 25 \ - --embd 64 \ - --hidden_char_dims 128 \ - --domain_embd 32 \ - --batch 256 \ - --balanced_weights \ - --model_output client \ - --new_model -## -## -python main.py --mode train \ - --train /tmp/rk/currentData.csv \ - --model /tmp/rk/results/medium_both \ - --epochs 25 \ - --embd 64 \ - --hidden_char_dims 128 \ - --domain_embd 32 \ - --batch 256 \ - --balanced_weights \ - --model_output both \ - --type rene - -python main.py --mode train \ - --train /tmp/rk/currentData.csv \ - --model /tmp/rk/results/medium_client \ - --epochs 25 \ - --embd 64 \ - --hidden_char_dims 128 \ - --domain_embd 32 \ - --batch 256 \ - --balanced_weights \ - --model_output client \ - --type rene - -python main.py --mode train \ - --train /tmp/rk/currentData.csv \ - --model /tmp/rk/results/medium_new_both \ - --epochs 25 \ - --embd 64 \ - --hidden_char_dims 128 \ - --domain_embd 32 \ - --batch 256 \ - --balanced_weights \ - --model_output both \ - --new_model \ - --type rene - -python main.py --mode train \ - --train /tmp/rk/currentData.csv \ - --model /tmp/rk/results/medium_new_client \ - --epochs 25 \ - --embd 64 \ - --hidden_char_dims 128 \ - --domain_embd 32 \ - --batch 256 \ - --balanced_weights \ - --model_output client \ - --new_model \ - --type rene + done + done +done diff --git a/visualize.py b/visualize.py index 0c7b7cd..ba8223a 100644 --- a/visualize.py +++ b/visualize.py @@ -2,7 +2,6 @@ import os import matplotlib.pyplot as plt import numpy as np -from keras.utils.vis_utils import plot_model from sklearn.decomposition import TruncatedSVD from sklearn.metrics import ( auc, classification_report, confusion_matrix, fbeta_score, precision_recall_curve, @@ -32,38 +31,43 @@ def scores(y_true, y_pred): print(" f0.5 score:", f05_score) -def plot_precision_recall(mask, prediction, path): - y = mask.flatten() - y_pred = prediction.flatten() +def plot_clf(): + plt.clf() + + +def plot_save(path, dpi=600): + plt.savefig(path, dpi=dpi) + plt.close() + + +def plot_legend(): + plt.legend() + + +def plot_precision_recall(y, y_pred, label=""): + y = y.flatten() + y_pred = y_pred.flatten() precision, recall, thresholds = precision_recall_curve(y, y_pred) decreasing_max_precision = np.maximum.accumulate(precision)[::-1] - plt.clf() # fig, ax = plt.subplots(1, 1) # ax.hold(True) - plt.plot(recall, precision, '--b') + plt.plot(recall, precision, '--', label=label) # ax.step(recall[::-1], decreasing_max_precision, '-r') plt.xlabel('Recall') plt.ylabel('Precision') - plt.savefig(path, dpi=600) - plt.close() - -def plot_precision_recall_curves(mask, prediction, path): - y = mask.flatten() - y_pred = prediction.flatten() +def plot_precision_recall_curves(y, y_pred): + y = y.flatten() + y_pred = y_pred.flatten() precision, recall, thresholds = precision_recall_curve(y, y_pred) - plt.clf() plt.plot(recall, label="Recall") plt.plot(precision, label="Precision") plt.xlabel('Threshold') plt.ylabel('Score') - plt.savefig(path, dpi=600) - plt.close() - def score_model(y, prediction): y = y.flatten() @@ -78,16 +82,12 @@ def score_model(y, prediction): print("F0.5 Score", fbeta_score(y, y_pred.round(), 0.5)) -def plot_roc_curve(mask, prediction, path): +def plot_roc_curve(mask, prediction, label=""): y = mask.flatten() y_pred = prediction.flatten() fpr, tpr, thresholds = roc_curve(y, y_pred) roc_auc = auc(fpr, tpr) - plt.clf() - plt.plot(fpr, tpr) - plt.savefig(path, dpi=600) - plt.close() - + plt.plot(fpr, tpr, label=label) print("roc_auc", roc_auc) @@ -161,4 +161,5 @@ def plot_embedding(domain_embedding, labels, path, dpi=600): def plot_model_as(model, path): + from keras.utils.vis_utils import plot_model plot_model(model, to_file=path, show_shapes=True, show_layer_names=True)