refactor test function working on full unfiltered data

This commit is contained in:
René Knaebel 2017-09-08 19:10:23 +02:00
parent edc75f4f44
commit 9a51b6ea34
4 changed files with 114 additions and 89 deletions

View File

@ -66,4 +66,5 @@ hyper:
clean:
rm -r results/test/test*
rm data/rk_mini.csv.gz_raw.h5
rm data/rk_mini.csv.gz.h5

View File

@ -105,9 +105,9 @@ def get_model_args(args):
"embedding_model": os.path.join(model_path, "embd.h5"),
"clf_model": os.path.join(model_path, "clf.h5"),
"train_log": os.path.join(model_path, "train.log.csv"),
"train_h5data": args.train_data + ".h5",
"test_h5data": args.test_data + ".h5",
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred.h5")
"train_h5data": args.train_data,
"test_h5data": args.test_data,
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred")
} for model_path in args.model_paths]
def parse():
@ -115,7 +115,7 @@ def parse():
args.embedding_model = os.path.join(args.model_path, "embd.h5")
args.clf_model = os.path.join(args.model_path, "clf.h5")
args.train_log = os.path.join(args.model_path, "train.log.csv")
args.train_h5data = args.train_data + ".h5"
args.test_h5data = args.test_data + ".h5"
args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.test_data)}_pred.h5")
args.train_h5data = args.train_data
args.test_h5data = args.test_data
args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.test_data)}_pred")
return args

View File

@ -4,6 +4,7 @@ import string
from multiprocessing import Pool
import h5py
import joblib
import numpy as np
import pandas as pd
from tqdm import tqdm
@ -139,14 +140,18 @@ def create_raw_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=
def store_h5dataset(path, data: dict):
f = h5py.File(path, "w")
f = h5py.File(path + ".h5", "w")
for key, val in data.items():
f.create_dataset(key, data=val)
f.close()
def check_h5dataset(path):
return open(path + ".h5", "r")
def load_h5dataset(path):
f = h5py.File(path, "r")
f = h5py.File(path + ".h5", "r")
data = {}
for k in f.keys():
data[k] = f[k]
@ -225,17 +230,17 @@ def get_flow_per_user(df):
def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
char_dict = get_character_dict()
logger.info(f"check for h5data {h5data}")
try:
open(h5data, "r")
check_h5dataset(h5data)
except FileNotFoundError:
logger.info("h5 data not found - load csv file")
user_flow_df = get_user_flow_data(train_data)
logger.info("create training dataset")
domain, flow, name, client, server = create_dataset_from_flows(user_flow_df, char_dict,
max_len=domain_length,
window_size=window_size)
logger.info("load raw training dataset")
domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(h5data + "_raw", train_data,
domain_length, window_size)
logger.info("filter training dataset")
domain, flow, name, client, server = filter_window_dataset_by_hits(domain.value, flow.value,
name.value, hits.value,
trusted_hits.value, server.value)
logger.info("store training dataset as h5 file")
data = {
"domain": domain.astype(np.int8),
@ -250,6 +255,32 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
return data["domain"], data["flow"], data["name"], data["client"], data["server"]
def load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size):
char_dict = get_character_dict()
logger.info(f"check for h5data {h5data}")
try:
check_h5dataset(h5data)
except FileNotFoundError:
logger.info("h5 data not found - load csv file")
user_flow_df = get_user_flow_data(train_data)
logger.info("create raw training dataset")
domain, flow, name, hits, trusted_hits, server = create_raw_dataset_from_flows(user_flow_df, char_dict,
domain_length, window_size)
logger.info("store raw training dataset as h5 file")
data = {
"domain": domain.astype(np.int8),
"flow": flow,
"name": name,
"hits_vt": hits.astype(np.int8),
"hits_trusted": hits.astype(np.int8),
"server": server.astype(np.bool)
}
store_h5dataset(h5data, data)
logger.info("load h5 dataset")
data = load_h5dataset(h5data)
return data["domain"], data["flow"], data["name"], data["hits_vt"], data["hits_trusted"], data["server"]
def generate_names(train_data, window_size):
user_flow_df = get_user_flow_data(train_data)
with Pool() as pool:
@ -291,13 +322,9 @@ def load_or_generate_domains(train_data, domain_length):
return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix().astype(bool)
def save_predictions(path, c_pred, s_pred):
f = h5py.File(path, "w")
f.create_dataset("client", data=c_pred)
f.create_dataset("server", data=s_pred)
f.close()
def save_predictions(path, results):
joblib.dump(results, path + "/results.joblib", compress=3)
def load_predictions(path):
f = h5py.File(path, "r")
return f["client"], f["server"]
return joblib.load(path + "/results.joblib")

67
main.py
View File

@ -1,13 +1,12 @@
import json
import logging
import os
import joblib
import numpy as np
import pandas as pd
import tensorflow as tf
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
from keras.models import load_model, Model
from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint
from keras.models import Model, load_model
import arguments
import dataset
@ -15,9 +14,8 @@ import hyperband
import models
# create logger
import visualize
from dataset import load_or_generate_h5data
from utils import exists_or_make_path, get_custom_class_weights
from arguments import get_model_args
from utils import exists_or_make_path, get_custom_class_weights
logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG)
@ -115,7 +113,8 @@ def main_hyperband():
}
logger.info("create training dataset")
domain_tr, flow_tr, name_tr, client_tr, server_tr = load_or_generate_h5data(args.train_h5data, args.train_data,
domain_tr, flow_tr, name_tr, client_tr, server_tr = dataset.load_or_generate_h5data(args.train_h5data,
args.train_data,
args.domain_length, args.window)
hp = hyperband.Hyperband(params,
[domain_tr, flow_tr],
@ -129,7 +128,7 @@ def main_train(param=None):
exists_or_make_path(args.model_path)
logger.info(f"Use command line arguments: {args}")
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = load_or_generate_h5data(args.train_h5data,
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.train_h5data,
args.train_data,
args.domain_length,
args.window)
@ -245,11 +244,11 @@ def main_train(param=None):
def main_test():
logger.info("start test: load data")
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data,
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.test_h5data,
args.test_data,
args.domain_length,
args.window)
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
domain_encs, _ = dataset.load_or_generate_domains(args.test_data, args.domain_length)
for model_args in get_model_args(args):
results = {}
@ -268,55 +267,49 @@ def main_test():
results["client_pred"] = pred
else:
results["server_pred"] = pred
# dataset.save_predictions(model_args["future_prediction"], c_pred, s_pred)
embd_model = load_model(model_args["embedding_model"])
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
# np.save(model_args["model_path"] + "/domain_embds.npy", domain_embeddings)
results["domain_embds"] = domain_embeddings
joblib.dump(results, model_args["model_path"] + "/results.joblib", compress=3)
dataset.save_predictions(model_args["model_path"], results)
def main_visualization():
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data,
domain_val, flow_val, name_val, client_val, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data,
args.test_data,
args.domain_length,
args.window)
# client_val, server_val = client_val.value, server_val.value
client_val = client_val.value
logger.info("plot model")
model = load_model(args.clf_model, custom_objects=models.get_metrics())
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
try:
logger.info("plot training curve")
logs = pd.read_csv(args.train_log)
if args.model_output == "client":
if "acc" in logs.keys():
visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
else:
elif "client_acc" in logs.keys() and "server_acc" in logs.keys():
visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
except Exception as e:
logger.warning(f"could not generate training curves: {e}")
else:
logger.warning("Error while plotting training curves")
results = dataset.load_predictions(args.future_prediction)
client_pred = results["client_pred"].flatten()
client_pred, server_pred = dataset.load_predictions(args.future_prediction)
client_pred, server_pred = client_pred.value.flatten(), server_pred.value.flatten()
logger.info("plot pr curve")
visualize.plot_clf()
visualize.plot_precision_recall(client_val, client_pred, args.model_path)
visualize.plot_legend()
visualize.plot_save("{}/window_client_prc.png".format(args.model_path))
# visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path))
# visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path))
# visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path))
logger.info("plot roc curve")
visualize.plot_clf()
visualize.plot_roc_curve(client_val, client_pred, args.model_path)
visualize.plot_legend()
visualize.plot_save("{}/window_client_roc.png".format(args.model_path))
# visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path))
print(f"names {name_val.shape} vals {client_val.shape} preds {client_pred.shape}")
@ -348,23 +341,25 @@ def main_visualization():
def main_visualize_all():
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data,
domain_val, flow_val, name_val, client_val, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data,
args.test_data,
args.domain_length,
args.window)
logger.info("plot pr curves")
visualize.plot_clf()
for model_args in get_model_args(args):
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"])
results = dataset.load_predictions(model_args["future_prediction"])
client_pred = results["client_pred"].flatten()
visualize.plot_precision_recall(client_val, client_pred, model_args["model_path"])
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_prc.png")
logger.info("plot roc curves")
visualize.plot_clf()
for model_args in get_model_args(args):
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"])
results = dataset.load_predictions(model_args["future_prediction"])
client_pred = results["client_pred"].flatten()
visualize.plot_roc_curve(client_val, client_pred, model_args["model_path"])
visualize.plot_legend()
visualize.plot_save(f"{args.output_prefix}_window_client_roc.png")
@ -374,8 +369,9 @@ def main_visualize_all():
logger.info("plot user pr curves")
visualize.plot_clf()
for model_args in get_model_args(args):
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()})
results = dataset.load_predictions(model_args["future_prediction"])
client_pred = results["client_pred"].flatten()
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred})
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"])
visualize.plot_legend()
@ -384,8 +380,9 @@ def main_visualize_all():
logger.info("plot user roc curves")
visualize.plot_clf()
for model_args in get_model_args(args):
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()})
results = dataset.load_predictions(model_args["future_prediction"])
client_pred = results["client_pred"].flatten()
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred})
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"])
visualize.plot_legend()