fix lazy domain loading and generation process
This commit is contained in:
parent
7f1d13658f
commit
6e7dc1297c
38
dataset.py
38
dataset.py
@ -152,6 +152,7 @@ def create_dataset_from_lists(chunks, vocab, max_len):
|
|||||||
:param max_len:
|
:param max_len:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_domain_features_reduced(d):
|
def get_domain_features_reduced(d):
|
||||||
return get_domain_features(d[0], vocab, max_len)
|
return get_domain_features(d[0], vocab, max_len)
|
||||||
|
|
||||||
@ -230,33 +231,40 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
|||||||
return load_h5dataset(h5data)
|
return load_h5dataset(h5data)
|
||||||
|
|
||||||
|
|
||||||
# TODO: implement csv loading if already generated
|
|
||||||
def load_or_generate_domains(train_data, domain_length):
|
def load_or_generate_domains(train_data, domain_length):
|
||||||
char_dict = get_character_dict()
|
fn = f"{train_data}_domains.gz"
|
||||||
user_flow_df = get_user_flow_data(train_data)
|
|
||||||
|
try:
|
||||||
|
user_flow_df = pd.read_csv(fn)
|
||||||
|
except Exception:
|
||||||
|
char_dict = get_character_dict()
|
||||||
|
user_flow_df = get_user_flow_data(train_data)
|
||||||
|
user_flow_df.reset_index(inplace=True)
|
||||||
|
user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0,
|
||||||
|
how="any")
|
||||||
|
user_flow_df = user_flow_df.groupby(user_flow_df.domain).mean()
|
||||||
|
user_flow_df.reset_index(inplace=True)
|
||||||
|
|
||||||
|
user_flow_df["clientLabel"] = np.where(
|
||||||
|
np.logical_or(user_flow_df.trustedHits > 0, user_flow_df.virusTotalHits >= 3), True, False)
|
||||||
|
user_flow_df[["serverLabel", "clientLabel"]] = user_flow_df[["serverLabel", "clientLabel"]].astype(bool)
|
||||||
|
user_flow_df = user_flow_df[["domain", "serverLabel", "clientLabel"]]
|
||||||
|
|
||||||
|
user_flow_df.to_csv(fn, compression="gzip")
|
||||||
|
|
||||||
domain_encs = user_flow_df.domain.apply(lambda d: get_domain_features(d, char_dict, domain_length))
|
domain_encs = user_flow_df.domain.apply(lambda d: get_domain_features(d, char_dict, domain_length))
|
||||||
domain_encs = np.stack(domain_encs)
|
domain_encs = np.stack(domain_encs)
|
||||||
|
|
||||||
user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0, how="any")
|
return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix().astype(bool)
|
||||||
user_flow_df.reset_index(inplace=True)
|
|
||||||
user_flow_df["clientLabel"] = np.where(
|
|
||||||
np.logical_or(user_flow_df.trustedHits > 0, user_flow_df.virusTotalHits >= 3), 1.0, 0.0)
|
|
||||||
user_flow_df = user_flow_df[["domain", "serverLabel", "clientLabel"]]
|
|
||||||
user_flow_df.groupby(user_flow_df.domain).mean()
|
|
||||||
|
|
||||||
return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix()
|
|
||||||
|
|
||||||
|
|
||||||
def save_predictions(path, c_pred, s_pred, embd, labels):
|
def save_predictions(path, c_pred, s_pred):
|
||||||
f = h5py.File(path, "w")
|
f = h5py.File(path, "w")
|
||||||
f.create_dataset("client", data=c_pred)
|
f.create_dataset("client", data=c_pred)
|
||||||
f.create_dataset("server", data=s_pred)
|
f.create_dataset("server", data=s_pred)
|
||||||
f.create_dataset("embedding", data=embd)
|
|
||||||
f.create_dataset("labels", data=labels)
|
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
|
|
||||||
def load_predictions(path):
|
def load_predictions(path):
|
||||||
f = h5py.File(path, "r")
|
f = h5py.File(path, "r")
|
||||||
return f["client"], f["server"], f["embedding"], f["labels"]
|
return f["client"], f["server"]
|
||||||
|
14
main.py
14
main.py
@ -194,13 +194,12 @@ def main_test():
|
|||||||
else:
|
else:
|
||||||
c_pred = np.zeros(0)
|
c_pred = np.zeros(0)
|
||||||
s_pred = pred
|
s_pred = pred
|
||||||
|
dataset.save_predictions(args.future_prediction, c_pred, s_pred)
|
||||||
|
|
||||||
model = load_model(args.embedding_model)
|
model = load_model(args.embedding_model)
|
||||||
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||||
domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||||
|
np.save(args.model_path + "/domain_embds.npy", domain_embedding)
|
||||||
dataset.save_predictions(args.future_prediction, c_pred, s_pred, domain_embedding, labels)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main_visualization():
|
def main_visualization():
|
||||||
@ -213,12 +212,15 @@ def main_visualization():
|
|||||||
try:
|
try:
|
||||||
logger.info("plot training curve")
|
logger.info("plot training curve")
|
||||||
logs = pd.read_csv(args.train_log)
|
logs = pd.read_csv(args.train_log)
|
||||||
visualize.plot_training_curve(logs, "client", "{}/client_train.png".format(args.model_path))
|
if args.model_output == "client":
|
||||||
visualize.plot_training_curve(logs, "server", "{}/server_train.png".format(args.model_path))
|
visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
|
||||||
|
else:
|
||||||
|
visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
|
||||||
|
visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"could not generate training curves: {e}")
|
logger.warning(f"could not generate training curves: {e}")
|
||||||
|
|
||||||
client_pred, server_pred, domain_embedding, labels = dataset.load_predictions(args.future_prediction)
|
client_pred, server_pred = dataset.load_predictions(args.future_prediction)
|
||||||
client_pred, server_pred = client_pred.value, server_pred.value
|
client_pred, server_pred = client_pred.value, server_pred.value
|
||||||
logger.info("plot pr curve")
|
logger.info("plot pr curve")
|
||||||
visualize.plot_precision_recall(client_val, client_pred.flatten(), "{}/client_prc.png".format(args.model_path))
|
visualize.plot_precision_recall(client_val, client_pred.flatten(), "{}/client_prc.png".format(args.model_path))
|
||||||
|
@ -132,11 +132,11 @@ def plot_confusion_matrix(y_true, y_pred, path,
|
|||||||
|
|
||||||
def plot_training_curve(logs, key, path, dpi=600):
|
def plot_training_curve(logs, key, path, dpi=600):
|
||||||
plt.clf()
|
plt.clf()
|
||||||
plt.plot(logs[f"{key}_acc"], label="accuracy")
|
plt.plot(logs[f"{key}acc"], label="accuracy")
|
||||||
plt.plot(logs[f"{key}_f1_score"], label="f1_score")
|
plt.plot(logs[f"{key}f1_score"], label="f1_score")
|
||||||
|
|
||||||
plt.plot(logs[f"val_{key}_acc"], label="accuracy")
|
plt.plot(logs[f"val_{key}acc"], label="accuracy")
|
||||||
plt.plot(logs[f"val_{key}_f1_score"], label="val_f1_score")
|
plt.plot(logs[f"val_{key}f1_score"], label="val_f1_score")
|
||||||
|
|
||||||
plt.xlabel('epoch')
|
plt.xlabel('epoch')
|
||||||
plt.ylabel('percentage')
|
plt.ylabel('percentage')
|
||||||
|
Loading…
Reference in New Issue
Block a user