add argument for using the new model architecture
This commit is contained in:
parent
ebaeb6b96e
commit
e24f596f40
@ -63,6 +63,7 @@ parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
|
|||||||
parser.add_argument("--stop_early", action="store_true", dest="stop_early")
|
parser.add_argument("--stop_early", action="store_true", dest="stop_early")
|
||||||
parser.add_argument("--balanced_weights", action="store_true", dest="class_weights")
|
parser.add_argument("--balanced_weights", action="store_true", dest="class_weights")
|
||||||
parser.add_argument("--gpu", action="store_true", dest="gpu")
|
parser.add_argument("--gpu", action="store_true", dest="gpu")
|
||||||
|
parser.add_argument("--new_model", action="store_true", dest="new_model")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
20
main.py
20
main.py
@ -112,18 +112,18 @@ def main_hyperband():
|
|||||||
json.dump(results, open("hyperband.json"))
|
json.dump(results, open("hyperband.json"))
|
||||||
|
|
||||||
|
|
||||||
def main_train(param=None, train_new_model=False):
|
def main_train(param=None):
|
||||||
|
logger.info(f"Create model path {args.model_path}")
|
||||||
exists_or_make_path(args.model_path)
|
exists_or_make_path(args.model_path)
|
||||||
|
logger.info(f"Use command line arguments: {args}")
|
||||||
|
|
||||||
domain_tr, flow_tr, client_tr, server_windows_tr = load_or_generate_h5data(args.train_h5data, args.train_data,
|
domain_tr, flow_tr, client_tr, server_windows_tr = load_or_generate_h5data(args.train_h5data, args.train_data,
|
||||||
args.domain_length, args.window)
|
args.domain_length, args.window)
|
||||||
|
|
||||||
if not param:
|
if not param:
|
||||||
param = PARAMS
|
param = PARAMS
|
||||||
|
logger.info(f"Generator model with params: {param}")
|
||||||
embedding, model, new_model = models.get_models_by_params(param)
|
embedding, model, new_model = models.get_models_by_params(param)
|
||||||
embedding.summary()
|
|
||||||
model.summary()
|
|
||||||
logger.info("define callbacks")
|
logger.info("define callbacks")
|
||||||
callbacks = []
|
callbacks = []
|
||||||
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
|
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
|
||||||
@ -131,11 +131,11 @@ def main_train(param=None, train_new_model=False):
|
|||||||
verbose=False,
|
verbose=False,
|
||||||
save_best_only=True))
|
save_best_only=True))
|
||||||
callbacks.append(CSVLogger(args.train_log))
|
callbacks.append(CSVLogger(args.train_log))
|
||||||
|
logger.info(f"Use early stopping: {args.stop_early}")
|
||||||
if args.stop_early:
|
if args.stop_early:
|
||||||
callbacks.append(EarlyStopping(monitor='val_loss',
|
callbacks.append(EarlyStopping(monitor='val_loss',
|
||||||
patience=5,
|
patience=5,
|
||||||
verbose=False))
|
verbose=False))
|
||||||
logger.info("compile model")
|
|
||||||
custom_metrics = models.get_metric_functions()
|
custom_metrics = models.get_metric_functions()
|
||||||
|
|
||||||
server_tr = np.max(server_windows_tr, axis=1)
|
server_tr = np.max(server_windows_tr, axis=1)
|
||||||
@ -147,12 +147,14 @@ def main_train(param=None, train_new_model=False):
|
|||||||
else:
|
else:
|
||||||
logger.info("class weights: set default")
|
logger.info("class weights: set default")
|
||||||
custom_class_weights = None
|
custom_class_weights = None
|
||||||
logger.info("start training")
|
|
||||||
|
|
||||||
if train_new_model:
|
logger.info(f"select model: {'new' if args.new_model else 'old'}")
|
||||||
|
if args.new_model:
|
||||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||||
model = new_model
|
model = new_model
|
||||||
|
logger.info("compile and train model")
|
||||||
|
embedding.summary()
|
||||||
|
model.summary()
|
||||||
model.compile(optimizer='adam',
|
model.compile(optimizer='adam',
|
||||||
loss='binary_crossentropy',
|
loss='binary_crossentropy',
|
||||||
metrics=['accuracy'] + custom_metrics)
|
metrics=['accuracy'] + custom_metrics)
|
||||||
@ -271,8 +273,6 @@ def main():
|
|||||||
main_paul_best()
|
main_paul_best()
|
||||||
if "data" in args.modes:
|
if "data" in args.modes:
|
||||||
main_data()
|
main_data()
|
||||||
if "train_new" in args.modes:
|
|
||||||
main_train(train_new_model=True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user