add matplotlib agg mode; update beta vis function according to test results

This commit is contained in:
René Knaebel 2017-11-10 11:04:27 +01:00
parent 4fc2f0c925
commit 27f4d086eb
2 changed files with 46 additions and 36 deletions

74
main.py
View File

@ -610,19 +610,21 @@ def main_beta():
domain_val, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
args.domain_length,
args.window)
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
path, model_prefix = os.path.split(os.path.normpath(args.model_path))
print(path, model_prefix)
try:
results = joblib.load(f"{path}/curves.joblib")
curves = joblib.load(f"{path}/curves.joblib")
logger.info(f"load file {path}/curves.joblib successfully")
except Exception:
results = {}
results[model_prefix] = {"all": {}}
curves = {}
logger.info(f"currently {len(curves)} models in file: {curves.keys()}")
curves[model_prefix] = {"all": {}}
domains = domain_val.value.reshape(-1, 40)
domains = np.apply_along_axis(lambda d: "".join(map(dataset.decode_char, d)), 1, domains)
domains = np.apply_along_axis(lambda d: dataset.decode_domain(d), 1, domains)
def load_df(path):
def load_df(res):
df_server = None
res = dataset.load_predictions(path)
data = {
"names": name_val, "client_pred": res["client_pred"].flatten(),
"hits_vt": hits_vt, "hits_trusted": hits_trusted,
@ -646,6 +648,9 @@ def main_beta():
return res, df_server
res = dataset.load_predictions(path)
model_keys = sorted(filter(lambda x: x.startswith("clf"), res.keys()), key=lambda x: int(x[4:-3]))
client_preds = []
server_preds = []
server_flow_preds = []
@ -653,8 +658,8 @@ def main_beta():
server_user_preds = []
server_domain_preds = []
server_domain_avg_preds = []
for model_args in get_model_args(args):
df, df_server = load_df(model_args["model_path"])
for model_name in model_keys:
df, df_server = load_df(res[model_name])
client_preds.append(df.client_pred.as_matrix())
if "server_val" in df.columns:
server_preds.append(df.server_pred.as_matrix())
@ -665,55 +670,56 @@ def main_beta():
df_domain_avg = df_server.groupby(df_server.domain).rolling(10).mean()
server_domain_avg_preds.append(df_domain_avg.server_pred.as_matrix())
results[model_prefix][model_args["model_name"]] = confusion_matrix(df.client_val.as_matrix(),
df.client_pred.as_matrix().round())
curves[model_prefix][model_name] = confusion_matrix(df.client_val.as_matrix(),
df.client_pred.as_matrix().round())
df_user = df.groupby(df.names).max()
client_user_preds.append(df_user.client_pred.as_matrix())
if "server_val" in df.columns:
server_user_preds.append(df_user.server_pred.as_matrix())
logger.info("plot client curves")
results[model_prefix]["all"]["client_window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), client_preds)
results[model_prefix]["all"]["client_window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), client_preds)
results[model_prefix]["all"]["client_user_prc"] = visualize.calc_pr_mean(df_user.client_val.as_matrix(),
curves[model_prefix]["all"]["client_window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), client_preds)
curves[model_prefix]["all"]["client_window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), client_preds)
curves[model_prefix]["all"]["client_user_prc"] = visualize.calc_pr_mean(df_user.client_val.as_matrix(),
client_user_preds)
curves[model_prefix]["all"]["client_user_roc"] = visualize.calc_roc_mean(df_user.client_val.as_matrix(),
client_user_preds)
results[model_prefix]["all"]["client_user_roc"] = visualize.calc_roc_mean(df_user.client_val.as_matrix(),
client_user_preds)
if "server_val" in df.columns:
logger.info("plot server curves")
results[model_prefix]["all"]["server_window_prc"] = visualize.calc_pr_mean(df.server_val.as_matrix(),
curves[model_prefix]["all"]["server_window_prc"] = visualize.calc_pr_mean(df.server_val.as_matrix(),
server_preds)
curves[model_prefix]["all"]["server_window_roc"] = visualize.calc_roc_mean(df.server_val.as_matrix(),
server_preds)
results[model_prefix]["all"]["server_window_roc"] = visualize.calc_roc_mean(df.server_val.as_matrix(),
server_preds)
results[model_prefix]["all"]["server_user_prc"] = visualize.calc_pr_mean(df_user.server_val.as_matrix(),
server_user_preds)
curves[model_prefix]["all"]["server_user_prc"] = visualize.calc_pr_mean(df_user.server_val.as_matrix(),
server_user_preds)
results[model_prefix]["all"]["server_user_roc"] = visualize.calc_roc_mean(df_user.server_val.as_matrix(),
server_user_preds)
curves[model_prefix]["all"]["server_user_roc"] = visualize.calc_roc_mean(df_user.server_val.as_matrix(),
server_user_preds)
if df_server is not None:
logger.info("plot server flow curves")
results[model_prefix]["all"]["server_flow_prc"] = visualize.calc_pr_mean(df_server.server_val.as_matrix(),
curves[model_prefix]["all"]["server_flow_prc"] = visualize.calc_pr_mean(df_server.server_val.as_matrix(),
server_flow_preds)
curves[model_prefix]["all"]["server_flow_roc"] = visualize.calc_roc_mean(df_server.server_val.as_matrix(),
server_flow_preds)
results[model_prefix]["all"]["server_flow_roc"] = visualize.calc_roc_mean(df_server.server_val.as_matrix(),
server_flow_preds)
results[model_prefix]["all"]["server_domain_prc"] = visualize.calc_pr_mean(df_domain.server_val.as_matrix(),
curves[model_prefix]["all"]["server_domain_prc"] = visualize.calc_pr_mean(df_domain.server_val.as_matrix(),
server_domain_preds)
curves[model_prefix]["all"]["server_domain_roc"] = visualize.calc_roc_mean(df_domain.server_val.as_matrix(),
server_domain_preds)
results[model_prefix]["all"]["server_domain_roc"] = visualize.calc_roc_mean(df_domain.server_val.as_matrix(),
server_domain_preds)
results[model_prefix]["all"]["server_domain_avg_prc"] = visualize.calc_pr_mean(
curves[model_prefix]["all"]["server_domain_avg_prc"] = visualize.calc_pr_mean(
df_domain_avg.server_val.as_matrix(),
server_domain_avg_preds)
results[model_prefix]["all"]["server_domain_avg_roc"] = visualize.calc_roc_mean(
curves[model_prefix]["all"]["server_domain_avg_roc"] = visualize.calc_roc_mean(
df_domain_avg.server_val.as_matrix(),
server_domain_avg_preds)
joblib.dump(results, f"{path}/curves.joblib")
# plot_overall_result()
joblib.dump(curves, f"{path}/curves.joblib")
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt

View File

@ -1,6 +1,10 @@
import os
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns