refactor test function working on full unfiltered data
This commit is contained in:
parent
edc75f4f44
commit
9a51b6ea34
1
Makefile
1
Makefile
@ -66,4 +66,5 @@ hyper:
|
|||||||
|
|
||||||
clean:
|
clean:
|
||||||
rm -r results/test/test*
|
rm -r results/test/test*
|
||||||
|
rm data/rk_mini.csv.gz_raw.h5
|
||||||
rm data/rk_mini.csv.gz.h5
|
rm data/rk_mini.csv.gz.h5
|
||||||
|
12
arguments.py
12
arguments.py
@ -105,9 +105,9 @@ def get_model_args(args):
|
|||||||
"embedding_model": os.path.join(model_path, "embd.h5"),
|
"embedding_model": os.path.join(model_path, "embd.h5"),
|
||||||
"clf_model": os.path.join(model_path, "clf.h5"),
|
"clf_model": os.path.join(model_path, "clf.h5"),
|
||||||
"train_log": os.path.join(model_path, "train.log.csv"),
|
"train_log": os.path.join(model_path, "train.log.csv"),
|
||||||
"train_h5data": args.train_data + ".h5",
|
"train_h5data": args.train_data,
|
||||||
"test_h5data": args.test_data + ".h5",
|
"test_h5data": args.test_data,
|
||||||
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred.h5")
|
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred")
|
||||||
} for model_path in args.model_paths]
|
} for model_path in args.model_paths]
|
||||||
|
|
||||||
def parse():
|
def parse():
|
||||||
@ -115,7 +115,7 @@ def parse():
|
|||||||
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
||||||
args.clf_model = os.path.join(args.model_path, "clf.h5")
|
args.clf_model = os.path.join(args.model_path, "clf.h5")
|
||||||
args.train_log = os.path.join(args.model_path, "train.log.csv")
|
args.train_log = os.path.join(args.model_path, "train.log.csv")
|
||||||
args.train_h5data = args.train_data + ".h5"
|
args.train_h5data = args.train_data
|
||||||
args.test_h5data = args.test_data + ".h5"
|
args.test_h5data = args.test_data
|
||||||
args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.test_data)}_pred.h5")
|
args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.test_data)}_pred")
|
||||||
return args
|
return args
|
||||||
|
61
dataset.py
61
dataset.py
@ -4,6 +4,7 @@ import string
|
|||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
|
import joblib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -139,14 +140,18 @@ def create_raw_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=
|
|||||||
|
|
||||||
|
|
||||||
def store_h5dataset(path, data: dict):
|
def store_h5dataset(path, data: dict):
|
||||||
f = h5py.File(path, "w")
|
f = h5py.File(path + ".h5", "w")
|
||||||
for key, val in data.items():
|
for key, val in data.items():
|
||||||
f.create_dataset(key, data=val)
|
f.create_dataset(key, data=val)
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
|
|
||||||
|
def check_h5dataset(path):
|
||||||
|
return open(path + ".h5", "r")
|
||||||
|
|
||||||
|
|
||||||
def load_h5dataset(path):
|
def load_h5dataset(path):
|
||||||
f = h5py.File(path, "r")
|
f = h5py.File(path + ".h5", "r")
|
||||||
data = {}
|
data = {}
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
data[k] = f[k]
|
data[k] = f[k]
|
||||||
@ -225,17 +230,17 @@ def get_flow_per_user(df):
|
|||||||
|
|
||||||
|
|
||||||
def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
||||||
char_dict = get_character_dict()
|
|
||||||
logger.info(f"check for h5data {h5data}")
|
logger.info(f"check for h5data {h5data}")
|
||||||
try:
|
try:
|
||||||
open(h5data, "r")
|
check_h5dataset(h5data)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.info("h5 data not found - load csv file")
|
logger.info("load raw training dataset")
|
||||||
user_flow_df = get_user_flow_data(train_data)
|
domain, flow, name, hits, trusted_hits, server = load_or_generate_raw_h5data(h5data + "_raw", train_data,
|
||||||
logger.info("create training dataset")
|
domain_length, window_size)
|
||||||
domain, flow, name, client, server = create_dataset_from_flows(user_flow_df, char_dict,
|
logger.info("filter training dataset")
|
||||||
max_len=domain_length,
|
domain, flow, name, client, server = filter_window_dataset_by_hits(domain.value, flow.value,
|
||||||
window_size=window_size)
|
name.value, hits.value,
|
||||||
|
trusted_hits.value, server.value)
|
||||||
logger.info("store training dataset as h5 file")
|
logger.info("store training dataset as h5 file")
|
||||||
data = {
|
data = {
|
||||||
"domain": domain.astype(np.int8),
|
"domain": domain.astype(np.int8),
|
||||||
@ -250,6 +255,32 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
|||||||
return data["domain"], data["flow"], data["name"], data["client"], data["server"]
|
return data["domain"], data["flow"], data["name"], data["client"], data["server"]
|
||||||
|
|
||||||
|
|
||||||
|
def load_or_generate_raw_h5data(h5data, train_data, domain_length, window_size):
|
||||||
|
char_dict = get_character_dict()
|
||||||
|
logger.info(f"check for h5data {h5data}")
|
||||||
|
try:
|
||||||
|
check_h5dataset(h5data)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.info("h5 data not found - load csv file")
|
||||||
|
user_flow_df = get_user_flow_data(train_data)
|
||||||
|
logger.info("create raw training dataset")
|
||||||
|
domain, flow, name, hits, trusted_hits, server = create_raw_dataset_from_flows(user_flow_df, char_dict,
|
||||||
|
domain_length, window_size)
|
||||||
|
logger.info("store raw training dataset as h5 file")
|
||||||
|
data = {
|
||||||
|
"domain": domain.astype(np.int8),
|
||||||
|
"flow": flow,
|
||||||
|
"name": name,
|
||||||
|
"hits_vt": hits.astype(np.int8),
|
||||||
|
"hits_trusted": hits.astype(np.int8),
|
||||||
|
"server": server.astype(np.bool)
|
||||||
|
}
|
||||||
|
store_h5dataset(h5data, data)
|
||||||
|
logger.info("load h5 dataset")
|
||||||
|
data = load_h5dataset(h5data)
|
||||||
|
return data["domain"], data["flow"], data["name"], data["hits_vt"], data["hits_trusted"], data["server"]
|
||||||
|
|
||||||
|
|
||||||
def generate_names(train_data, window_size):
|
def generate_names(train_data, window_size):
|
||||||
user_flow_df = get_user_flow_data(train_data)
|
user_flow_df = get_user_flow_data(train_data)
|
||||||
with Pool() as pool:
|
with Pool() as pool:
|
||||||
@ -291,13 +322,9 @@ def load_or_generate_domains(train_data, domain_length):
|
|||||||
return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix().astype(bool)
|
return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix().astype(bool)
|
||||||
|
|
||||||
|
|
||||||
def save_predictions(path, c_pred, s_pred):
|
def save_predictions(path, results):
|
||||||
f = h5py.File(path, "w")
|
joblib.dump(results, path + "/results.joblib", compress=3)
|
||||||
f.create_dataset("client", data=c_pred)
|
|
||||||
f.create_dataset("server", data=s_pred)
|
|
||||||
f.close()
|
|
||||||
|
|
||||||
|
|
||||||
def load_predictions(path):
|
def load_predictions(path):
|
||||||
f = h5py.File(path, "r")
|
return joblib.load(path + "/results.joblib")
|
||||||
return f["client"], f["server"]
|
|
||||||
|
103
main.py
103
main.py
@ -1,13 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import joblib
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
|
from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint
|
||||||
from keras.models import load_model, Model
|
from keras.models import Model, load_model
|
||||||
|
|
||||||
import arguments
|
import arguments
|
||||||
import dataset
|
import dataset
|
||||||
@ -15,9 +14,8 @@ import hyperband
|
|||||||
import models
|
import models
|
||||||
# create logger
|
# create logger
|
||||||
import visualize
|
import visualize
|
||||||
from dataset import load_or_generate_h5data
|
|
||||||
from utils import exists_or_make_path, get_custom_class_weights
|
|
||||||
from arguments import get_model_args
|
from arguments import get_model_args
|
||||||
|
from utils import exists_or_make_path, get_custom_class_weights
|
||||||
|
|
||||||
logger = logging.getLogger('logger')
|
logger = logging.getLogger('logger')
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
@ -115,8 +113,9 @@ def main_hyperband():
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.info("create training dataset")
|
logger.info("create training dataset")
|
||||||
domain_tr, flow_tr, name_tr, client_tr, server_tr = load_or_generate_h5data(args.train_h5data, args.train_data,
|
domain_tr, flow_tr, name_tr, client_tr, server_tr = dataset.load_or_generate_h5data(args.train_h5data,
|
||||||
args.domain_length, args.window)
|
args.train_data,
|
||||||
|
args.domain_length, args.window)
|
||||||
hp = hyperband.Hyperband(params,
|
hp = hyperband.Hyperband(params,
|
||||||
[domain_tr, flow_tr],
|
[domain_tr, flow_tr],
|
||||||
[client_tr, server_tr])
|
[client_tr, server_tr])
|
||||||
@ -129,10 +128,10 @@ def main_train(param=None):
|
|||||||
exists_or_make_path(args.model_path)
|
exists_or_make_path(args.model_path)
|
||||||
logger.info(f"Use command line arguments: {args}")
|
logger.info(f"Use command line arguments: {args}")
|
||||||
|
|
||||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = load_or_generate_h5data(args.train_h5data,
|
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.train_h5data,
|
||||||
args.train_data,
|
args.train_data,
|
||||||
args.domain_length,
|
args.domain_length,
|
||||||
args.window)
|
args.window)
|
||||||
logger.info("define callbacks")
|
logger.info("define callbacks")
|
||||||
callbacks = []
|
callbacks = []
|
||||||
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
|
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
|
||||||
@ -245,11 +244,11 @@ def main_train(param=None):
|
|||||||
|
|
||||||
def main_test():
|
def main_test():
|
||||||
logger.info("start test: load data")
|
logger.info("start test: load data")
|
||||||
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data,
|
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.test_h5data,
|
||||||
args.test_data,
|
args.test_data,
|
||||||
args.domain_length,
|
args.domain_length,
|
||||||
args.window)
|
args.window)
|
||||||
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
domain_encs, _ = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||||
|
|
||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
results = {}
|
results = {}
|
||||||
@ -268,55 +267,49 @@ def main_test():
|
|||||||
results["client_pred"] = pred
|
results["client_pred"] = pred
|
||||||
else:
|
else:
|
||||||
results["server_pred"] = pred
|
results["server_pred"] = pred
|
||||||
# dataset.save_predictions(model_args["future_prediction"], c_pred, s_pred)
|
|
||||||
|
|
||||||
embd_model = load_model(model_args["embedding_model"])
|
embd_model = load_model(model_args["embedding_model"])
|
||||||
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||||
# np.save(model_args["model_path"] + "/domain_embds.npy", domain_embeddings)
|
|
||||||
|
|
||||||
results["domain_embds"] = domain_embeddings
|
results["domain_embds"] = domain_embeddings
|
||||||
joblib.dump(results, model_args["model_path"] + "/results.joblib", compress=3)
|
|
||||||
|
dataset.save_predictions(model_args["model_path"], results)
|
||||||
|
|
||||||
|
|
||||||
def main_visualization():
|
def main_visualization():
|
||||||
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data,
|
domain_val, flow_val, name_val, client_val, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data,
|
||||||
args.test_data,
|
args.test_data,
|
||||||
args.domain_length,
|
args.domain_length,
|
||||||
args.window)
|
args.window)
|
||||||
# client_val, server_val = client_val.value, server_val.value
|
|
||||||
client_val = client_val.value
|
client_val = client_val.value
|
||||||
|
|
||||||
logger.info("plot model")
|
logger.info("plot model")
|
||||||
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||||
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
|
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
|
||||||
|
|
||||||
try:
|
logger.info("plot training curve")
|
||||||
logger.info("plot training curve")
|
logs = pd.read_csv(args.train_log)
|
||||||
logs = pd.read_csv(args.train_log)
|
if "acc" in logs.keys():
|
||||||
if args.model_output == "client":
|
visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
|
||||||
visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
|
elif "client_acc" in logs.keys() and "server_acc" in logs.keys():
|
||||||
else:
|
visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
|
||||||
visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
|
visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
|
||||||
visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
|
else:
|
||||||
except Exception as e:
|
logger.warning("Error while plotting training curves")
|
||||||
logger.warning(f"could not generate training curves: {e}")
|
|
||||||
|
results = dataset.load_predictions(args.future_prediction)
|
||||||
|
client_pred = results["client_pred"].flatten()
|
||||||
|
|
||||||
client_pred, server_pred = dataset.load_predictions(args.future_prediction)
|
|
||||||
client_pred, server_pred = client_pred.value.flatten(), server_pred.value.flatten()
|
|
||||||
logger.info("plot pr curve")
|
logger.info("plot pr curve")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_precision_recall(client_val, client_pred, args.model_path)
|
visualize.plot_precision_recall(client_val, client_pred, args.model_path)
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save("{}/window_client_prc.png".format(args.model_path))
|
visualize.plot_save("{}/window_client_prc.png".format(args.model_path))
|
||||||
# visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path))
|
|
||||||
# visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path))
|
|
||||||
# visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path))
|
|
||||||
logger.info("plot roc curve")
|
logger.info("plot roc curve")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_roc_curve(client_val, client_pred, args.model_path)
|
visualize.plot_roc_curve(client_val, client_pred, args.model_path)
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save("{}/window_client_roc.png".format(args.model_path))
|
visualize.plot_save("{}/window_client_roc.png".format(args.model_path))
|
||||||
# visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path))
|
|
||||||
|
|
||||||
print(f"names {name_val.shape} vals {client_val.shape} preds {client_pred.shape}")
|
print(f"names {name_val.shape} vals {client_val.shape} preds {client_pred.shape}")
|
||||||
|
|
||||||
@ -348,23 +341,25 @@ def main_visualization():
|
|||||||
|
|
||||||
|
|
||||||
def main_visualize_all():
|
def main_visualize_all():
|
||||||
domain_val, flow_val, name_val, client_val, server_val = load_or_generate_h5data(args.test_h5data,
|
domain_val, flow_val, name_val, client_val, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data,
|
||||||
args.test_data,
|
args.test_data,
|
||||||
args.domain_length,
|
args.domain_length,
|
||||||
args.window)
|
args.window)
|
||||||
logger.info("plot pr curves")
|
logger.info("plot pr curves")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
results = dataset.load_predictions(model_args["future_prediction"])
|
||||||
visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"])
|
client_pred = results["client_pred"].flatten()
|
||||||
|
visualize.plot_precision_recall(client_val, client_pred, model_args["model_path"])
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{args.output_prefix}_window_client_prc.png")
|
visualize.plot_save(f"{args.output_prefix}_window_client_prc.png")
|
||||||
|
|
||||||
logger.info("plot roc curves")
|
logger.info("plot roc curves")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
results = dataset.load_predictions(model_args["future_prediction"])
|
||||||
visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"])
|
client_pred = results["client_pred"].flatten()
|
||||||
|
visualize.plot_roc_curve(client_val, client_pred, model_args["model_path"])
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{args.output_prefix}_window_client_roc.png")
|
visualize.plot_save(f"{args.output_prefix}_window_client_roc.png")
|
||||||
|
|
||||||
@ -374,8 +369,9 @@ def main_visualize_all():
|
|||||||
logger.info("plot user pr curves")
|
logger.info("plot user pr curves")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
results = dataset.load_predictions(model_args["future_prediction"])
|
||||||
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()})
|
client_pred = results["client_pred"].flatten()
|
||||||
|
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred})
|
||||||
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
||||||
visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"])
|
visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"])
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
@ -384,8 +380,9 @@ def main_visualize_all():
|
|||||||
logger.info("plot user roc curves")
|
logger.info("plot user roc curves")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
results = dataset.load_predictions(model_args["future_prediction"])
|
||||||
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred.value.flatten()})
|
client_pred = results["client_pred"].flatten()
|
||||||
|
df_pred = pd.DataFrame(data={"names": name_val, "client_val": client_pred})
|
||||||
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
||||||
visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"])
|
visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"])
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
|
Loading…
Reference in New Issue
Block a user