fix lazy domain loading and generation process

This commit is contained in:
René Knaebel 2017-08-03 12:27:17 +02:00
parent 7f1d13658f
commit 6e7dc1297c
3 changed files with 35 additions and 25 deletions

View File

@ -152,6 +152,7 @@ def create_dataset_from_lists(chunks, vocab, max_len):
:param max_len: :param max_len:
:return: :return:
""" """
def get_domain_features_reduced(d): def get_domain_features_reduced(d):
return get_domain_features(d[0], vocab, max_len) return get_domain_features(d[0], vocab, max_len)
@ -230,33 +231,40 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
return load_h5dataset(h5data) return load_h5dataset(h5data)
# TODO: implement csv loading if already generated
def load_or_generate_domains(train_data, domain_length): def load_or_generate_domains(train_data, domain_length):
char_dict = get_character_dict() fn = f"{train_data}_domains.gz"
user_flow_df = get_user_flow_data(train_data)
try:
user_flow_df = pd.read_csv(fn)
except Exception:
char_dict = get_character_dict()
user_flow_df = get_user_flow_data(train_data)
user_flow_df.reset_index(inplace=True)
user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0,
how="any")
user_flow_df = user_flow_df.groupby(user_flow_df.domain).mean()
user_flow_df.reset_index(inplace=True)
user_flow_df["clientLabel"] = np.where(
np.logical_or(user_flow_df.trustedHits > 0, user_flow_df.virusTotalHits >= 3), True, False)
user_flow_df[["serverLabel", "clientLabel"]] = user_flow_df[["serverLabel", "clientLabel"]].astype(bool)
user_flow_df = user_flow_df[["domain", "serverLabel", "clientLabel"]]
user_flow_df.to_csv(fn, compression="gzip")
domain_encs = user_flow_df.domain.apply(lambda d: get_domain_features(d, char_dict, domain_length)) domain_encs = user_flow_df.domain.apply(lambda d: get_domain_features(d, char_dict, domain_length))
domain_encs = np.stack(domain_encs) domain_encs = np.stack(domain_encs)
user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0, how="any") return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix().astype(bool)
user_flow_df.reset_index(inplace=True)
user_flow_df["clientLabel"] = np.where(
np.logical_or(user_flow_df.trustedHits > 0, user_flow_df.virusTotalHits >= 3), 1.0, 0.0)
user_flow_df = user_flow_df[["domain", "serverLabel", "clientLabel"]]
user_flow_df.groupby(user_flow_df.domain).mean()
return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix()
def save_predictions(path, c_pred, s_pred, embd, labels): def save_predictions(path, c_pred, s_pred):
f = h5py.File(path, "w") f = h5py.File(path, "w")
f.create_dataset("client", data=c_pred) f.create_dataset("client", data=c_pred)
f.create_dataset("server", data=s_pred) f.create_dataset("server", data=s_pred)
f.create_dataset("embedding", data=embd)
f.create_dataset("labels", data=labels)
f.close() f.close()
def load_predictions(path): def load_predictions(path):
f = h5py.File(path, "r") f = h5py.File(path, "r")
return f["client"], f["server"], f["embedding"], f["labels"] return f["client"], f["server"]

14
main.py
View File

@ -194,13 +194,12 @@ def main_test():
else: else:
c_pred = np.zeros(0) c_pred = np.zeros(0)
s_pred = pred s_pred = pred
dataset.save_predictions(args.future_prediction, c_pred, s_pred)
model = load_model(args.embedding_model) model = load_model(args.embedding_model)
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length) domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1) domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
np.save(args.model_path + "/domain_embds.npy", domain_embedding)
dataset.save_predictions(args.future_prediction, c_pred, s_pred, domain_embedding, labels)
def main_visualization(): def main_visualization():
@ -213,12 +212,15 @@ def main_visualization():
try: try:
logger.info("plot training curve") logger.info("plot training curve")
logs = pd.read_csv(args.train_log) logs = pd.read_csv(args.train_log)
visualize.plot_training_curve(logs, "client", "{}/client_train.png".format(args.model_path)) if args.model_output == "client":
visualize.plot_training_curve(logs, "server", "{}/server_train.png".format(args.model_path)) visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
else:
visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
except Exception as e: except Exception as e:
logger.warning(f"could not generate training curves: {e}") logger.warning(f"could not generate training curves: {e}")
client_pred, server_pred, domain_embedding, labels = dataset.load_predictions(args.future_prediction) client_pred, server_pred = dataset.load_predictions(args.future_prediction)
client_pred, server_pred = client_pred.value, server_pred.value client_pred, server_pred = client_pred.value, server_pred.value
logger.info("plot pr curve") logger.info("plot pr curve")
visualize.plot_precision_recall(client_val, client_pred.flatten(), "{}/client_prc.png".format(args.model_path)) visualize.plot_precision_recall(client_val, client_pred.flatten(), "{}/client_prc.png".format(args.model_path))

View File

@ -132,11 +132,11 @@ def plot_confusion_matrix(y_true, y_pred, path,
def plot_training_curve(logs, key, path, dpi=600): def plot_training_curve(logs, key, path, dpi=600):
plt.clf() plt.clf()
plt.plot(logs[f"{key}_acc"], label="accuracy") plt.plot(logs[f"{key}acc"], label="accuracy")
plt.plot(logs[f"{key}_f1_score"], label="f1_score") plt.plot(logs[f"{key}f1_score"], label="f1_score")
plt.plot(logs[f"val_{key}_acc"], label="accuracy") plt.plot(logs[f"val_{key}acc"], label="accuracy")
plt.plot(logs[f"val_{key}_f1_score"], label="val_f1_score") plt.plot(logs[f"val_{key}f1_score"], label="val_f1_score")
plt.xlabel('epoch') plt.xlabel('epoch')
plt.ylabel('percentage') plt.ylabel('percentage')