refactor all visualization for pauls changes - evaluate on max windows per users

This commit is contained in:
René Knaebel 2017-09-08 22:59:55 +02:00
parent 9a51b6ea34
commit 6fef2b8b84
3 changed files with 47 additions and 47 deletions

View File

@ -59,7 +59,8 @@ fancy:
python3 main.py --mode fancy --batch 128 --model results/test/test_client_4 --test data/rk_mini.csv.gz python3 main.py --mode fancy --batch 128 --model results/test/test_client_4 --test data/rk_mini.csv.gz
all-fancy: all-fancy:
python3 main.py --mode all_fancy --batch 128 --models results/test/test* --test data/rk_mini.csv.gz python3 main.py --mode all_fancy --batch 128 --models results/test/test* --test data/rk_mini.csv.gz \
--out-prefix results/test/
hyper: hyper:
python3 main.py --mode hyperband --batch 64 --train data/rk_data.csv.gz python3 main.py --mode hyperband --batch 64 --train data/rk_data.csv.gz

View File

@ -235,7 +235,7 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
check_h5dataset(h5data) check_h5dataset(h5data)
except FileNotFoundError: except FileNotFoundError:
logger.info("load raw training dataset") logger.info("load raw training dataset")
domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(h5data + "_raw", train_data, domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(h5data, train_data,
domain_length, window_size) domain_length, window_size)
logger.info("filter training dataset") logger.info("filter training dataset")
domain, flow, name, client, server = filter_window_dataset_by_hits(domain.value, flow.value, domain, flow, name, client, server = filter_window_dataset_by_hits(domain.value, flow.value,
@ -256,6 +256,7 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
def load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size): def load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size):
h5data = h5data + "_raw"
char_dict = get_character_dict() char_dict = get_character_dict()
logger.info(f"check for h5data {h5data}") logger.info(f"check for h5data {h5data}")
try: try:

76
main.py
View File

@ -276,11 +276,18 @@ def main_test():
def main_visualization(): def main_visualization():
domain_val, flow_val, name_val, client_val, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data, _, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data,
args.test_data, args.test_data,
args.domain_length, args.domain_length,
args.window) args.window)
client_val = client_val.value
results = dataset.load_predictions(args.model_path)
df = pd.DataFrame(data={
"names": name_val, "client_pred": results["client_pred"].flatten(),
"hits_vt": hits_vt, "hits_trusted": hits_trusted
})
df["client_val"] = np.logical_or(df.hits_vt == 1.0, df.hits_trusted >= 3)
df_user = df.groupby(df.names).max()
logger.info("plot model") logger.info("plot model")
model = load_model(args.clf_model, custom_objects=models.get_metrics()) model = load_model(args.clf_model, custom_objects=models.get_metrics())
@ -296,95 +303,86 @@ def main_visualization():
else: else:
logger.warning("Error while plotting training curves") logger.warning("Error while plotting training curves")
results = dataset.load_predictions(args.future_prediction)
client_pred = results["client_pred"].flatten()
logger.info("plot pr curve") logger.info("plot pr curve")
visualize.plot_clf() visualize.plot_clf()
visualize.plot_precision_recall(client_val, client_pred, args.model_path) visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), args.model_path)
visualize.plot_legend() visualize.plot_legend()
visualize.plot_save("{}/window_client_prc.png".format(args.model_path)) visualize.plot_save("{}/window_client_prc.png".format(args.model_path))
logger.info("plot roc curve") logger.info("plot roc curve")
visualize.plot_clf() visualize.plot_clf()
visualize.plot_roc_curve(client_val, client_pred, args.model_path) visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), args.model_path)
visualize.plot_legend() visualize.plot_legend()
visualize.plot_save("{}/window_client_roc.png".format(args.model_path)) visualize.plot_save("{}/window_client_roc.png".format(args.model_path))
print(f"names {name_val.shape} vals {client_val.shape} preds {client_pred.shape}")
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred})
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
visualize.plot_clf() visualize.plot_clf()
visualize.plot_precision_recall(user_vals, user_preds, args.model_path) visualize.plot_precision_recall(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix(), args.model_path)
visualize.plot_legend() visualize.plot_legend()
visualize.plot_save("{}/user_client_prc.png".format(args.model_path)) visualize.plot_save("{}/user_client_prc.png".format(args.model_path))
visualize.plot_clf() visualize.plot_clf()
visualize.plot_roc_curve(user_vals, user_preds, args.model_path) visualize.plot_roc_curve(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix(), args.model_path)
visualize.plot_legend() visualize.plot_legend()
visualize.plot_save("{}/user_client_roc.png".format(args.model_path)) visualize.plot_save("{}/user_client_roc.png".format(args.model_path))
visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(), visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(),
"{}/client_cov.png".format(args.model_path), "{}/client_cov.png".format(args.model_path),
normalize=False, title="Client Confusion Matrix") normalize=False, title="Client Confusion Matrix")
visualize.plot_confusion_matrix(user_vals, user_preds.flatten().round(), visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(),
"{}/user_cov.png".format(args.model_path), "{}/user_cov.png".format(args.model_path),
normalize=False, title="User Confusion Matrix") normalize=False, title="User Confusion Matrix")
logger.info("visualize embedding") logger.info("visualize embedding")
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length) domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
domain_embedding = np.load(args.model_path + "/domain_embds.npy") domain_embedding = results["domain_embds"]
visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path)) visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
def main_visualize_all(): def main_visualize_all():
domain_val, flow_val, name_val, client_val, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data, _, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data,
args.test_data, args.test_data,
args.domain_length, args.domain_length,
args.window) args.window)
def load_df(path):
res = dataset.load_predictions(path)
res = pd.DataFrame(data={
"names": name_val, "client_pred": res["client_pred"].flatten(),
"hits_vt": hits_vt, "hits_trusted": hits_trusted
})
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
return res
logger.info("plot pr curves") logger.info("plot pr curves")
visualize.plot_clf() visualize.plot_clf()
for model_args in get_model_args(args): for model_args in get_model_args(args):
results = dataset.load_predictions(model_args["future_prediction"]) df = load_df(model_args["model_path"])
client_pred = results["client_pred"].flatten() visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_path"])
visualize.plot_precision_recall(client_val, client_pred, model_args["model_path"])
visualize.plot_legend() visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_prc.png") visualize.plot_save(f"{args.output_prefix}_window_client_prc.png")
logger.info("plot roc curves") logger.info("plot roc curves")
visualize.plot_clf() visualize.plot_clf()
for model_args in get_model_args(args): for model_args in get_model_args(args):
results = dataset.load_predictions(model_args["future_prediction"]) df = load_df(model_args["model_path"])
client_pred = results["client_pred"].flatten() visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_path"])
visualize.plot_roc_curve(client_val, client_pred, model_args["model_path"])
visualize.plot_legend() visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_roc.png") visualize.plot_save(f"{args.output_prefix}_window_client_roc.png")
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
logger.info("plot user pr curves") logger.info("plot user pr curves")
visualize.plot_clf() visualize.plot_clf()
for model_args in get_model_args(args): for model_args in get_model_args(args):
results = dataset.load_predictions(model_args["future_prediction"]) df = load_df(model_args["model_path"])
client_pred = results["client_pred"].flatten() df = df.groupby(df.names).max()
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred}) visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_path"])
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"])
visualize.plot_legend() visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_user_client_prc.png") visualize.plot_save(f"{args.output_prefix}_user_client_prc.png")
logger.info("plot user roc curves") logger.info("plot user roc curves")
visualize.plot_clf() visualize.plot_clf()
for model_args in get_model_args(args): for model_args in get_model_args(args):
results = dataset.load_predictions(model_args["future_prediction"]) df = load_df(model_args["model_path"])
client_pred = results["client_pred"].flatten() df = df.groupby(df.names).max()
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred}) visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_path"])
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"])
visualize.plot_legend() visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png") visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")