add window to file names for visualization
This commit is contained in:
parent
595c2ea894
commit
1ab0108c78
8
main.py
8
main.py
@ -303,14 +303,14 @@ def main_visualization():
|
||||
logger.info("plot pr curve")
|
||||
visualize.plot_clf()
|
||||
visualize.plot_precision_recall(client_val, client_pred)
|
||||
visualize.plot_save("{}/client_prc.png".format(args.model_path))
|
||||
visualize.plot_save("{}/window_client_prc.png".format(args.model_path))
|
||||
# visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path))
|
||||
# visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path))
|
||||
# visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path))
|
||||
logger.info("plot roc curve")
|
||||
visualize.plot_clf()
|
||||
visualize.plot_roc_curve(client_val, client_pred)
|
||||
visualize.plot_save("{}/client_roc.png".format(args.model_path))
|
||||
visualize.plot_save("{}/window_client_roc.png".format(args.model_path))
|
||||
# visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path))
|
||||
|
||||
print(f"names {name_val.shape} vals {client_val.shape} preds {client_pred.shape}")
|
||||
@ -351,7 +351,7 @@ def main_visualize_all():
|
||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||
visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"])
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save(f"{args.output_prefix}_client_prc.png")
|
||||
visualize.plot_save(f"{args.output_prefix}_window_client_prc.png")
|
||||
|
||||
logger.info("plot roc curves")
|
||||
visualize.plot_clf()
|
||||
@ -359,7 +359,7 @@ def main_visualize_all():
|
||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||
visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"])
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save(f"{args.output_prefix}_client_roc.png")
|
||||
visualize.plot_save(f"{args.output_prefix}_window_client_roc.png")
|
||||
|
||||
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
|
||||
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
|
||||
|
Loading…
x
Reference in New Issue
Block a user