refactor training - separate staggered training; make differences as small as possible

This commit is contained in:
René Knaebel 2017-09-12 08:36:23 +02:00
parent 6ce8fb464f
commit 7f49021a63
2 changed files with 50 additions and 69 deletions

116
main.py
View File

@ -156,19 +156,37 @@ def main_train(param=None):
logger.info("class weights: set default") logger.info("class weights: set default")
custom_class_weights = None custom_class_weights = None
if not param:
param = PARAMS
logger.info(f"Generator model with params: {param}")
embedding, model, new_model = models.get_models_by_params(param)
callbacks.append(LambdaCallback(
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
)
model = create_model(model, args.model_output)
new_model = create_model(new_model, args.model_output)
if args.model_type in ("inter", "staggered"):
server_tr = np.expand_dims(server_windows_tr, 2)
model = new_model
if args.model_output == "both":
labels = {"client": client_tr, "server": server_tr}
loss_weights = {"client": 1.0, "server": 1.0}
elif args.model_output == "client":
labels = {"client": client_tr}
loss_weights = {"client": 1.0}
elif args.model_output == "server":
labels = {"server": server_tr}
loss_weights = {"server": 1.0}
else:
raise ValueError("unknown model output")
logger.info(f"select model: {args.model_type}") logger.info(f"select model: {args.model_type}")
if args.model_type == "staggered": if args.model_type == "staggered":
if not param: logger.info("compile and pre-train server model")
param = PARAMS
logger.info(f"Generator model with params: {param}")
embedding, model, new_model = models.get_models_by_params(param)
model = create_model(new_model, args.model_output)
server_tr = np.expand_dims(server_windows_tr, 2)
logger.info("compile and train model")
embedding.summary()
model.summary()
logger.info(model.get_config()) logger.info(model.get_config())
model.compile(optimizer='adam', model.compile(optimizer='adam',
@ -183,66 +201,30 @@ def main_train(param=None):
shuffle=True, shuffle=True,
validation_split=0.2, validation_split=0.2,
class_weight=custom_class_weights) class_weight=custom_class_weights)
logger.info("fix server model")
model.get_layer("domain_cnn").trainable = False
model.get_layer("dense_server").trainable = False model.get_layer("dense_server").trainable = False
model.get_layer("server").trainable = False model.get_layer("server").trainable = False
model.compile(optimizer='adam', loss_weights = {"client": 1.0, "server": 0.0}
loss='binary_crossentropy',
loss_weights={"client": 1.0, "server": 0.0},
metrics=['accuracy'] + custom_metrics)
model.summary() logger.info("compile and train model")
callbacks.append(LambdaCallback( embedding.summary()
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model)) logger.info(model.get_config())
) model.compile(optimizer='adam',
model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr}, loss='binary_crossentropy',
{"client": client_tr, "server": server_tr}, loss_weights=loss_weights,
batch_size=args.batch_size, metrics=['accuracy'] + custom_metrics)
epochs=args.epochs,
callbacks=callbacks,
shuffle=True,
class_weight=custom_class_weights)
else: model.summary()
if not param: model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr},
param = PARAMS labels,
logger.info(f"Generator model with params: {param}") batch_size=args.batch_size,
embedding, model, new_model = models.get_models_by_params(param) epochs=args.epochs,
callbacks=callbacks,
model = create_model(model, args.model_output) shuffle=True,
new_model = create_model(new_model, args.model_output) validation_split=0.2,
class_weight=custom_class_weights)
if args.model_type == "inter":
server_tr = np.expand_dims(server_windows_tr, 2)
model = new_model
logger.info("compile and train model")
embedding.summary()
model.summary()
logger.info(model.get_config())
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'] + custom_metrics)
if args.model_output == "both":
labels = [client_tr, server_tr]
elif args.model_output == "client":
labels = [client_tr]
elif args.model_output == "server":
labels = [server_tr]
else:
raise ValueError("unknown model output")
callbacks.append(LambdaCallback(
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
)
model.fit([domain_tr, flow_tr],
labels,
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks,
shuffle=True,
validation_split=0.3,
class_weight=custom_class_weights)
def main_test(): def main_test():

View File

@ -3,7 +3,6 @@ from collections import namedtuple
import keras import keras
from keras.engine import Input, Model as KerasModel from keras.engine import Input, Model as KerasModel
from keras.layers import Activation, Conv1D, Dense, Dropout, Embedding, GlobalMaxPooling1D, TimeDistributed from keras.layers import Activation, Conv1D, Dense, Dropout, Embedding, GlobalMaxPooling1D, TimeDistributed
from keras.regularizers import l2
import dataset import dataset
@ -58,7 +57,7 @@ def get_model(cnnDropout, flow_features, domain_features, window_size, domain_le
# remove temporal dimension by global max pooling # remove temporal dimension by global max pooling
y = GlobalMaxPooling1D()(y) y = GlobalMaxPooling1D()(y)
y = Dropout(cnnDropout)(y) y = Dropout(cnnDropout)(y)
y = Dense(dense_dim, kernel_regularizer=l2(0.1), activation='relu')(y) y = Dense(dense_dim, activation='relu')(y)
out_client = Dense(1, activation='sigmoid', name="client")(y) out_client = Dense(1, activation='sigmoid', name="client")(y)
out_server = Dense(1, activation='sigmoid', name="server")(y) out_server = Dense(1, activation='sigmoid', name="server")(y)