refactor training - separate staggered training; make differences as small as possible
This commit is contained in:
parent
6ce8fb464f
commit
7f49021a63
116
main.py
116
main.py
@ -156,19 +156,37 @@ def main_train(param=None):
|
|||||||
logger.info("class weights: set default")
|
logger.info("class weights: set default")
|
||||||
custom_class_weights = None
|
custom_class_weights = None
|
||||||
|
|
||||||
|
if not param:
|
||||||
|
param = PARAMS
|
||||||
|
logger.info(f"Generator model with params: {param}")
|
||||||
|
embedding, model, new_model = models.get_models_by_params(param)
|
||||||
|
|
||||||
|
callbacks.append(LambdaCallback(
|
||||||
|
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
|
||||||
|
)
|
||||||
|
|
||||||
|
model = create_model(model, args.model_output)
|
||||||
|
new_model = create_model(new_model, args.model_output)
|
||||||
|
|
||||||
|
if args.model_type in ("inter", "staggered"):
|
||||||
|
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||||
|
model = new_model
|
||||||
|
|
||||||
|
if args.model_output == "both":
|
||||||
|
labels = {"client": client_tr, "server": server_tr}
|
||||||
|
loss_weights = {"client": 1.0, "server": 1.0}
|
||||||
|
elif args.model_output == "client":
|
||||||
|
labels = {"client": client_tr}
|
||||||
|
loss_weights = {"client": 1.0}
|
||||||
|
elif args.model_output == "server":
|
||||||
|
labels = {"server": server_tr}
|
||||||
|
loss_weights = {"server": 1.0}
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown model output")
|
||||||
|
|
||||||
logger.info(f"select model: {args.model_type}")
|
logger.info(f"select model: {args.model_type}")
|
||||||
if args.model_type == "staggered":
|
if args.model_type == "staggered":
|
||||||
if not param:
|
logger.info("compile and pre-train server model")
|
||||||
param = PARAMS
|
|
||||||
logger.info(f"Generator model with params: {param}")
|
|
||||||
embedding, model, new_model = models.get_models_by_params(param)
|
|
||||||
|
|
||||||
model = create_model(new_model, args.model_output)
|
|
||||||
|
|
||||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
|
||||||
logger.info("compile and train model")
|
|
||||||
embedding.summary()
|
|
||||||
model.summary()
|
|
||||||
logger.info(model.get_config())
|
logger.info(model.get_config())
|
||||||
|
|
||||||
model.compile(optimizer='adam',
|
model.compile(optimizer='adam',
|
||||||
@ -183,66 +201,30 @@ def main_train(param=None):
|
|||||||
shuffle=True,
|
shuffle=True,
|
||||||
validation_split=0.2,
|
validation_split=0.2,
|
||||||
class_weight=custom_class_weights)
|
class_weight=custom_class_weights)
|
||||||
|
|
||||||
|
logger.info("fix server model")
|
||||||
|
model.get_layer("domain_cnn").trainable = False
|
||||||
model.get_layer("dense_server").trainable = False
|
model.get_layer("dense_server").trainable = False
|
||||||
model.get_layer("server").trainable = False
|
model.get_layer("server").trainable = False
|
||||||
model.compile(optimizer='adam',
|
loss_weights = {"client": 1.0, "server": 0.0}
|
||||||
loss='binary_crossentropy',
|
|
||||||
loss_weights={"client": 1.0, "server": 0.0},
|
|
||||||
metrics=['accuracy'] + custom_metrics)
|
|
||||||
|
|
||||||
model.summary()
|
logger.info("compile and train model")
|
||||||
callbacks.append(LambdaCallback(
|
embedding.summary()
|
||||||
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
|
logger.info(model.get_config())
|
||||||
)
|
model.compile(optimizer='adam',
|
||||||
model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr},
|
loss='binary_crossentropy',
|
||||||
{"client": client_tr, "server": server_tr},
|
loss_weights=loss_weights,
|
||||||
batch_size=args.batch_size,
|
metrics=['accuracy'] + custom_metrics)
|
||||||
epochs=args.epochs,
|
|
||||||
callbacks=callbacks,
|
|
||||||
shuffle=True,
|
|
||||||
class_weight=custom_class_weights)
|
|
||||||
|
|
||||||
else:
|
model.summary()
|
||||||
if not param:
|
model.fit({"ipt_domains": domain_tr, "ipt_flows": flow_tr},
|
||||||
param = PARAMS
|
labels,
|
||||||
logger.info(f"Generator model with params: {param}")
|
batch_size=args.batch_size,
|
||||||
embedding, model, new_model = models.get_models_by_params(param)
|
epochs=args.epochs,
|
||||||
|
callbacks=callbacks,
|
||||||
model = create_model(model, args.model_output)
|
shuffle=True,
|
||||||
new_model = create_model(new_model, args.model_output)
|
validation_split=0.2,
|
||||||
|
class_weight=custom_class_weights)
|
||||||
if args.model_type == "inter":
|
|
||||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
|
||||||
model = new_model
|
|
||||||
logger.info("compile and train model")
|
|
||||||
embedding.summary()
|
|
||||||
model.summary()
|
|
||||||
logger.info(model.get_config())
|
|
||||||
model.compile(optimizer='adam',
|
|
||||||
loss='binary_crossentropy',
|
|
||||||
metrics=['accuracy'] + custom_metrics)
|
|
||||||
|
|
||||||
if args.model_output == "both":
|
|
||||||
labels = [client_tr, server_tr]
|
|
||||||
elif args.model_output == "client":
|
|
||||||
labels = [client_tr]
|
|
||||||
elif args.model_output == "server":
|
|
||||||
labels = [server_tr]
|
|
||||||
else:
|
|
||||||
raise ValueError("unknown model output")
|
|
||||||
|
|
||||||
callbacks.append(LambdaCallback(
|
|
||||||
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
|
|
||||||
)
|
|
||||||
model.fit([domain_tr, flow_tr],
|
|
||||||
labels,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
epochs=args.epochs,
|
|
||||||
callbacks=callbacks,
|
|
||||||
shuffle=True,
|
|
||||||
validation_split=0.3,
|
|
||||||
class_weight=custom_class_weights)
|
|
||||||
|
|
||||||
|
|
||||||
def main_test():
|
def main_test():
|
||||||
|
@ -3,7 +3,6 @@ from collections import namedtuple
|
|||||||
import keras
|
import keras
|
||||||
from keras.engine import Input, Model as KerasModel
|
from keras.engine import Input, Model as KerasModel
|
||||||
from keras.layers import Activation, Conv1D, Dense, Dropout, Embedding, GlobalMaxPooling1D, TimeDistributed
|
from keras.layers import Activation, Conv1D, Dense, Dropout, Embedding, GlobalMaxPooling1D, TimeDistributed
|
||||||
from keras.regularizers import l2
|
|
||||||
|
|
||||||
import dataset
|
import dataset
|
||||||
|
|
||||||
@ -58,7 +57,7 @@ def get_model(cnnDropout, flow_features, domain_features, window_size, domain_le
|
|||||||
# remove temporal dimension by global max pooling
|
# remove temporal dimension by global max pooling
|
||||||
y = GlobalMaxPooling1D()(y)
|
y = GlobalMaxPooling1D()(y)
|
||||||
y = Dropout(cnnDropout)(y)
|
y = Dropout(cnnDropout)(y)
|
||||||
y = Dense(dense_dim, kernel_regularizer=l2(0.1), activation='relu')(y)
|
y = Dense(dense_dim, activation='relu')(y)
|
||||||
out_client = Dense(1, activation='sigmoid', name="client")(y)
|
out_client = Dense(1, activation='sigmoid', name="client")(y)
|
||||||
out_server = Dense(1, activation='sigmoid', name="server")(y)
|
out_server = Dense(1, activation='sigmoid', name="server")(y)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user