refactor using joblib for test results, make h5py store/load more flexible
This commit is contained in:
parent
1ab0108c78
commit
70d00efb01
1
Makefile
1
Makefile
@ -34,3 +34,4 @@ hyper:
|
|||||||
|
|
||||||
clean:
|
clean:
|
||||||
rm -r results/test*
|
rm -r results/test*
|
||||||
|
rm data/rk_mini.csv.gz.h5
|
||||||
|
66
dataset.py
66
dataset.py
@ -126,22 +126,48 @@ def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10):
|
|||||||
return domain, flow, names, client_tr, server
|
return domain, flow, names, client_tr, server
|
||||||
|
|
||||||
|
|
||||||
def store_h5dataset(path, domain, flow, name, client, server):
|
def create_testset_from_flows(user_flow_df, char_dict, max_len, window_size=10):
|
||||||
|
logger.info("get chunks from user data frames")
|
||||||
|
with Pool() as pool:
|
||||||
|
results = []
|
||||||
|
for user_flow in tqdm(get_flow_per_user(user_flow_df), total=len(user_flow_df['user_hash'].unique().tolist())):
|
||||||
|
results.append(pool.apply_async(get_user_chunks, (user_flow, window_size)))
|
||||||
|
windows = [window for res in results for window in res.get()]
|
||||||
|
logger.info("create training dataset")
|
||||||
|
domain, flow, hits, names, server, trusted_hits = create_dataset_from_lists(chunks=windows,
|
||||||
|
vocab=char_dict,
|
||||||
|
max_len=max_len)
|
||||||
|
# make client labels discrete with 4 different values
|
||||||
|
hits = np.apply_along_axis(lambda x: discretize_label(x, 3), 0, np.atleast_2d(hits))
|
||||||
|
# select only 1.0 and 0.0 from training data
|
||||||
|
pos_idx = np.where(np.logical_or(hits == 1.0, trusted_hits >= 1.0))[0]
|
||||||
|
neg_idx = np.where(hits == 0.0)[0]
|
||||||
|
idx = np.concatenate((pos_idx, neg_idx))
|
||||||
|
# choose selected sample to train on
|
||||||
|
domain = domain[idx]
|
||||||
|
flow = flow[idx]
|
||||||
|
client_tr = np.zeros_like(idx, float)
|
||||||
|
client_tr[:pos_idx.shape[-1]] = 1.0
|
||||||
|
server = server[idx]
|
||||||
|
names = names[idx]
|
||||||
|
|
||||||
|
return domain, flow, names, client_tr, server
|
||||||
|
|
||||||
|
|
||||||
|
def store_h5dataset(path, data: dict):
|
||||||
f = h5py.File(path, "w")
|
f = h5py.File(path, "w")
|
||||||
domain = domain.astype(np.int8)
|
for key, val in data.items():
|
||||||
f.create_dataset("domain", data=domain)
|
f.create_dataset(key, data=val)
|
||||||
f.create_dataset("flow", data=flow)
|
|
||||||
f.create_dataset("name", data=name)
|
|
||||||
server = server.astype(np.bool)
|
|
||||||
client = client.astype(np.bool)
|
|
||||||
f.create_dataset("client", data=client)
|
|
||||||
f.create_dataset("server", data=server)
|
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
|
|
||||||
def load_h5dataset(path):
|
def load_h5dataset(path):
|
||||||
data = h5py.File(path, "r")
|
f = h5py.File(path, "r")
|
||||||
return data["domain"], data["flow"], data["name"], data["client"], data["server"]
|
keys = f.keys()
|
||||||
|
data = {}
|
||||||
|
for k in keys:
|
||||||
|
data[k] = f[k]
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def create_dataset_from_lists(chunks, vocab, max_len):
|
def create_dataset_from_lists(chunks, vocab, max_len):
|
||||||
@ -224,13 +250,21 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
|||||||
logger.info("h5 data not found - load csv file")
|
logger.info("h5 data not found - load csv file")
|
||||||
user_flow_df = get_user_flow_data(train_data)
|
user_flow_df = get_user_flow_data(train_data)
|
||||||
logger.info("create training dataset")
|
logger.info("create training dataset")
|
||||||
domain, flow, names, client, server = create_dataset_from_flows(user_flow_df, char_dict,
|
domain, flow, name, client, server = create_dataset_from_flows(user_flow_df, char_dict,
|
||||||
max_len=domain_length,
|
max_len=domain_length,
|
||||||
window_size=window_size)
|
window_size=window_size)
|
||||||
logger.info("store training dataset as h5 file")
|
logger.info("store training dataset as h5 file")
|
||||||
store_h5dataset(h5data, domain, flow, names, client, server)
|
data = {
|
||||||
|
"domain": domain.astype(np.int8),
|
||||||
|
"flow": flow,
|
||||||
|
"name": name,
|
||||||
|
"client": client.astype(np.bool),
|
||||||
|
"server": server.astype(np.bool)
|
||||||
|
}
|
||||||
|
store_h5dataset(h5data, data)
|
||||||
logger.info("load h5 dataset")
|
logger.info("load h5 dataset")
|
||||||
return load_h5dataset(h5data)
|
data = load_h5dataset(h5data)
|
||||||
|
return data["domain"], data["flow"], data["name"], data["client"], data["server"]
|
||||||
|
|
||||||
|
|
||||||
def generate_names(train_data, window_size):
|
def generate_names(train_data, window_size):
|
||||||
|
70
main.py
70
main.py
@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import joblib
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -78,6 +79,15 @@ PARAMS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(model, output_type):
|
||||||
|
if output_type == "both":
|
||||||
|
return Model(inputs=[model.in_domains, model.in_flows], outputs=(model.out_client, model.out_server))
|
||||||
|
elif output_type == "client":
|
||||||
|
return Model(inputs=[model.in_domains, model.in_flows], outputs=(model.out_client,))
|
||||||
|
else:
|
||||||
|
raise Exception("unknown model output")
|
||||||
|
|
||||||
|
|
||||||
def main_paul_best():
|
def main_paul_best():
|
||||||
pauls_best_params = models.pauls_networks.best_config
|
pauls_best_params = models.pauls_networks.best_config
|
||||||
main_train(pauls_best_params)
|
main_train(pauls_best_params)
|
||||||
@ -154,11 +164,7 @@ def main_train(param=None):
|
|||||||
logger.info(f"Generator model with params: {param}")
|
logger.info(f"Generator model with params: {param}")
|
||||||
embedding, model, new_model = models.get_models_by_params(param)
|
embedding, model, new_model = models.get_models_by_params(param)
|
||||||
|
|
||||||
if args.model_output == "both":
|
model = create_model(new_model, args.model_output)
|
||||||
model = Model(inputs=[new_model.in_domains, new_model.in_flows],
|
|
||||||
outputs=(new_model.out_server, new_model.out_client))
|
|
||||||
else:
|
|
||||||
raise Exception("unknown model output")
|
|
||||||
|
|
||||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||||
logger.info("compile and train model")
|
logger.info("compile and train model")
|
||||||
@ -202,15 +208,8 @@ def main_train(param=None):
|
|||||||
logger.info(f"Generator model with params: {param}")
|
logger.info(f"Generator model with params: {param}")
|
||||||
embedding, model, new_model = models.get_models_by_params(param)
|
embedding, model, new_model = models.get_models_by_params(param)
|
||||||
|
|
||||||
if args.model_output == "both":
|
model = create_model(model, args.model_output)
|
||||||
model = Model(inputs=[model.in_domains, model.in_flows], outputs=(model.out_client, model.out_server))
|
new_model = create_model(new_model, args.model_output)
|
||||||
new_model = Model(inputs=[new_model.in_domains, new_model.in_flows],
|
|
||||||
outputs=(new_model.out_client, new_model.out_server))
|
|
||||||
elif args.model_output == "client":
|
|
||||||
model = Model(inputs=[model.in_domains, model.in_flows], outputs=(model.out_client,))
|
|
||||||
new_model = Model(inputs=[new_model.in_domains, new_model.in_flows], outputs=(new_model.out_client,))
|
|
||||||
else:
|
|
||||||
raise Exception("unknown model output")
|
|
||||||
|
|
||||||
if args.model_type == "inter":
|
if args.model_type == "inter":
|
||||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||||
@ -253,6 +252,7 @@ def main_test():
|
|||||||
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||||
|
|
||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
|
results = {}
|
||||||
logger.info(f"process model {model_args['model_path']}")
|
logger.info(f"process model {model_args['model_path']}")
|
||||||
clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics())
|
clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics())
|
||||||
|
|
||||||
@ -262,17 +262,20 @@ def main_test():
|
|||||||
|
|
||||||
if args.model_output == "both":
|
if args.model_output == "both":
|
||||||
c_pred, s_pred = pred
|
c_pred, s_pred = pred
|
||||||
|
results["client_pred"] = c_pred
|
||||||
|
results["server_pred"] = s_pred
|
||||||
elif args.model_output == "client":
|
elif args.model_output == "client":
|
||||||
c_pred = pred
|
results["client_pred"] = pred
|
||||||
s_pred = np.zeros(0)
|
|
||||||
else:
|
else:
|
||||||
c_pred = np.zeros(0)
|
results["server_pred"] = pred
|
||||||
s_pred = pred
|
# dataset.save_predictions(model_args["future_prediction"], c_pred, s_pred)
|
||||||
dataset.save_predictions(model_args["future_prediction"], c_pred, s_pred)
|
|
||||||
|
|
||||||
embd_model = load_model(model_args["embedding_model"])
|
embd_model = load_model(model_args["embedding_model"])
|
||||||
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||||
np.save(model_args["model_path"] + "/domain_embds.npy", domain_embeddings)
|
# np.save(model_args["model_path"] + "/domain_embds.npy", domain_embeddings)
|
||||||
|
|
||||||
|
results["domain_embds"] = domain_embeddings
|
||||||
|
joblib.dump(results, model_args["model_path"] + "results.joblib", compress=3)
|
||||||
|
|
||||||
|
|
||||||
def main_visualization():
|
def main_visualization():
|
||||||
@ -302,14 +305,16 @@ def main_visualization():
|
|||||||
client_pred, server_pred = client_pred.value.flatten(), server_pred.value.flatten()
|
client_pred, server_pred = client_pred.value.flatten(), server_pred.value.flatten()
|
||||||
logger.info("plot pr curve")
|
logger.info("plot pr curve")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_precision_recall(client_val, client_pred)
|
visualize.plot_precision_recall(client_val, client_pred, args.model_path)
|
||||||
|
visualize.plot_legend()
|
||||||
visualize.plot_save("{}/window_client_prc.png".format(args.model_path))
|
visualize.plot_save("{}/window_client_prc.png".format(args.model_path))
|
||||||
# visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path))
|
# visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path))
|
||||||
# visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path))
|
# visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path))
|
||||||
# visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path))
|
# visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path))
|
||||||
logger.info("plot roc curve")
|
logger.info("plot roc curve")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_roc_curve(client_val, client_pred)
|
visualize.plot_roc_curve(client_val, client_pred, args.model_path)
|
||||||
|
visualize.plot_legend()
|
||||||
visualize.plot_save("{}/window_client_roc.png".format(args.model_path))
|
visualize.plot_save("{}/window_client_roc.png".format(args.model_path))
|
||||||
# visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path))
|
# visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path))
|
||||||
|
|
||||||
@ -321,11 +326,13 @@ def main_visualization():
|
|||||||
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
||||||
|
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_precision_recall(user_vals, user_preds)
|
visualize.plot_precision_recall(user_vals, user_preds, args.model_path)
|
||||||
|
visualize.plot_legend()
|
||||||
visualize.plot_save("{}/user_client_prc.png".format(args.model_path))
|
visualize.plot_save("{}/user_client_prc.png".format(args.model_path))
|
||||||
|
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_roc_curve(user_vals, user_preds)
|
visualize.plot_roc_curve(user_vals, user_preds, args.model_path)
|
||||||
|
visualize.plot_legend()
|
||||||
visualize.plot_save("{}/user_client_roc.png".format(args.model_path))
|
visualize.plot_save("{}/user_client_roc.png".format(args.model_path))
|
||||||
|
|
||||||
visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(),
|
visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(),
|
||||||
@ -385,19 +392,6 @@ def main_visualize_all():
|
|||||||
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")
|
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")
|
||||||
|
|
||||||
|
|
||||||
def main_data():
|
|
||||||
char_dict = dataset.get_character_dict()
|
|
||||||
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
|
||||||
logger.info("create training dataset")
|
|
||||||
domain_tr, flow_tr, name_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict,
|
|
||||||
max_len=args.domain_length,
|
|
||||||
window_size=args.window)
|
|
||||||
print(f"domain shape {domain_tr.shape}")
|
|
||||||
print(f"flow shape {flow_tr.shape}")
|
|
||||||
print(f"client shape {client_tr.shape}")
|
|
||||||
print(f"server shape {server_tr.shape}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if "train" == args.mode:
|
if "train" == args.mode:
|
||||||
main_train()
|
main_train()
|
||||||
@ -411,8 +405,6 @@ def main():
|
|||||||
main_visualize_all()
|
main_visualize_all()
|
||||||
if "paul" == args.mode:
|
if "paul" == args.mode:
|
||||||
main_paul_best()
|
main_paul_best()
|
||||||
if "data" == args.mode:
|
|
||||||
main_data()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user