From 7f49021a63a68cce11c256221b0e8881e2870c0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Tue, 12 Sep 2017 08:36:23 +0200 Subject: [PATCH] refactor training - separate staggered training; make differences as small as possible --- main.py | 116 +++++++++++++++++---------------------- models/pauls_networks.py | 3 +- 2 files changed, 50 insertions(+), 69 deletions(-) diff --git a/main.py b/main.py index 17ab37b..2a2f736 100644 --- a/main.py +++ b/main.py @@ -156,19 +156,37 @@ def main_train(param=None): logger.info("class weights: set default") custom_class_weights = None + if not param: + param = PARAMS + logger.info(f"Generator model with params: {param}") + embedding, model, new_model = models.get_models_by_params(param) + + callbacks.append(LambdaCallback( + on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model)) + ) + + model = create_model(model, args.model_output) + new_model = create_model(new_model, args.model_output) + + if args.model_type in ("inter", "staggered"): + server_tr = np.expand_dims(server_windows_tr, 2) + model = new_model + + if args.model_output == "both": + labels = {"client": client_tr, "server": server_tr} + loss_weights = {"client": 1.0, "server": 1.0} + elif args.model_output == "client": + labels = {"client": client_tr} + loss_weights = {"client": 1.0} + elif args.model_output == "server": + labels = {"server": server_tr} + loss_weights = {"server": 1.0} + else: + raise ValueError("unknown model output") + logger.info(f"select model: {args.model_type}") if args.model_type == "staggered": - if not param: - param = PARAMS - logger.info(f"Generator model with params: {param}") - embedding, model, new_model = models.get_models_by_params(param) - - model = create_model(new_model, args.model_output) - - server_tr = np.expand_dims(server_windows_tr, 2) - logger.info("compile and train model") - embedding.summary() - model.summary() + logger.info("compile and pre-train server model") logger.info(model.get_config()) model.compile(optimizer='adam', @@ -183,66 +201,30 @@ def main_train(param=None): shuffle=True, validation_split=0.2, class_weight=custom_class_weights) - + + logger.info("fix server model") + model.get_layer("domain_cnn").trainable = False model.get_layer("dense_server").trainable = False model.get_layer("server").trainable = False - model.compile(optimizer='adam', - loss='binary_crossentropy', - loss_weights={"client": 1.0, "server": 0.0}, - metrics=['accuracy'] + custom_metrics) + loss_weights = {"client": 1.0, "server": 0.0} - model.summary() - callbacks.append(LambdaCallback( - on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model)) - ) - model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr}, - {"client": client_tr, "server": server_tr}, - batch_size=args.batch_size, - epochs=args.epochs, - callbacks=callbacks, - shuffle=True, - class_weight=custom_class_weights) + logger.info("compile and train model") + embedding.summary() + logger.info(model.get_config()) + model.compile(optimizer='adam', + loss='binary_crossentropy', + loss_weights=loss_weights, + metrics=['accuracy'] + custom_metrics) - else: - if not param: - param = PARAMS - logger.info(f"Generator model with params: {param}") - embedding, model, new_model = models.get_models_by_params(param) - - model = create_model(model, args.model_output) - new_model = create_model(new_model, args.model_output) - - if args.model_type == "inter": - server_tr = np.expand_dims(server_windows_tr, 2) - model = new_model - logger.info("compile and train model") - embedding.summary() - model.summary() - logger.info(model.get_config()) - model.compile(optimizer='adam', - loss='binary_crossentropy', - metrics=['accuracy'] + custom_metrics) - - if args.model_output == "both": - labels = [client_tr, server_tr] - elif args.model_output == "client": - labels = [client_tr] - elif args.model_output == "server": - labels = [server_tr] - else: - raise ValueError("unknown model output") - - callbacks.append(LambdaCallback( - on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model)) - ) - model.fit([domain_tr, flow_tr], - labels, - batch_size=args.batch_size, - epochs=args.epochs, - callbacks=callbacks, - shuffle=True, - validation_split=0.3, - class_weight=custom_class_weights) + model.summary() + model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr}, + labels, + batch_size=args.batch_size, + epochs=args.epochs, + callbacks=callbacks, + shuffle=True, + validation_split=0.2, + class_weight=custom_class_weights) def main_test(): diff --git a/models/pauls_networks.py b/models/pauls_networks.py index 0d26ec3..4f216db 100644 --- a/models/pauls_networks.py +++ b/models/pauls_networks.py @@ -3,7 +3,6 @@ from collections import namedtuple import keras from keras.engine import Input, Model as KerasModel from keras.layers import Activation, Conv1D, Dense, Dropout, Embedding, GlobalMaxPooling1D, TimeDistributed -from keras.regularizers import l2 import dataset @@ -58,7 +57,7 @@ def get_model(cnnDropout, flow_features, domain_features, window_size, domain_le # remove temporal dimension by global max pooling y = GlobalMaxPooling1D()(y) y = Dropout(cnnDropout)(y) - y = Dense(dense_dim, kernel_regularizer=l2(0.1), activation='relu')(y) + y = Dense(dense_dim, activation='relu')(y) out_client = Dense(1, activation='sigmoid', name="client")(y) out_server = Dense(1, activation='sigmoid', name="server")(y)