add matplotlib agg mode; update beta vis function according to test results
This commit is contained in:
parent
4fc2f0c925
commit
27f4d086eb
74
main.py
74
main.py
@ -610,19 +610,21 @@ def main_beta():
|
||||
domain_val, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
|
||||
path, model_prefix = os.path.split(os.path.normpath(args.model_path))
|
||||
print(path, model_prefix)
|
||||
try:
|
||||
results = joblib.load(f"{path}/curves.joblib")
|
||||
curves = joblib.load(f"{path}/curves.joblib")
|
||||
logger.info(f"load file {path}/curves.joblib successfully")
|
||||
except Exception:
|
||||
results = {}
|
||||
results[model_prefix] = {"all": {}}
|
||||
curves = {}
|
||||
logger.info(f"currently {len(curves)} models in file: {curves.keys()}")
|
||||
curves[model_prefix] = {"all": {}}
|
||||
|
||||
domains = domain_val.value.reshape(-1, 40)
|
||||
domains = np.apply_along_axis(lambda d: "".join(map(dataset.decode_char, d)), 1, domains)
|
||||
domains = np.apply_along_axis(lambda d: dataset.decode_domain(d), 1, domains)
|
||||
|
||||
def load_df(path):
|
||||
def load_df(res):
|
||||
df_server = None
|
||||
res = dataset.load_predictions(path)
|
||||
data = {
|
||||
"names": name_val, "client_pred": res["client_pred"].flatten(),
|
||||
"hits_vt": hits_vt, "hits_trusted": hits_trusted,
|
||||
@ -646,6 +648,9 @@ def main_beta():
|
||||
|
||||
return res, df_server
|
||||
|
||||
res = dataset.load_predictions(path)
|
||||
model_keys = sorted(filter(lambda x: x.startswith("clf"), res.keys()), key=lambda x: int(x[4:-3]))
|
||||
|
||||
client_preds = []
|
||||
server_preds = []
|
||||
server_flow_preds = []
|
||||
@ -653,8 +658,8 @@ def main_beta():
|
||||
server_user_preds = []
|
||||
server_domain_preds = []
|
||||
server_domain_avg_preds = []
|
||||
for model_args in get_model_args(args):
|
||||
df, df_server = load_df(model_args["model_path"])
|
||||
for model_name in model_keys:
|
||||
df, df_server = load_df(res[model_name])
|
||||
client_preds.append(df.client_pred.as_matrix())
|
||||
if "server_val" in df.columns:
|
||||
server_preds.append(df.server_pred.as_matrix())
|
||||
@ -665,55 +670,56 @@ def main_beta():
|
||||
df_domain_avg = df_server.groupby(df_server.domain).rolling(10).mean()
|
||||
server_domain_avg_preds.append(df_domain_avg.server_pred.as_matrix())
|
||||
|
||||
results[model_prefix][model_args["model_name"]] = confusion_matrix(df.client_val.as_matrix(),
|
||||
df.client_pred.as_matrix().round())
|
||||
curves[model_prefix][model_name] = confusion_matrix(df.client_val.as_matrix(),
|
||||
df.client_pred.as_matrix().round())
|
||||
df_user = df.groupby(df.names).max()
|
||||
client_user_preds.append(df_user.client_pred.as_matrix())
|
||||
if "server_val" in df.columns:
|
||||
server_user_preds.append(df_user.server_pred.as_matrix())
|
||||
|
||||
logger.info("plot client curves")
|
||||
results[model_prefix]["all"]["client_window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), client_preds)
|
||||
results[model_prefix]["all"]["client_window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), client_preds)
|
||||
results[model_prefix]["all"]["client_user_prc"] = visualize.calc_pr_mean(df_user.client_val.as_matrix(),
|
||||
curves[model_prefix]["all"]["client_window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), client_preds)
|
||||
curves[model_prefix]["all"]["client_window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), client_preds)
|
||||
curves[model_prefix]["all"]["client_user_prc"] = visualize.calc_pr_mean(df_user.client_val.as_matrix(),
|
||||
client_user_preds)
|
||||
curves[model_prefix]["all"]["client_user_roc"] = visualize.calc_roc_mean(df_user.client_val.as_matrix(),
|
||||
client_user_preds)
|
||||
results[model_prefix]["all"]["client_user_roc"] = visualize.calc_roc_mean(df_user.client_val.as_matrix(),
|
||||
client_user_preds)
|
||||
|
||||
if "server_val" in df.columns:
|
||||
logger.info("plot server curves")
|
||||
results[model_prefix]["all"]["server_window_prc"] = visualize.calc_pr_mean(df.server_val.as_matrix(),
|
||||
curves[model_prefix]["all"]["server_window_prc"] = visualize.calc_pr_mean(df.server_val.as_matrix(),
|
||||
server_preds)
|
||||
curves[model_prefix]["all"]["server_window_roc"] = visualize.calc_roc_mean(df.server_val.as_matrix(),
|
||||
server_preds)
|
||||
results[model_prefix]["all"]["server_window_roc"] = visualize.calc_roc_mean(df.server_val.as_matrix(),
|
||||
server_preds)
|
||||
results[model_prefix]["all"]["server_user_prc"] = visualize.calc_pr_mean(df_user.server_val.as_matrix(),
|
||||
server_user_preds)
|
||||
curves[model_prefix]["all"]["server_user_prc"] = visualize.calc_pr_mean(df_user.server_val.as_matrix(),
|
||||
server_user_preds)
|
||||
|
||||
results[model_prefix]["all"]["server_user_roc"] = visualize.calc_roc_mean(df_user.server_val.as_matrix(),
|
||||
server_user_preds)
|
||||
curves[model_prefix]["all"]["server_user_roc"] = visualize.calc_roc_mean(df_user.server_val.as_matrix(),
|
||||
server_user_preds)
|
||||
|
||||
if df_server is not None:
|
||||
logger.info("plot server flow curves")
|
||||
results[model_prefix]["all"]["server_flow_prc"] = visualize.calc_pr_mean(df_server.server_val.as_matrix(),
|
||||
curves[model_prefix]["all"]["server_flow_prc"] = visualize.calc_pr_mean(df_server.server_val.as_matrix(),
|
||||
server_flow_preds)
|
||||
curves[model_prefix]["all"]["server_flow_roc"] = visualize.calc_roc_mean(df_server.server_val.as_matrix(),
|
||||
server_flow_preds)
|
||||
results[model_prefix]["all"]["server_flow_roc"] = visualize.calc_roc_mean(df_server.server_val.as_matrix(),
|
||||
server_flow_preds)
|
||||
results[model_prefix]["all"]["server_domain_prc"] = visualize.calc_pr_mean(df_domain.server_val.as_matrix(),
|
||||
curves[model_prefix]["all"]["server_domain_prc"] = visualize.calc_pr_mean(df_domain.server_val.as_matrix(),
|
||||
server_domain_preds)
|
||||
curves[model_prefix]["all"]["server_domain_roc"] = visualize.calc_roc_mean(df_domain.server_val.as_matrix(),
|
||||
server_domain_preds)
|
||||
results[model_prefix]["all"]["server_domain_roc"] = visualize.calc_roc_mean(df_domain.server_val.as_matrix(),
|
||||
server_domain_preds)
|
||||
results[model_prefix]["all"]["server_domain_avg_prc"] = visualize.calc_pr_mean(
|
||||
curves[model_prefix]["all"]["server_domain_avg_prc"] = visualize.calc_pr_mean(
|
||||
df_domain_avg.server_val.as_matrix(),
|
||||
server_domain_avg_preds)
|
||||
results[model_prefix]["all"]["server_domain_avg_roc"] = visualize.calc_roc_mean(
|
||||
curves[model_prefix]["all"]["server_domain_avg_roc"] = visualize.calc_roc_mean(
|
||||
df_domain_avg.server_val.as_matrix(),
|
||||
server_domain_avg_preds)
|
||||
|
||||
joblib.dump(results, f"{path}/curves.joblib")
|
||||
|
||||
# plot_overall_result()
|
||||
joblib.dump(curves, f"{path}/curves.joblib")
|
||||
|
||||
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
|
@ -1,6 +1,10 @@
|
||||
import os
|
||||
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
|
Loading…
Reference in New Issue
Block a user