refactor training - separate staggered training; make differences as small as possible
This commit is contained in:
parent
6ce8fb464f
commit
7f49021a63
116
main.py
116
main.py
@ -156,19 +156,37 @@ def main_train(param=None):
|
||||
logger.info("class weights: set default")
|
||||
custom_class_weights = None
|
||||
|
||||
if not param:
|
||||
param = PARAMS
|
||||
logger.info(f"Generator model with params: {param}")
|
||||
embedding, model, new_model = models.get_models_by_params(param)
|
||||
|
||||
callbacks.append(LambdaCallback(
|
||||
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
|
||||
)
|
||||
|
||||
model = create_model(model, args.model_output)
|
||||
new_model = create_model(new_model, args.model_output)
|
||||
|
||||
if args.model_type in ("inter", "staggered"):
|
||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||
model = new_model
|
||||
|
||||
if args.model_output == "both":
|
||||
labels = {"client": client_tr, "server": server_tr}
|
||||
loss_weights = {"client": 1.0, "server": 1.0}
|
||||
elif args.model_output == "client":
|
||||
labels = {"client": client_tr}
|
||||
loss_weights = {"client": 1.0}
|
||||
elif args.model_output == "server":
|
||||
labels = {"server": server_tr}
|
||||
loss_weights = {"server": 1.0}
|
||||
else:
|
||||
raise ValueError("unknown model output")
|
||||
|
||||
logger.info(f"select model: {args.model_type}")
|
||||
if args.model_type == "staggered":
|
||||
if not param:
|
||||
param = PARAMS
|
||||
logger.info(f"Generator model with params: {param}")
|
||||
embedding, model, new_model = models.get_models_by_params(param)
|
||||
|
||||
model = create_model(new_model, args.model_output)
|
||||
|
||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||
logger.info("compile and train model")
|
||||
embedding.summary()
|
||||
model.summary()
|
||||
logger.info("compile and pre-train server model")
|
||||
logger.info(model.get_config())
|
||||
|
||||
model.compile(optimizer='adam',
|
||||
@ -183,66 +201,30 @@ def main_train(param=None):
|
||||
shuffle=True,
|
||||
validation_split=0.2,
|
||||
class_weight=custom_class_weights)
|
||||
|
||||
|
||||
logger.info("fix server model")
|
||||
model.get_layer("domain_cnn").trainable = False
|
||||
model.get_layer("dense_server").trainable = False
|
||||
model.get_layer("server").trainable = False
|
||||
model.compile(optimizer='adam',
|
||||
loss='binary_crossentropy',
|
||||
loss_weights={"client": 1.0, "server": 0.0},
|
||||
metrics=['accuracy'] + custom_metrics)
|
||||
loss_weights = {"client": 1.0, "server": 0.0}
|
||||
|
||||
model.summary()
|
||||
callbacks.append(LambdaCallback(
|
||||
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
|
||||
)
|
||||
model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr},
|
||||
{"client": client_tr, "server": server_tr},
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
callbacks=callbacks,
|
||||
shuffle=True,
|
||||
class_weight=custom_class_weights)
|
||||
logger.info("compile and train model")
|
||||
embedding.summary()
|
||||
logger.info(model.get_config())
|
||||
model.compile(optimizer='adam',
|
||||
loss='binary_crossentropy',
|
||||
loss_weights=loss_weights,
|
||||
metrics=['accuracy'] + custom_metrics)
|
||||
|
||||
else:
|
||||
if not param:
|
||||
param = PARAMS
|
||||
logger.info(f"Generator model with params: {param}")
|
||||
embedding, model, new_model = models.get_models_by_params(param)
|
||||
|
||||
model = create_model(model, args.model_output)
|
||||
new_model = create_model(new_model, args.model_output)
|
||||
|
||||
if args.model_type == "inter":
|
||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||
model = new_model
|
||||
logger.info("compile and train model")
|
||||
embedding.summary()
|
||||
model.summary()
|
||||
logger.info(model.get_config())
|
||||
model.compile(optimizer='adam',
|
||||
loss='binary_crossentropy',
|
||||
metrics=['accuracy'] + custom_metrics)
|
||||
|
||||
if args.model_output == "both":
|
||||
labels = [client_tr, server_tr]
|
||||
elif args.model_output == "client":
|
||||
labels = [client_tr]
|
||||
elif args.model_output == "server":
|
||||
labels = [server_tr]
|
||||
else:
|
||||
raise ValueError("unknown model output")
|
||||
|
||||
callbacks.append(LambdaCallback(
|
||||
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
|
||||
)
|
||||
model.fit([domain_tr, flow_tr],
|
||||
labels,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
callbacks=callbacks,
|
||||
shuffle=True,
|
||||
validation_split=0.3,
|
||||
class_weight=custom_class_weights)
|
||||
model.summary()
|
||||
model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr},
|
||||
labels,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
callbacks=callbacks,
|
||||
shuffle=True,
|
||||
validation_split=0.2,
|
||||
class_weight=custom_class_weights)
|
||||
|
||||
|
||||
def main_test():
|
||||
|
@ -3,7 +3,6 @@ from collections import namedtuple
|
||||
import keras
|
||||
from keras.engine import Input, Model as KerasModel
|
||||
from keras.layers import Activation, Conv1D, Dense, Dropout, Embedding, GlobalMaxPooling1D, TimeDistributed
|
||||
from keras.regularizers import l2
|
||||
|
||||
import dataset
|
||||
|
||||
@ -58,7 +57,7 @@ def get_model(cnnDropout, flow_features, domain_features, window_size, domain_le
|
||||
# remove temporal dimension by global max pooling
|
||||
y = GlobalMaxPooling1D()(y)
|
||||
y = Dropout(cnnDropout)(y)
|
||||
y = Dense(dense_dim, kernel_regularizer=l2(0.1), activation='relu')(y)
|
||||
y = Dense(dense_dim, activation='relu')(y)
|
||||
out_client = Dense(1, activation='sigmoid', name="client")(y)
|
||||
out_server = Dense(1, activation='sigmoid', name="server")(y)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user