From 6fef2b8b84159c1173a3368f01d95f4e4ee92a5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Fri, 8 Sep 2017 22:59:55 +0200 Subject: [PATCH] refactor all visualization for pauls changes - evaluate on max windows per users --- Makefile | 3 +- dataset.py | 3 +- main.py | 88 ++++++++++++++++++++++++++---------------------------- 3 files changed, 47 insertions(+), 47 deletions(-) diff --git a/Makefile b/Makefile index 0b283d3..b7c51e4 100644 --- a/Makefile +++ b/Makefile @@ -59,7 +59,8 @@ fancy: python3 main.py --mode fancy --batch 128 --model results/test/test_client_4 --test data/rk_mini.csv.gz all-fancy: - python3 main.py --mode all_fancy --batch 128 --models results/test/test* --test data/rk_mini.csv.gz + python3 main.py --mode all_fancy --batch 128 --models results/test/test* --test data/rk_mini.csv.gz \ + --out-prefix results/test/ hyper: python3 main.py --mode hyperband --batch 64 --train data/rk_data.csv.gz diff --git a/dataset.py b/dataset.py index ba6a4aa..d8965a0 100644 --- a/dataset.py +++ b/dataset.py @@ -235,7 +235,7 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size): check_h5dataset(h5data) except FileNotFoundError: logger.info("load raw training dataset") - domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(h5data + "_raw", train_data, + domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size) logger.info("filter training dataset") domain, flow, name, client, server = filter_window_dataset_by_hits(domain.value, flow.value, @@ -256,6 +256,7 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size): def load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size): + h5data = h5data + "_raw" char_dict = get_character_dict() logger.info(f"check for h5data {h5data}") try: diff --git a/main.py b/main.py index b5b8d41..8c80708 100644 --- a/main.py +++ b/main.py @@ -276,11 +276,18 @@ def main_test(): def main_visualization(): - domain_val, flow_val, name_val, client_val, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data, - args.test_data, - args.domain_length, - args.window) - client_val = client_val.value + _, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data, + args.test_data, + args.domain_length, + args.window) + + results = dataset.load_predictions(args.model_path) + df = pd.DataFrame(data={ + "names": name_val, "client_pred": results["client_pred"].flatten(), + "hits_vt": hits_vt, "hits_trusted": hits_trusted + }) + df["client_val"] = np.logical_or(df.hits_vt == 1.0, df.hits_trusted >= 3) + df_user = df.groupby(df.names).max() logger.info("plot model") model = load_model(args.clf_model, custom_objects=models.get_metrics()) @@ -296,95 +303,86 @@ def main_visualization(): else: logger.warning("Error while plotting training curves") - results = dataset.load_predictions(args.future_prediction) - client_pred = results["client_pred"].flatten() - logger.info("plot pr curve") visualize.plot_clf() - visualize.plot_precision_recall(client_val, client_pred, args.model_path) + visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), args.model_path) visualize.plot_legend() visualize.plot_save("{}/window_client_prc.png".format(args.model_path)) logger.info("plot roc curve") visualize.plot_clf() - visualize.plot_roc_curve(client_val, client_pred, args.model_path) + visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), args.model_path) visualize.plot_legend() visualize.plot_save("{}/window_client_roc.png".format(args.model_path)) - print(f"names {name_val.shape} vals {client_val.shape} preds {client_pred.shape}") - - df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val}) - user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float) - df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred}) - user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float) - visualize.plot_clf() - visualize.plot_precision_recall(user_vals, user_preds, args.model_path) + visualize.plot_precision_recall(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix(), args.model_path) visualize.plot_legend() visualize.plot_save("{}/user_client_prc.png".format(args.model_path)) visualize.plot_clf() - visualize.plot_roc_curve(user_vals, user_preds, args.model_path) + visualize.plot_roc_curve(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix(), args.model_path) visualize.plot_legend() visualize.plot_save("{}/user_client_roc.png".format(args.model_path)) - visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(), + visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(), "{}/client_cov.png".format(args.model_path), normalize=False, title="Client Confusion Matrix") - visualize.plot_confusion_matrix(user_vals, user_preds.flatten().round(), + visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(), "{}/user_cov.png".format(args.model_path), normalize=False, title="User Confusion Matrix") logger.info("visualize embedding") domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length) - domain_embedding = np.load(args.model_path + "/domain_embds.npy") + domain_embedding = results["domain_embds"] visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path)) def main_visualize_all(): - domain_val, flow_val, name_val, client_val, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data, - args.test_data, - args.domain_length, - args.window) + _, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data, + args.test_data, + args.domain_length, + args.window) + + def load_df(path): + res = dataset.load_predictions(path) + res = pd.DataFrame(data={ + "names": name_val, "client_pred": res["client_pred"].flatten(), + "hits_vt": hits_vt, "hits_trusted": hits_trusted + }) + res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3) + return res + logger.info("plot pr curves") visualize.plot_clf() for model_args in get_model_args(args): - results = dataset.load_predictions(model_args["future_prediction"]) - client_pred = results["client_pred"].flatten() - visualize.plot_precision_recall(client_val, client_pred, model_args["model_path"]) + df = load_df(model_args["model_path"]) + visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_path"]) visualize.plot_legend() visualize.plot_save(f"{args.output_prefix}_window_client_prc.png") logger.info("plot roc curves") visualize.plot_clf() for model_args in get_model_args(args): - results = dataset.load_predictions(model_args["future_prediction"]) - client_pred = results["client_pred"].flatten() - visualize.plot_roc_curve(client_val, client_pred, model_args["model_path"]) + df = load_df(model_args["model_path"]) + visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_path"]) visualize.plot_legend() visualize.plot_save(f"{args.output_prefix}_window_client_roc.png") - df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val}) - user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float) - logger.info("plot user pr curves") visualize.plot_clf() for model_args in get_model_args(args): - results = dataset.load_predictions(model_args["future_prediction"]) - client_pred = results["client_pred"].flatten() - df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred}) - user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float) - visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"]) + df = load_df(model_args["model_path"]) + df = df.groupby(df.names).max() + visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_path"]) visualize.plot_legend() visualize.plot_save(f"{args.output_prefix}_user_client_prc.png") logger.info("plot user roc curves") visualize.plot_clf() for model_args in get_model_args(args): - results = dataset.load_predictions(model_args["future_prediction"]) - client_pred = results["client_pred"].flatten() - df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred}) - user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float) - visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"]) + df = load_df(model_args["model_path"]) + df = df.groupby(df.names).max() + visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_path"]) visualize.plot_legend() visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")