update plotting function according the test and beta results
This commit is contained in:
parent
c19d649bc4
commit
3ce385eca6
11
main.py
11
main.py
@ -629,7 +629,6 @@ def main_beta():
|
|||||||
"hits_vt": hits_vt, "hits_trusted": hits_trusted,
|
"hits_vt": hits_vt, "hits_trusted": hits_trusted,
|
||||||
}
|
}
|
||||||
if "server_pred" in res:
|
if "server_pred" in res:
|
||||||
print(res["server_pred"].shape, server_val.value.shape)
|
|
||||||
server = res["server_pred"] if len(res["server_pred"].shape) == 2 else res["server_pred"].max(axis=1)
|
server = res["server_pred"] if len(res["server_pred"].shape) == 2 else res["server_pred"].max(axis=1)
|
||||||
val = server_val.value.max(axis=1)
|
val = server_val.value.max(axis=1)
|
||||||
data["server_pred"] = server.flatten()
|
data["server_pred"] = server.flatten()
|
||||||
@ -680,7 +679,7 @@ def main_beta():
|
|||||||
if "server_val" in df.columns:
|
if "server_val" in df.columns:
|
||||||
server_user_preds.append(df_user.server_pred.as_matrix())
|
server_user_preds.append(df_user.server_pred.as_matrix())
|
||||||
|
|
||||||
logger.info("plot client curves")
|
logger.info("compute client curves")
|
||||||
curves[model_prefix]["all"]["client_window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), client_preds)
|
curves[model_prefix]["all"]["client_window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), client_preds)
|
||||||
curves[model_prefix]["all"]["client_window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), client_preds)
|
curves[model_prefix]["all"]["client_window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), client_preds)
|
||||||
curves[model_prefix]["all"]["client_user_prc"] = visualize.calc_pr_mean(df_user.client_val.as_matrix(),
|
curves[model_prefix]["all"]["client_user_prc"] = visualize.calc_pr_mean(df_user.client_val.as_matrix(),
|
||||||
@ -689,7 +688,7 @@ def main_beta():
|
|||||||
client_user_preds)
|
client_user_preds)
|
||||||
|
|
||||||
if "server_val" in df.columns:
|
if "server_val" in df.columns:
|
||||||
logger.info("plot server curves")
|
logger.info("compute server curves")
|
||||||
curves[model_prefix]["all"]["server_window_prc"] = visualize.calc_pr_mean(df.server_val.as_matrix(),
|
curves[model_prefix]["all"]["server_window_prc"] = visualize.calc_pr_mean(df.server_val.as_matrix(),
|
||||||
server_preds)
|
server_preds)
|
||||||
curves[model_prefix]["all"]["server_window_roc"] = visualize.calc_roc_mean(df.server_val.as_matrix(),
|
curves[model_prefix]["all"]["server_window_roc"] = visualize.calc_roc_mean(df.server_val.as_matrix(),
|
||||||
@ -701,7 +700,7 @@ def main_beta():
|
|||||||
server_user_preds)
|
server_user_preds)
|
||||||
|
|
||||||
if df_server is not None:
|
if df_server is not None:
|
||||||
logger.info("plot server flow curves")
|
logger.info("compute server flow curves")
|
||||||
curves[model_prefix]["all"]["server_flow_prc"] = visualize.calc_pr_mean(df_server.server_val.as_matrix(),
|
curves[model_prefix]["all"]["server_flow_prc"] = visualize.calc_pr_mean(df_server.server_val.as_matrix(),
|
||||||
server_flow_preds)
|
server_flow_preds)
|
||||||
curves[model_prefix]["all"]["server_flow_roc"] = visualize.calc_roc_mean(df_server.server_val.as_matrix(),
|
curves[model_prefix]["all"]["server_flow_roc"] = visualize.calc_roc_mean(df_server.server_val.as_matrix(),
|
||||||
@ -727,9 +726,11 @@ import matplotlib.pyplot as plt
|
|||||||
|
|
||||||
|
|
||||||
def plot_overall_result():
|
def plot_overall_result():
|
||||||
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
|
path, model_prefix = os.path.split(os.path.normpath(args.model_path))
|
||||||
|
exists_or_make_path(f"{path}/figs/curves/")
|
||||||
try:
|
try:
|
||||||
results = joblib.load(f"{path}/curves.joblib")
|
results = joblib.load(f"{path}/curves.joblib")
|
||||||
|
logger.info("curves successfully loaded")
|
||||||
except Exception:
|
except Exception:
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user