refactor test function working on full unfiltered data
This commit is contained in:
parent
edc75f4f44
commit
9a51b6ea34
1
Makefile
1
Makefile
@ -66,4 +66,5 @@ hyper:
|
||||
|
||||
clean:
|
||||
rm -r results/test/test*
|
||||
rm data/rk_mini.csv.gz_raw.h5
|
||||
rm data/rk_mini.csv.gz.h5
|
||||
|
12
arguments.py
12
arguments.py
@ -105,9 +105,9 @@ def get_model_args(args):
|
||||
"embedding_model": os.path.join(model_path, "embd.h5"),
|
||||
"clf_model": os.path.join(model_path, "clf.h5"),
|
||||
"train_log": os.path.join(model_path, "train.log.csv"),
|
||||
"train_h5data": args.train_data + ".h5",
|
||||
"test_h5data": args.test_data + ".h5",
|
||||
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred.h5")
|
||||
"train_h5data": args.train_data,
|
||||
"test_h5data": args.test_data,
|
||||
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred")
|
||||
} for model_path in args.model_paths]
|
||||
|
||||
def parse():
|
||||
@ -115,7 +115,7 @@ def parse():
|
||||
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
||||
args.clf_model = os.path.join(args.model_path, "clf.h5")
|
||||
args.train_log = os.path.join(args.model_path, "train.log.csv")
|
||||
args.train_h5data = args.train_data + ".h5"
|
||||
args.test_h5data = args.test_data + ".h5"
|
||||
args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.test_data)}_pred.h5")
|
||||
args.train_h5data = args.train_data
|
||||
args.test_h5data = args.test_data
|
||||
args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.test_data)}_pred")
|
||||
return args
|
||||
|
61
dataset.py
61
dataset.py
@ -4,6 +4,7 @@ import string
|
||||
from multiprocessing import Pool
|
||||
|
||||
import h5py
|
||||
import joblib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
@ -139,14 +140,18 @@ def create_raw_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=
|
||||
|
||||
|
||||
def store_h5dataset(path, data: dict):
|
||||
f = h5py.File(path, "w")
|
||||
f = h5py.File(path + ".h5", "w")
|
||||
for key, val in data.items():
|
||||
f.create_dataset(key, data=val)
|
||||
f.close()
|
||||
|
||||
|
||||
def check_h5dataset(path):
|
||||
return open(path + ".h5", "r")
|
||||
|
||||
|
||||
def load_h5dataset(path):
|
||||
f = h5py.File(path, "r")
|
||||
f = h5py.File(path + ".h5", "r")
|
||||
data = {}
|
||||
for k in f.keys():
|
||||
data[k] = f[k]
|
||||
@ -225,17 +230,17 @@ def get_flow_per_user(df):
|
||||
|
||||
|
||||
def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
||||
char_dict = get_character_dict()
|
||||
logger.info(f"check for h5data {h5data}")
|
||||
try:
|
||||
open(h5data, "r")
|
||||
check_h5dataset(h5data)
|
||||
except FileNotFoundError:
|
||||
logger.info("h5 data not found - load csv file")
|
||||
user_flow_df = get_user_flow_data(train_data)
|
||||
logger.info("create training dataset")
|
||||
domain, flow, name, client, server = create_dataset_from_flows(user_flow_df, char_dict,
|
||||
max_len=domain_length,
|
||||
window_size=window_size)
|
||||
logger.info("load raw training dataset")
|
||||
domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(h5data + "_raw", train_data,
|
||||
domain_length, window_size)
|
||||
logger.info("filter training dataset")
|
||||
domain, flow, name, client, server = filter_window_dataset_by_hits(domain.value, flow.value,
|
||||
name.value, hits.value,
|
||||
trusted_hits.value, server.value)
|
||||
logger.info("store training dataset as h5 file")
|
||||
data = {
|
||||
"domain": domain.astype(np.int8),
|
||||
@ -250,6 +255,32 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
||||
return data["domain"], data["flow"], data["name"], data["client"], data["server"]
|
||||
|
||||
|
||||
def load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size):
|
||||
char_dict = get_character_dict()
|
||||
logger.info(f"check for h5data {h5data}")
|
||||
try:
|
||||
check_h5dataset(h5data)
|
||||
except FileNotFoundError:
|
||||
logger.info("h5 data not found - load csv file")
|
||||
user_flow_df = get_user_flow_data(train_data)
|
||||
logger.info("create raw training dataset")
|
||||
domain, flow, name, hits, trusted_hits, server = create_raw_dataset_from_flows(user_flow_df, char_dict,
|
||||
domain_length, window_size)
|
||||
logger.info("store raw training dataset as h5 file")
|
||||
data = {
|
||||
"domain": domain.astype(np.int8),
|
||||
"flow": flow,
|
||||
"name": name,
|
||||
"hits_vt": hits.astype(np.int8),
|
||||
"hits_trusted": hits.astype(np.int8),
|
||||
"server": server.astype(np.bool)
|
||||
}
|
||||
store_h5dataset(h5data, data)
|
||||
logger.info("load h5 dataset")
|
||||
data = load_h5dataset(h5data)
|
||||
return data["domain"], data["flow"], data["name"], data["hits_vt"], data["hits_trusted"], data["server"]
|
||||
|
||||
|
||||
def generate_names(train_data, window_size):
|
||||
user_flow_df = get_user_flow_data(train_data)
|
||||
with Pool() as pool:
|
||||
@ -291,13 +322,9 @@ def load_or_generate_domains(train_data, domain_length):
|
||||
return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix().astype(bool)
|
||||
|
||||
|
||||
def save_predictions(path, c_pred, s_pred):
|
||||
f = h5py.File(path, "w")
|
||||
f.create_dataset("client", data=c_pred)
|
||||
f.create_dataset("server", data=s_pred)
|
||||
f.close()
|
||||
def save_predictions(path, results):
|
||||
joblib.dump(results, path + "/results.joblib", compress=3)
|
||||
|
||||
|
||||
def load_predictions(path):
|
||||
f = h5py.File(path, "r")
|
||||
return f["client"], f["server"]
|
||||
return joblib.load(path + "/results.joblib")
|
||||
|
129
main.py
129
main.py
@ -1,13 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import joblib
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow as tf
|
||||
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
|
||||
from keras.models import load_model, Model
|
||||
from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint
|
||||
from keras.models import Model, load_model
|
||||
|
||||
import arguments
|
||||
import dataset
|
||||
@ -15,9 +14,8 @@ import hyperband
|
||||
import models
|
||||
# create logger
|
||||
import visualize
|
||||
from dataset import load_or_generate_h5data
|
||||
from utils import exists_or_make_path, get_custom_class_weights
|
||||
from arguments import get_model_args
|
||||
from utils import exists_or_make_path, get_custom_class_weights
|
||||
|
||||
logger = logging.getLogger('logger')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
@ -115,8 +113,9 @@ def main_hyperband():
|
||||
}
|
||||
|
||||
logger.info("create training dataset")
|
||||
domain_tr, flow_tr, name_tr, client_tr, server_tr = load_or_generate_h5data(args.train_h5data, args.train_data,
|
||||
args.domain_length, args.window)
|
||||
domain_tr, flow_tr, name_tr, client_tr, server_tr = dataset.load_or_generate_h5data(args.train_h5data,
|
||||
args.train_data,
|
||||
args.domain_length, args.window)
|
||||
hp = hyperband.Hyperband(params,
|
||||
[domain_tr, flow_tr],
|
||||
[client_tr, server_tr])
|
||||
@ -129,10 +128,10 @@ def main_train(param=None):
|
||||
exists_or_make_path(args.model_path)
|
||||
logger.info(f"Use command line arguments: {args}")
|
||||
|
||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = load_or_generate_h5data(args.train_h5data,
|
||||
args.train_data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.train_h5data,
|
||||
args.train_data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
logger.info("define callbacks")
|
||||
callbacks = []
|
||||
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
|
||||
@ -245,12 +244,12 @@ def main_train(param=None):
|
||||
|
||||
def main_test():
|
||||
logger.info("start test: load data")
|
||||
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data,
|
||||
args.test_data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||
|
||||
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.test_h5data,
|
||||
args.test_data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
domain_encs, _ = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||
|
||||
for model_args in get_model_args(args):
|
||||
results = {}
|
||||
logger.info(f"process model {model_args['model_path']}")
|
||||
@ -268,73 +267,67 @@ def main_test():
|
||||
results["client_pred"] = pred
|
||||
else:
|
||||
results["server_pred"] = pred
|
||||
# dataset.save_predictions(model_args["future_prediction"], c_pred, s_pred)
|
||||
|
||||
embd_model = load_model(model_args["embedding_model"])
|
||||
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||
# np.save(model_args["model_path"] + "/domain_embds.npy", domain_embeddings)
|
||||
|
||||
results["domain_embds"] = domain_embeddings
|
||||
joblib.dump(results, model_args["model_path"] + "/results.joblib", compress=3)
|
||||
|
||||
dataset.save_predictions(model_args["model_path"], results)
|
||||
|
||||
|
||||
def main_visualization():
|
||||
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data,
|
||||
args.test_data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
# client_val, server_val = client_val.value, server_val.value
|
||||
domain_val, flow_val, name_val, client_val, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data,
|
||||
args.test_data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
client_val = client_val.value
|
||||
|
||||
|
||||
logger.info("plot model")
|
||||
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
|
||||
|
||||
try:
|
||||
logger.info("plot training curve")
|
||||
logs = pd.read_csv(args.train_log)
|
||||
if args.model_output == "client":
|
||||
visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
|
||||
else:
|
||||
visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
|
||||
visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
|
||||
except Exception as e:
|
||||
logger.warning(f"could not generate training curves: {e}")
|
||||
|
||||
client_pred, server_pred = dataset.load_predictions(args.future_prediction)
|
||||
client_pred, server_pred = client_pred.value.flatten(), server_pred.value.flatten()
|
||||
|
||||
logger.info("plot training curve")
|
||||
logs = pd.read_csv(args.train_log)
|
||||
if "acc" in logs.keys():
|
||||
visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
|
||||
elif "client_acc" in logs.keys() and "server_acc" in logs.keys():
|
||||
visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
|
||||
visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
|
||||
else:
|
||||
logger.warning("Error while plotting training curves")
|
||||
|
||||
results = dataset.load_predictions(args.future_prediction)
|
||||
client_pred = results["client_pred"].flatten()
|
||||
|
||||
logger.info("plot pr curve")
|
||||
visualize.plot_clf()
|
||||
visualize.plot_precision_recall(client_val, client_pred, args.model_path)
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("{}/window_client_prc.png".format(args.model_path))
|
||||
# visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path))
|
||||
# visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path))
|
||||
# visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path))
|
||||
|
||||
logger.info("plot roc curve")
|
||||
visualize.plot_clf()
|
||||
visualize.plot_roc_curve(client_val, client_pred, args.model_path)
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("{}/window_client_roc.png".format(args.model_path))
|
||||
# visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path))
|
||||
|
||||
|
||||
print(f"names {name_val.shape} vals {client_val.shape} preds {client_pred.shape}")
|
||||
|
||||
|
||||
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
|
||||
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
|
||||
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred})
|
||||
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
||||
|
||||
|
||||
visualize.plot_clf()
|
||||
visualize.plot_precision_recall(user_vals, user_preds, args.model_path)
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("{}/user_client_prc.png".format(args.model_path))
|
||||
|
||||
|
||||
visualize.plot_clf()
|
||||
visualize.plot_roc_curve(user_vals, user_preds, args.model_path)
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("{}/user_client_roc.png".format(args.model_path))
|
||||
|
||||
|
||||
visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(),
|
||||
"{}/client_cov.png".format(args.model_path),
|
||||
normalize=False, title="Client Confusion Matrix")
|
||||
@ -348,44 +341,48 @@ def main_visualization():
|
||||
|
||||
|
||||
def main_visualize_all():
|
||||
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data,
|
||||
args.test_data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
domain_val, flow_val, name_val, client_val, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data,
|
||||
args.test_data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
logger.info("plot pr curves")
|
||||
visualize.plot_clf()
|
||||
for model_args in get_model_args(args):
|
||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||
visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"])
|
||||
results = dataset.load_predictions(model_args["future_prediction"])
|
||||
client_pred = results["client_pred"].flatten()
|
||||
visualize.plot_precision_recall(client_val, client_pred, model_args["model_path"])
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save(f"{args.output_prefix}_window_client_prc.png")
|
||||
|
||||
|
||||
logger.info("plot roc curves")
|
||||
visualize.plot_clf()
|
||||
for model_args in get_model_args(args):
|
||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||
visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"])
|
||||
results = dataset.load_predictions(model_args["future_prediction"])
|
||||
client_pred = results["client_pred"].flatten()
|
||||
visualize.plot_roc_curve(client_val, client_pred, model_args["model_path"])
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save(f"{args.output_prefix}_window_client_roc.png")
|
||||
|
||||
|
||||
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
|
||||
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
|
||||
|
||||
|
||||
logger.info("plot user pr curves")
|
||||
visualize.plot_clf()
|
||||
for model_args in get_model_args(args):
|
||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()})
|
||||
results = dataset.load_predictions(model_args["future_prediction"])
|
||||
client_pred = results["client_pred"].flatten()
|
||||
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred})
|
||||
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
||||
visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"])
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save(f"{args.output_prefix}_user_client_prc.png")
|
||||
|
||||
|
||||
logger.info("plot user roc curves")
|
||||
visualize.plot_clf()
|
||||
for model_args in get_model_args(args):
|
||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()})
|
||||
results = dataset.load_predictions(model_args["future_prediction"])
|
||||
client_pred = results["client_pred"].flatten()
|
||||
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred})
|
||||
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
||||
visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"])
|
||||
visualize.plot_legend()
|
||||
|
Loading…
x
Reference in New Issue
Block a user