add window to file names for visualization
This commit is contained in:
parent
595c2ea894
commit
1ab0108c78
8
main.py
8
main.py
@ -303,14 +303,14 @@ def main_visualization():
|
|||||||
logger.info("plot pr curve")
|
logger.info("plot pr curve")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_precision_recall(client_val, client_pred)
|
visualize.plot_precision_recall(client_val, client_pred)
|
||||||
visualize.plot_save("{}/client_prc.png".format(args.model_path))
|
visualize.plot_save("{}/window_client_prc.png".format(args.model_path))
|
||||||
# visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path))
|
# visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path))
|
||||||
# visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path))
|
# visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path))
|
||||||
# visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path))
|
# visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path))
|
||||||
logger.info("plot roc curve")
|
logger.info("plot roc curve")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_roc_curve(client_val, client_pred)
|
visualize.plot_roc_curve(client_val, client_pred)
|
||||||
visualize.plot_save("{}/client_roc.png".format(args.model_path))
|
visualize.plot_save("{}/window_client_roc.png".format(args.model_path))
|
||||||
# visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path))
|
# visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path))
|
||||||
|
|
||||||
print(f"names {name_val.shape} vals {client_val.shape} preds {client_pred.shape}")
|
print(f"names {name_val.shape} vals {client_val.shape} preds {client_pred.shape}")
|
||||||
@ -351,7 +351,7 @@ def main_visualize_all():
|
|||||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||||
visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"])
|
visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"])
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{args.output_prefix}_client_prc.png")
|
visualize.plot_save(f"{args.output_prefix}_window_client_prc.png")
|
||||||
|
|
||||||
logger.info("plot roc curves")
|
logger.info("plot roc curves")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
@ -359,7 +359,7 @@ def main_visualize_all():
|
|||||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||||
visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"])
|
visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"])
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{args.output_prefix}_client_roc.png")
|
visualize.plot_save(f"{args.output_prefix}_window_client_roc.png")
|
||||||
|
|
||||||
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
|
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
|
||||||
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
|
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user