From 18b60e1754d34e45ee5a41333ab994c4fee75249 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Mon, 17 Jul 2017 19:30:56 +0200 Subject: [PATCH] add extended test mode for embeddings --- dataset.py | 1 - main.py | 34 ++++++++++++++++++++++++++++------ 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/dataset.py b/dataset.py index 38d7cd2..3c66df3 100644 --- a/dataset.py +++ b/dataset.py @@ -216,7 +216,6 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size): logger.info(f"check for h5data {h5data}") try: open(h5data, "r") - raise FileNotFoundError except FileNotFoundError: logger.info("h5 data not found - load csv file") user_flow_df = get_user_flow_data(train_data) diff --git a/main.py b/main.py index eb938b4..7464819 100644 --- a/main.py +++ b/main.py @@ -198,6 +198,25 @@ def main_test(): verbose=1) np.save(args.future_prediction, y_pred) + char_dict = dataset.get_character_dict() + user_flow_df = dataset.get_user_flow_data(args.test_data) + domains = user_flow_df.domain.unique() + + def get_domain_features_reduced(d): + return dataset.get_domain_features(d[0], char_dict, args.domain_length) + + domain_features = [] + for ds in domains: + domain_features.append(np.apply_along_axis(get_domain_features_reduced, 2, np.atleast_3d(ds))) + + model = load_model(args.embedding_model) + domain_features = np.stack(domain_features).reshape((-1, 40)) + pred = model.predict(domains, batch_size=args.batch_size, verbose=1) + + np.save("/tmp/rk/domains.npy", domains) + np.save("/tmp/rk/domain_features.npy", domain_features) + np.save("/tmp/rk/domain_embd.npy", pred) + def main_visualization(): domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data, @@ -205,10 +224,13 @@ def main_visualization(): logger.info("plot model") model = load_model(args.clf_model, custom_objects=models.get_metrics()) visualize.plot_model(model, os.path.join(args.model_path, "model.png")) - logger.info("plot training curve") - logs = pd.read_csv(args.train_log) - visualize.plot_training_curve(logs, "client", "{}/client_train.png".format(args.model_path)) - visualize.plot_training_curve(logs, "server", "{}/server_train.png".format(args.model_path)) + try: + logger.info("plot training curve") + logs = pd.read_csv(args.train_log) + visualize.plot_training_curve(logs, "client", "{}/client_train.png".format(args.model_path)) + visualize.plot_training_curve(logs, "server", "{}/server_train.png".format(args.model_path)) + except Exception as e: + logger.warning(f"could not generate training curves: {e}") client_pred, server_pred = np.load(args.future_prediction) logger.info("plot pr curve") @@ -230,8 +252,8 @@ def main_visualization(): import matplotlib.pyplot as plt model = load_model(args.embedding_model) - domains = np.reshape(domain_val, (12800, 40)) - domain_embedding = model.predict(domains) + domains = np.reshape(domain_val, (domain_val.shape[0] * domain_val.shape[1], 40)) + domain_embedding = model.predict(domains, batch_size=args.batch_size, verbose=1) pca = PCA(n_components=2) domain_reduced = pca.fit_transform(domain_embedding)