update plotting function according the test and beta results
This commit is contained in:
parent
c19d649bc4
commit
3ce385eca6
13
main.py
13
main.py
@ -629,7 +629,6 @@ def main_beta():
|
||||
"hits_vt": hits_vt, "hits_trusted": hits_trusted,
|
||||
}
|
||||
if "server_pred" in res:
|
||||
print(res["server_pred"].shape, server_val.value.shape)
|
||||
server = res["server_pred"] if len(res["server_pred"].shape) == 2 else res["server_pred"].max(axis=1)
|
||||
val = server_val.value.max(axis=1)
|
||||
data["server_pred"] = server.flatten()
|
||||
@ -679,8 +678,8 @@ def main_beta():
|
||||
client_user_preds.append(df_user.client_pred.as_matrix())
|
||||
if "server_val" in df.columns:
|
||||
server_user_preds.append(df_user.server_pred.as_matrix())
|
||||
|
||||
logger.info("plot client curves")
|
||||
|
||||
logger.info("compute client curves")
|
||||
curves[model_prefix]["all"]["client_window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), client_preds)
|
||||
curves[model_prefix]["all"]["client_window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), client_preds)
|
||||
curves[model_prefix]["all"]["client_user_prc"] = visualize.calc_pr_mean(df_user.client_val.as_matrix(),
|
||||
@ -689,7 +688,7 @@ def main_beta():
|
||||
client_user_preds)
|
||||
|
||||
if "server_val" in df.columns:
|
||||
logger.info("plot server curves")
|
||||
logger.info("compute server curves")
|
||||
curves[model_prefix]["all"]["server_window_prc"] = visualize.calc_pr_mean(df.server_val.as_matrix(),
|
||||
server_preds)
|
||||
curves[model_prefix]["all"]["server_window_roc"] = visualize.calc_roc_mean(df.server_val.as_matrix(),
|
||||
@ -701,7 +700,7 @@ def main_beta():
|
||||
server_user_preds)
|
||||
|
||||
if df_server is not None:
|
||||
logger.info("plot server flow curves")
|
||||
logger.info("compute server flow curves")
|
||||
curves[model_prefix]["all"]["server_flow_prc"] = visualize.calc_pr_mean(df_server.server_val.as_matrix(),
|
||||
server_flow_preds)
|
||||
curves[model_prefix]["all"]["server_flow_roc"] = visualize.calc_roc_mean(df_server.server_val.as_matrix(),
|
||||
@ -727,9 +726,11 @@ import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def plot_overall_result():
|
||||
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
|
||||
path, model_prefix = os.path.split(os.path.normpath(args.model_path))
|
||||
exists_or_make_path(f"{path}/figs/curves/")
|
||||
try:
|
||||
results = joblib.load(f"{path}/curves.joblib")
|
||||
logger.info("curves successfully loaded")
|
||||
except Exception:
|
||||
results = {}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user