add argument for using the new model architecture

This commit is contained in:
René Knaebel 2017-07-30 14:07:39 +02:00
parent ebaeb6b96e
commit e24f596f40
2 changed files with 11 additions and 10 deletions

View File

@ -63,6 +63,7 @@ parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
parser.add_argument("--stop_early", action="store_true", dest="stop_early") parser.add_argument("--stop_early", action="store_true", dest="stop_early")
parser.add_argument("--balanced_weights", action="store_true", dest="class_weights") parser.add_argument("--balanced_weights", action="store_true", dest="class_weights")
parser.add_argument("--gpu", action="store_true", dest="gpu") parser.add_argument("--gpu", action="store_true", dest="gpu")
parser.add_argument("--new_model", action="store_true", dest="new_model")

20
main.py
View File

@ -112,18 +112,18 @@ def main_hyperband():
json.dump(results, open("hyperband.json")) json.dump(results, open("hyperband.json"))
def main_train(param=None, train_new_model=False): def main_train(param=None):
logger.info(f"Create model path {args.model_path}")
exists_or_make_path(args.model_path) exists_or_make_path(args.model_path)
logger.info(f"Use command line arguments: {args}")
domain_tr, flow_tr, client_tr, server_windows_tr = load_or_generate_h5data(args.train_h5data, args.train_data, domain_tr, flow_tr, client_tr, server_windows_tr = load_or_generate_h5data(args.train_h5data, args.train_data,
args.domain_length, args.window) args.domain_length, args.window)
if not param: if not param:
param = PARAMS param = PARAMS
logger.info(f"Generator model with params: {param}")
embedding, model, new_model = models.get_models_by_params(param) embedding, model, new_model = models.get_models_by_params(param)
embedding.summary()
model.summary()
logger.info("define callbacks") logger.info("define callbacks")
callbacks = [] callbacks = []
callbacks.append(ModelCheckpoint(filepath=args.clf_model, callbacks.append(ModelCheckpoint(filepath=args.clf_model,
@ -131,11 +131,11 @@ def main_train(param=None, train_new_model=False):
verbose=False, verbose=False,
save_best_only=True)) save_best_only=True))
callbacks.append(CSVLogger(args.train_log)) callbacks.append(CSVLogger(args.train_log))
logger.info(f"Use early stopping: {args.stop_early}")
if args.stop_early: if args.stop_early:
callbacks.append(EarlyStopping(monitor='val_loss', callbacks.append(EarlyStopping(monitor='val_loss',
patience=5, patience=5,
verbose=False)) verbose=False))
logger.info("compile model")
custom_metrics = models.get_metric_functions() custom_metrics = models.get_metric_functions()
server_tr = np.max(server_windows_tr, axis=1) server_tr = np.max(server_windows_tr, axis=1)
@ -147,12 +147,14 @@ def main_train(param=None, train_new_model=False):
else: else:
logger.info("class weights: set default") logger.info("class weights: set default")
custom_class_weights = None custom_class_weights = None
logger.info("start training")
if train_new_model: logger.info(f"select model: {'new' if args.new_model else 'old'}")
if args.new_model:
server_tr = np.expand_dims(server_windows_tr, 2) server_tr = np.expand_dims(server_windows_tr, 2)
model = new_model model = new_model
logger.info("compile and train model")
embedding.summary()
model.summary()
model.compile(optimizer='adam', model.compile(optimizer='adam',
loss='binary_crossentropy', loss='binary_crossentropy',
metrics=['accuracy'] + custom_metrics) metrics=['accuracy'] + custom_metrics)
@ -271,8 +273,6 @@ def main():
main_paul_best() main_paul_best()
if "data" in args.modes: if "data" in args.modes:
main_data() main_data()
if "train_new" in args.modes:
main_train(train_new_model=True)
if __name__ == "__main__": if __name__ == "__main__":