refactor all visualization for pauls changes - evaluate on max windows per users
This commit is contained in:
parent
9a51b6ea34
commit
6fef2b8b84
3
Makefile
3
Makefile
@ -59,7 +59,8 @@ fancy:
|
|||||||
python3 main.py --mode fancy --batch 128 --model results/test/test_client_4 --test data/rk_mini.csv.gz
|
python3 main.py --mode fancy --batch 128 --model results/test/test_client_4 --test data/rk_mini.csv.gz
|
||||||
|
|
||||||
all-fancy:
|
all-fancy:
|
||||||
python3 main.py --mode all_fancy --batch 128 --models results/test/test* --test data/rk_mini.csv.gz
|
python3 main.py --mode all_fancy --batch 128 --models results/test/test* --test data/rk_mini.csv.gz \
|
||||||
|
--out-prefix results/test/
|
||||||
|
|
||||||
hyper:
|
hyper:
|
||||||
python3 main.py --mode hyperband --batch 64 --train data/rk_data.csv.gz
|
python3 main.py --mode hyperband --batch 64 --train data/rk_data.csv.gz
|
||||||
|
@ -235,7 +235,7 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
|||||||
check_h5dataset(h5data)
|
check_h5dataset(h5data)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.info("load raw training dataset")
|
logger.info("load raw training dataset")
|
||||||
domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(h5data + "_raw", train_data,
|
domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(h5data, train_data,
|
||||||
domain_length, window_size)
|
domain_length, window_size)
|
||||||
logger.info("filter training dataset")
|
logger.info("filter training dataset")
|
||||||
domain, flow, name, client, server = filter_window_dataset_by_hits(domain.value, flow.value,
|
domain, flow, name, client, server = filter_window_dataset_by_hits(domain.value, flow.value,
|
||||||
@ -256,6 +256,7 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
|||||||
|
|
||||||
|
|
||||||
def load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size):
|
def load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size):
|
||||||
|
h5data = h5data + "_raw"
|
||||||
char_dict = get_character_dict()
|
char_dict = get_character_dict()
|
||||||
logger.info(f"check for h5data {h5data}")
|
logger.info(f"check for h5data {h5data}")
|
||||||
try:
|
try:
|
||||||
|
88
main.py
88
main.py
@ -276,11 +276,18 @@ def main_test():
|
|||||||
|
|
||||||
|
|
||||||
def main_visualization():
|
def main_visualization():
|
||||||
domain_val, flow_val, name_val, client_val, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data,
|
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data,
|
||||||
args.test_data,
|
args.test_data,
|
||||||
args.domain_length,
|
args.domain_length,
|
||||||
args.window)
|
args.window)
|
||||||
client_val = client_val.value
|
|
||||||
|
results = dataset.load_predictions(args.model_path)
|
||||||
|
df = pd.DataFrame(data={
|
||||||
|
"names": name_val, "client_pred": results["client_pred"].flatten(),
|
||||||
|
"hits_vt": hits_vt, "hits_trusted": hits_trusted
|
||||||
|
})
|
||||||
|
df["client_val"] = np.logical_or(df.hits_vt == 1.0, df.hits_trusted >= 3)
|
||||||
|
df_user = df.groupby(df.names).max()
|
||||||
|
|
||||||
logger.info("plot model")
|
logger.info("plot model")
|
||||||
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||||
@ -296,95 +303,86 @@ def main_visualization():
|
|||||||
else:
|
else:
|
||||||
logger.warning("Error while plotting training curves")
|
logger.warning("Error while plotting training curves")
|
||||||
|
|
||||||
results = dataset.load_predictions(args.future_prediction)
|
|
||||||
client_pred = results["client_pred"].flatten()
|
|
||||||
|
|
||||||
logger.info("plot pr curve")
|
logger.info("plot pr curve")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_precision_recall(client_val, client_pred, args.model_path)
|
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), args.model_path)
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save("{}/window_client_prc.png".format(args.model_path))
|
visualize.plot_save("{}/window_client_prc.png".format(args.model_path))
|
||||||
|
|
||||||
logger.info("plot roc curve")
|
logger.info("plot roc curve")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_roc_curve(client_val, client_pred, args.model_path)
|
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), args.model_path)
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save("{}/window_client_roc.png".format(args.model_path))
|
visualize.plot_save("{}/window_client_roc.png".format(args.model_path))
|
||||||
|
|
||||||
print(f"names {name_val.shape} vals {client_val.shape} preds {client_pred.shape}")
|
|
||||||
|
|
||||||
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
|
|
||||||
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
|
|
||||||
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred})
|
|
||||||
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
|
||||||
|
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_precision_recall(user_vals, user_preds, args.model_path)
|
visualize.plot_precision_recall(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix(), args.model_path)
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save("{}/user_client_prc.png".format(args.model_path))
|
visualize.plot_save("{}/user_client_prc.png".format(args.model_path))
|
||||||
|
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_roc_curve(user_vals, user_preds, args.model_path)
|
visualize.plot_roc_curve(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix(), args.model_path)
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save("{}/user_client_roc.png".format(args.model_path))
|
visualize.plot_save("{}/user_client_roc.png".format(args.model_path))
|
||||||
|
|
||||||
visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(),
|
visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(),
|
||||||
"{}/client_cov.png".format(args.model_path),
|
"{}/client_cov.png".format(args.model_path),
|
||||||
normalize=False, title="Client Confusion Matrix")
|
normalize=False, title="Client Confusion Matrix")
|
||||||
visualize.plot_confusion_matrix(user_vals, user_preds.flatten().round(),
|
visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(),
|
||||||
"{}/user_cov.png".format(args.model_path),
|
"{}/user_cov.png".format(args.model_path),
|
||||||
normalize=False, title="User Confusion Matrix")
|
normalize=False, title="User Confusion Matrix")
|
||||||
logger.info("visualize embedding")
|
logger.info("visualize embedding")
|
||||||
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||||
domain_embedding = np.load(args.model_path + "/domain_embds.npy")
|
domain_embedding = results["domain_embds"]
|
||||||
visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
|
visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
|
||||||
|
|
||||||
|
|
||||||
def main_visualize_all():
|
def main_visualize_all():
|
||||||
domain_val, flow_val, name_val, client_val, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data,
|
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data,
|
||||||
args.test_data,
|
args.test_data,
|
||||||
args.domain_length,
|
args.domain_length,
|
||||||
args.window)
|
args.window)
|
||||||
|
|
||||||
|
def load_df(path):
|
||||||
|
res = dataset.load_predictions(path)
|
||||||
|
res = pd.DataFrame(data={
|
||||||
|
"names": name_val, "client_pred": res["client_pred"].flatten(),
|
||||||
|
"hits_vt": hits_vt, "hits_trusted": hits_trusted
|
||||||
|
})
|
||||||
|
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
|
||||||
|
return res
|
||||||
|
|
||||||
logger.info("plot pr curves")
|
logger.info("plot pr curves")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
results = dataset.load_predictions(model_args["future_prediction"])
|
df = load_df(model_args["model_path"])
|
||||||
client_pred = results["client_pred"].flatten()
|
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_path"])
|
||||||
visualize.plot_precision_recall(client_val, client_pred, model_args["model_path"])
|
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{args.output_prefix}_window_client_prc.png")
|
visualize.plot_save(f"{args.output_prefix}_window_client_prc.png")
|
||||||
|
|
||||||
logger.info("plot roc curves")
|
logger.info("plot roc curves")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
results = dataset.load_predictions(model_args["future_prediction"])
|
df = load_df(model_args["model_path"])
|
||||||
client_pred = results["client_pred"].flatten()
|
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_path"])
|
||||||
visualize.plot_roc_curve(client_val, client_pred, model_args["model_path"])
|
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{args.output_prefix}_window_client_roc.png")
|
visualize.plot_save(f"{args.output_prefix}_window_client_roc.png")
|
||||||
|
|
||||||
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
|
|
||||||
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
|
|
||||||
|
|
||||||
logger.info("plot user pr curves")
|
logger.info("plot user pr curves")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
results = dataset.load_predictions(model_args["future_prediction"])
|
df = load_df(model_args["model_path"])
|
||||||
client_pred = results["client_pred"].flatten()
|
df = df.groupby(df.names).max()
|
||||||
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred})
|
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_path"])
|
||||||
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
|
||||||
visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"])
|
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{args.output_prefix}_user_client_prc.png")
|
visualize.plot_save(f"{args.output_prefix}_user_client_prc.png")
|
||||||
|
|
||||||
logger.info("plot user roc curves")
|
logger.info("plot user roc curves")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
results = dataset.load_predictions(model_args["future_prediction"])
|
df = load_df(model_args["model_path"])
|
||||||
client_pred = results["client_pred"].flatten()
|
df = df.groupby(df.names).max()
|
||||||
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred})
|
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_path"])
|
||||||
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
|
||||||
visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"])
|
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")
|
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user