refactor training - separate staggered training; make differences as small as possible

This commit is contained in:
René Knaebel 2017-09-12 08:36:23 +02:00
parent 6ce8fb464f
commit 7f49021a63
2 changed files with 50 additions and 69 deletions

80
main.py
View File

@ -156,19 +156,37 @@ def main_train(param=None):
logger.info("class weights: set default") logger.info("class weights: set default")
custom_class_weights = None custom_class_weights = None
logger.info(f"select model: {args.model_type}")
if args.model_type == "staggered":
if not param: if not param:
param = PARAMS param = PARAMS
logger.info(f"Generator model with params: {param}") logger.info(f"Generator model with params: {param}")
embedding, model, new_model = models.get_models_by_params(param) embedding, model, new_model = models.get_models_by_params(param)
model = create_model(new_model, args.model_output) callbacks.append(LambdaCallback(
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
)
model = create_model(model, args.model_output)
new_model = create_model(new_model, args.model_output)
if args.model_type in ("inter", "staggered"):
server_tr = np.expand_dims(server_windows_tr, 2) server_tr = np.expand_dims(server_windows_tr, 2)
logger.info("compile and train model") model = new_model
embedding.summary()
model.summary() if args.model_output == "both":
labels = {"client": client_tr, "server": server_tr}
loss_weights = {"client": 1.0, "server": 1.0}
elif args.model_output == "client":
labels = {"client": client_tr}
loss_weights = {"client": 1.0}
elif args.model_output == "server":
labels = {"server": server_tr}
loss_weights = {"server": 1.0}
else:
raise ValueError("unknown model output")
logger.info(f"select model: {args.model_type}")
if args.model_type == "staggered":
logger.info("compile and pre-train server model")
logger.info(model.get_config()) logger.info(model.get_config())
model.compile(optimizer='adam', model.compile(optimizer='adam',
@ -184,64 +202,28 @@ def main_train(param=None):
validation_split=0.2, validation_split=0.2,
class_weight=custom_class_weights) class_weight=custom_class_weights)
logger.info("fix server model")
model.get_layer("domain_cnn").trainable = False
model.get_layer("dense_server").trainable = False model.get_layer("dense_server").trainable = False
model.get_layer("server").trainable = False model.get_layer("server").trainable = False
model.compile(optimizer='adam', loss_weights = {"client": 1.0, "server": 0.0}
loss='binary_crossentropy',
loss_weights={"client": 1.0, "server": 0.0},
metrics=['accuracy'] + custom_metrics)
model.summary()
callbacks.append(LambdaCallback(
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
)
model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr},
{"client": client_tr, "server": server_tr},
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks,
shuffle=True,
class_weight=custom_class_weights)
else:
if not param:
param = PARAMS
logger.info(f"Generator model with params: {param}")
embedding, model, new_model = models.get_models_by_params(param)
model = create_model(model, args.model_output)
new_model = create_model(new_model, args.model_output)
if args.model_type == "inter":
server_tr = np.expand_dims(server_windows_tr, 2)
model = new_model
logger.info("compile and train model") logger.info("compile and train model")
embedding.summary() embedding.summary()
model.summary()
logger.info(model.get_config()) logger.info(model.get_config())
model.compile(optimizer='adam', model.compile(optimizer='adam',
loss='binary_crossentropy', loss='binary_crossentropy',
loss_weights=loss_weights,
metrics=['accuracy'] + custom_metrics) metrics=['accuracy'] + custom_metrics)
if args.model_output == "both": model.summary()
labels = [client_tr, server_tr] model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr},
elif args.model_output == "client":
labels = [client_tr]
elif args.model_output == "server":
labels = [server_tr]
else:
raise ValueError("unknown model output")
callbacks.append(LambdaCallback(
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
)
model.fit([domain_tr, flow_tr],
labels, labels,
batch_size=args.batch_size, batch_size=args.batch_size,
epochs=args.epochs, epochs=args.epochs,
callbacks=callbacks, callbacks=callbacks,
shuffle=True, shuffle=True,
validation_split=0.3, validation_split=0.2,
class_weight=custom_class_weights) class_weight=custom_class_weights)

View File

@ -3,7 +3,6 @@ from collections import namedtuple
import keras import keras
from keras.engine import Input, Model as KerasModel from keras.engine import Input, Model as KerasModel
from keras.layers import Activation, Conv1D, Dense, Dropout, Embedding, GlobalMaxPooling1D, TimeDistributed from keras.layers import Activation, Conv1D, Dense, Dropout, Embedding, GlobalMaxPooling1D, TimeDistributed
from keras.regularizers import l2
import dataset import dataset
@ -58,7 +57,7 @@ def get_model(cnnDropout, flow_features, domain_features, window_size, domain_le
# remove temporal dimension by global max pooling # remove temporal dimension by global max pooling
y = GlobalMaxPooling1D()(y) y = GlobalMaxPooling1D()(y)
y = Dropout(cnnDropout)(y) y = Dropout(cnnDropout)(y)
y = Dense(dense_dim, kernel_regularizer=l2(0.1), activation='relu')(y) y = Dense(dense_dim, activation='relu')(y)
out_client = Dense(1, activation='sigmoid', name="client")(y) out_client = Dense(1, activation='sigmoid', name="client")(y)
out_server = Dense(1, activation='sigmoid', name="server")(y) out_server = Dense(1, activation='sigmoid', name="server")(y)