From 5741f8ee0e24182dc805365b8eaf138d71f92fd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Fri, 6 Oct 2017 10:38:00 +0200 Subject: [PATCH] fix staggered training --- main.py | 31 +++++++++++++++++++++++++++++++ utils.py | 6 +++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index b904961..a96eb2d 100644 --- a/main.py +++ b/main.py @@ -201,6 +201,7 @@ def main_train(param=None): loss_weights={"client": 0.0, "server": 1.0}, metrics=['accuracy'] + custom_metrics) + model.summary() model.fit(features, labels, batch_size=args.batch_size, epochs=args.epochs, @@ -208,6 +209,7 @@ def main_train(param=None): logger.info("fix server model") model.get_layer("domain_cnn").trainable = False + model.get_layer("domain_cnn").layer.trainable = False model.get_layer("dense_server").trainable = False model.get_layer("server").trainable = False loss_weights = {"client": 1.0, "server": 0.0} @@ -649,6 +651,33 @@ def train_server_only(): callbacks=callbacks) +def test_server_only(): + logger.info("start test: load data") + domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, + args.data, + args.domain_length, + args.window) + domain_val = domain_val.value.reshape(-1, 40) + flow_val = flow_val.value.reshape(-1, 3) + domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length) + + for model_args in get_model_args(args): + results = {} + logger.info(f"process model {model_args['model_path']}") + embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects()) + + pred = clf_model.predict([domain_val, flow_val], + batch_size=args.batch_size, + verbose=1) + + results["server_pred"] = pred + + domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1) + results["domain_embds"] = domain_embeddings + + dataset.save_predictions(model_args["model_path"], results) + + def main(): if "train" == args.mode: main_train() @@ -668,6 +697,8 @@ def main(): plot_overall_result() if "server" == args.mode: train_server_only() + if "server_test" == args.mode: + test_server_only() if __name__ == "__main__": diff --git a/utils.py b/utils.py index fc8d4c9..9fa6bf6 100644 --- a/utils.py +++ b/utils.py @@ -37,6 +37,10 @@ def load_model(path, custom_objects=None): except Exception: # in some version i forgot to specify domain_cnn # this bug fix is for certain compatibility - embd = clf.layers[1].layer + try: + embd = clf.layers[1].layer + except Exception: + embd = clf.get_layer("domain_cnn") + return embd, clf