add first version of model averaging visualization

This commit is contained in:
René Knaebel 2017-09-26 19:25:37 +02:00
parent 49ad506a96
commit b157ca6a19
6 changed files with 199 additions and 9 deletions

View File

@ -113,6 +113,7 @@ def get_model_args(args):
def parse():
args = parser.parse_args()
args.result_path = os.path.split(os.path.normpath(args.output_prefix))[1]
args.model_name = os.path.split(os.path.normpath(args.model_path))[1]
args.embedding_model = os.path.join(args.model_path, "embd.h5")
args.clf_model = os.path.join(args.model_path, "clf.h5")

107
main.py
View File

@ -408,6 +408,111 @@ def main_visualize_all():
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")
import joblib
def main_beta():
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data,
args.test_data,
args.domain_length,
args.window)
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
try:
results = joblib.load(f"{path}/curves.joblib")
except Exception:
results = {}
results[model_prefix] = {}
def load_df(path):
res = dataset.load_predictions(path)
res = pd.DataFrame(data={
"names": name_val, "client_pred": res["client_pred"].flatten(),
"hits_vt": hits_vt, "hits_trusted": hits_trusted
})
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
return res
paul = dataset.load_predictions("results/paul/")
df_paul = pd.DataFrame(data={
"names": paul["testNames"].flatten(), "client_pred": paul["testScores"].flatten(),
"hits_vt": paul["testLabel"].flatten(), "hits_trusted": paul["testHits"].flatten()
})
df_paul["client_val"] = np.logical_or(df_paul.hits_vt == 1.0, df_paul.hits_trusted >= 3)
df_paul_user = df_paul.groupby(df_paul.names).max()
logger.info("plot pr curves")
visualize.plot_clf()
predictions = []
for model_args in get_model_args(args):
df = load_df(model_args["model_path"])
predictions.append(df.client_pred.as_matrix())
results[model_prefix]["window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), predictions)
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
visualize.plot_pr_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_prc_all.png")
logger.info("plot roc curves")
visualize.plot_clf()
predictions = []
for model_args in get_model_args(args):
df = load_df(model_args["model_path"])
predictions.append(df.client_pred.as_matrix())
results[model_prefix]["window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), predictions)
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
visualize.plot_roc_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_roc_all.png")
logger.info("plot user pr curves")
visualize.plot_clf()
predictions = []
for model_args in get_model_args(args):
df = load_df(model_args["model_path"])
df = df.groupby(df.names).max()
predictions.append(df.client_pred.as_matrix())
results[model_prefix]["user_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), predictions)
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
visualize.plot_pr_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_user_client_prc_all.png")
logger.info("plot user roc curves")
visualize.plot_clf()
predictions = []
for model_args in get_model_args(args):
df = load_df(model_args["model_path"])
df = df.groupby(df.names).max()
predictions.append(df.client_pred.as_matrix())
results[model_prefix]["user_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), predictions)
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
visualize.plot_roc_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_user_client_roc_all.png")
joblib.dump(results, f"{path}/curves.joblib")
import matplotlib.pyplot as plt
x = np.linspace(0, 1, 10000)
for vis in ["window_prc", "window_roc", "user_prc", "user_roc"]:
logger.info(f"plot {vis}")
visualize.plot_clf()
for model_key in results.keys():
ys_mean, ys_std, score = results[model_key][vis]
plt.plot(x, ys_mean, label=f"{model_key} - {score:5.4}")
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1)
if vis.endswith("prc"):
plt.xlabel('Recall')
plt.ylabel('Precision')
else:
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.ylim([0.0, 1.0])
plt.xlim([0.0, 1.0])
visualize.plot_legend()
visualize.plot_save(f"{path}/{vis}_all.png")
def main():
if "train" == args.mode:
main_train()
@ -423,6 +528,8 @@ def main():
plot_embedding()
if "paul" == args.mode:
main_paul_best()
if "beta" == args.mode:
main_beta()
if __name__ == "__main__":

2
run.sh
View File

@ -5,7 +5,7 @@ RESDIR=$1
mkdir -p /tmp/rk/${RESDIR}
DATADIR=$2
EPOCHS=100
EPOCHS=10
for output in client both
do

View File

@ -10,7 +10,7 @@ RESDIR=$6
mkdir -p /tmp/rk/${RESDIR}
DATADIR=$7
EPOCHS=100
EPOCHS=10
for ((i = ${N1}; i <= ${N2}; i++))
do
@ -25,5 +25,6 @@ do
--batch 128 \
--model_output ${OUTPUT} \
--type ${TYPE} \
--depth ${DEPTH}
--depth ${DEPTH} \
--gpu
done

30
run_model_rene.sh Normal file
View File

@ -0,0 +1,30 @@
#!/usr/bin/env bash
N1=$1
N2=$2
OUTPUT=$3
DEPTH=$4
TYPE=$5
RESDIR=$6
mkdir -p /tmp/rk/${RESDIR}
DATADIR=$7
EPOCHS=10
for ((i = ${N1}; i <= ${N2}; i++))
do
python main.py --mode train \
--train ${DATADIR} \
--model ${RESDIR}/${OUTPUT}_${TYPE}_${i} \
--epochs ${EPOCHS} \
--embd 64 \
--filter_embd 128 --kernel_embd 5 --dense_embd 64 \
--domain_embd 16 \
--filter_main 32 --kernel_main 5 --dense_main 256 \
--batch 128 \
--model_output ${OUTPUT} \
--type ${TYPE} \
--depth ${DEPTH} \
--gpu
done

View File

@ -2,6 +2,7 @@ import os
import matplotlib.pyplot as plt
import numpy as np
from scipy import interpolate
from sklearn.decomposition import TruncatedSVD
from sklearn.manifold import TSNE
from sklearn.metrics import (
@ -65,13 +66,32 @@ def plot_precision_recall(y, y_pred, label=""):
plt.xlim([0.0, 1.0])
def plot_pr_curves(y, y_preds, label=""):
for idx, y in enumerate(y_preds):
def calc_pr_mean(y, y_preds):
appr = []
scores = []
y = y.flatten()
for idx, y_pred in enumerate(y_preds):
y_pred = y_pred.flatten()
precision, recall, thresholds = precision_recall_curve(y, y_pred)
score = fbeta_score(y, y_pred.round(), 1)
plt.plot(recall, precision, '--', label=f"{idx}{label} - {score:5.4}")
appr.append(interpolate.interp1d(recall, precision))
scores.append(fbeta_score(y, y_pred.round(), 1))
x = np.linspace(0, 1, 10000)
ys = np.vstack([f(x) for f in appr])
ys_mean = ys.mean(axis=0)
ys_std = ys.std(axis=0)
scores_mean = np.mean(scores)
return ys_mean, ys_std, scores_mean
def plot_pr_mean(y, y_preds, label=""):
x = np.linspace(0, 1, 10000)
ys_mean, ys_std, score = calc_pr_mean(y, y_preds)
plt.plot(x, ys_mean, label=f"{label} - {score:5.4}")
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1)
plt.ylim([0.0, 1.0])
plt.xlim([0.0, 1.0])
plt.xlabel('Recall')
plt.ylabel('Precision')
@ -102,6 +122,37 @@ def plot_roc_curve(mask, prediction, label=""):
plt.ylabel('True Positive Rate')
def calc_roc_mean(y, y_preds):
appr = []
aucs = []
y = y.flatten()
for idx, y_pred in enumerate(y_preds):
y_pred = y_pred.flatten()
fpr, tpr, thresholds = roc_curve(y, y_pred)
appr.append(interpolate.interp1d(fpr, tpr))
aucs.append(auc(fpr, tpr))
x = np.linspace(0, 1, 10000)
ys = np.vstack([f(x) for f in appr])
ys_mean = ys.mean(axis=0)
ys_std = ys.std(axis=0)
auc_mean = np.mean(aucs)
return ys_mean, ys_std, auc_mean
def plot_roc_mean(y, y_preds, label=""):
x = np.linspace(0, 1, 10000)
ys_mean, ys_std, auc_mean = calc_roc_mean(y, y_preds)
plt.xscale('log')
plt.ylim([0.0, 1.0])
plt.xlim([0.0, 1.0])
plt.plot(x, ys_mean, label=f"{label} - {auc_mean:5.4}")
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
def plot_confusion_matrix(y_true, y_pred, path,
normalize=False,
classes=("benign", "malicious"),