train multiple models at once
This commit is contained in:
parent
88e3eda595
commit
14fef66a55
195
main.py
195
main.py
@ -80,8 +80,8 @@ PARAMS = {
|
|||||||
|
|
||||||
|
|
||||||
# TODO: remove inner global params
|
# TODO: remove inner global params
|
||||||
def get_param_dist(size="small"):
|
def get_param_dist(dist_size="small"):
|
||||||
if dist_type == "small":
|
if dist_size == "small":
|
||||||
return {
|
return {
|
||||||
# static params
|
# static params
|
||||||
"type": [args.model_type],
|
"type": [args.model_type],
|
||||||
@ -180,11 +180,7 @@ def train(parameters, features, labels):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def main_train(param=None):
|
def load_data(data, domain_length, window_size, model_type):
|
||||||
logger.info(f"Create model path {args.model_path}")
|
|
||||||
exists_or_make_path(args.model_path)
|
|
||||||
logger.info(f"Use command line arguments: {args}")
|
|
||||||
|
|
||||||
# data preparation
|
# data preparation
|
||||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
|
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
|
||||||
args.data,
|
args.data,
|
||||||
@ -193,110 +189,124 @@ def main_train(param=None):
|
|||||||
server_tr = np.max(server_windows_tr, axis=1)
|
server_tr = np.max(server_windows_tr, axis=1)
|
||||||
if args.model_type in ("inter", "staggered"):
|
if args.model_type in ("inter", "staggered"):
|
||||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||||
|
return domain_tr, flow_tr, client_tr, server_tr
|
||||||
|
|
||||||
|
|
||||||
|
def main_train(param=None):
|
||||||
|
logger.info(f"Create model path {args.model_path}")
|
||||||
|
exists_or_make_path(args.model_path)
|
||||||
|
logger.info(f"Use command line arguments: {args}")
|
||||||
|
|
||||||
|
# data preparation
|
||||||
|
domain_tr, flow_tr, client_tr, server_tr = load_data(args.data, args.domain_length,
|
||||||
|
args.window, args.model_type)
|
||||||
|
|
||||||
# call hyperband if used
|
# call hyperband if used
|
||||||
if args.hyperband_results:
|
if args.hyperband_results:
|
||||||
logger.info("start hyperband parameter search")
|
logger.info("start hyperband parameter search")
|
||||||
hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, 81, args.hyperband_results)
|
hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, 81, args.hyperband_results)
|
||||||
param = sorted(hyper_results, key=operator.itemgetter("loss"))[0]
|
param = sorted(hyper_results, key=operator.itemgetter("loss"))[0]["params"]
|
||||||
logger.info(f"select params from result: {param}")
|
logger.info(f"select params from result: {param}")
|
||||||
|
|
||||||
# define training call backs
|
|
||||||
logger.info("define callbacks")
|
|
||||||
callbacks = []
|
|
||||||
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
|
|
||||||
monitor='loss',
|
|
||||||
verbose=False,
|
|
||||||
save_best_only=True))
|
|
||||||
callbacks.append(CSVLogger(args.train_log))
|
|
||||||
logger.info(f"Use early stopping: {args.stop_early}")
|
|
||||||
if args.stop_early:
|
|
||||||
callbacks.append(EarlyStopping(monitor='val_loss',
|
|
||||||
patience=5,
|
|
||||||
verbose=False))
|
|
||||||
custom_metrics = models.get_metric_functions()
|
|
||||||
|
|
||||||
# custom class or sample weights
|
|
||||||
if args.class_weights:
|
|
||||||
logger.info("class weights: compute custom weights")
|
|
||||||
custom_class_weights = get_custom_class_weights(client_tr.value, server_tr)
|
|
||||||
logger.info(custom_class_weights)
|
|
||||||
else:
|
|
||||||
logger.info("class weights: set default")
|
|
||||||
custom_class_weights = None
|
|
||||||
|
|
||||||
if args.sample_weights:
|
|
||||||
logger.info("class weights: compute custom weights")
|
|
||||||
custom_sample_weights = get_custom_sample_weights(client_tr.value, server_tr)
|
|
||||||
logger.info(custom_class_weights)
|
|
||||||
else:
|
|
||||||
logger.info("class weights: set default")
|
|
||||||
custom_sample_weights = None
|
|
||||||
|
|
||||||
if not param:
|
if not param:
|
||||||
param = PARAMS
|
param = PARAMS
|
||||||
logger.info(f"Generator model with params: {param}")
|
|
||||||
embedding, model, new_model = models.get_models_by_params(param)
|
|
||||||
|
|
||||||
model = create_model(model, args.model_output)
|
for i in range(20):
|
||||||
new_model = create_model(new_model, args.model_output)
|
model_path = os.path.join(args.model_path, f"clf_{i}.h5")
|
||||||
|
train_log_path = os.path.join(args.model_path, "train_{i}.log.csv")
|
||||||
|
# define training call backs
|
||||||
|
logger.info("define callbacks")
|
||||||
|
callbacks = []
|
||||||
|
callbacks.append(ModelCheckpoint(filepath=model_path,
|
||||||
|
monitor='loss',
|
||||||
|
verbose=False,
|
||||||
|
save_best_only=True))
|
||||||
|
callbacks.append(CSVLogger(train_log_path))
|
||||||
|
logger.info(f"Use early stopping: {args.stop_early}")
|
||||||
|
if args.stop_early:
|
||||||
|
callbacks.append(EarlyStopping(monitor='val_loss',
|
||||||
|
patience=5,
|
||||||
|
verbose=False))
|
||||||
|
custom_metrics = models.get_metric_functions()
|
||||||
|
|
||||||
if args.model_type in ("inter", "staggered"):
|
# custom class or sample weights
|
||||||
model = new_model
|
if args.class_weights:
|
||||||
|
logger.info("class weights: compute custom weights")
|
||||||
|
custom_class_weights = get_custom_class_weights(client_tr.value, server_tr)
|
||||||
|
logger.info(custom_class_weights)
|
||||||
|
else:
|
||||||
|
logger.info("class weights: set default")
|
||||||
|
custom_class_weights = None
|
||||||
|
|
||||||
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
|
if args.sample_weights:
|
||||||
if args.model_output == "both":
|
logger.info("class weights: compute custom weights")
|
||||||
labels = {"client": client_tr.value, "server": server_tr}
|
custom_sample_weights = get_custom_sample_weights(client_tr.value, server_tr)
|
||||||
loss_weights = {"client": 1.0, "server": 1.0}
|
logger.info(custom_class_weights)
|
||||||
elif args.model_output == "client":
|
else:
|
||||||
labels = {"client": client_tr.value}
|
logger.info("class weights: set default")
|
||||||
loss_weights = {"client": 1.0}
|
custom_sample_weights = None
|
||||||
elif args.model_output == "server":
|
|
||||||
labels = {"server": server_tr}
|
|
||||||
loss_weights = {"server": 1.0}
|
|
||||||
else:
|
|
||||||
raise ValueError("unknown model output")
|
|
||||||
|
|
||||||
logger.info(f"select model: {args.model_type}")
|
logger.info(f"Generator model with params: {param}")
|
||||||
if args.model_type == "staggered":
|
embedding, model, new_model = models.get_models_by_params(param)
|
||||||
logger.info("compile and pre-train server model")
|
|
||||||
|
model = create_model(model, args.model_output)
|
||||||
|
new_model = create_model(new_model, args.model_output)
|
||||||
|
|
||||||
|
if args.model_type in ("inter", "staggered"):
|
||||||
|
model = new_model
|
||||||
|
|
||||||
|
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
|
||||||
|
if args.model_output == "both":
|
||||||
|
labels = {"client": client_tr.value, "server": server_tr}
|
||||||
|
loss_weights = {"client": 1.0, "server": 1.0}
|
||||||
|
elif args.model_output == "client":
|
||||||
|
labels = {"client": client_tr.value}
|
||||||
|
loss_weights = {"client": 1.0}
|
||||||
|
elif args.model_output == "server":
|
||||||
|
labels = {"server": server_tr}
|
||||||
|
loss_weights = {"server": 1.0}
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown model output")
|
||||||
|
|
||||||
|
logger.info(f"select model: {args.model_type}")
|
||||||
|
if args.model_type == "staggered":
|
||||||
|
logger.info("compile and pre-train server model")
|
||||||
|
logger.info(model.get_config())
|
||||||
|
|
||||||
|
model.compile(optimizer='adam',
|
||||||
|
loss='binary_crossentropy',
|
||||||
|
loss_weights={"client": 0.0, "server": 1.0},
|
||||||
|
metrics=['accuracy'] + custom_metrics)
|
||||||
|
|
||||||
|
model.summary()
|
||||||
|
model.fit(features, labels,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
epochs=args.epochs,
|
||||||
|
class_weight=custom_class_weights,
|
||||||
|
sample_weight=custom_sample_weights)
|
||||||
|
|
||||||
|
logger.info("fix server model")
|
||||||
|
model.get_layer("domain_cnn").trainable = False
|
||||||
|
model.get_layer("domain_cnn").layer.trainable = False
|
||||||
|
model.get_layer("dense_server").trainable = False
|
||||||
|
model.get_layer("server").trainable = False
|
||||||
|
loss_weights = {"client": 1.0, "server": 0.0}
|
||||||
|
|
||||||
|
logger.info("compile and train model")
|
||||||
|
embedding.summary()
|
||||||
logger.info(model.get_config())
|
logger.info(model.get_config())
|
||||||
|
|
||||||
model.compile(optimizer='adam',
|
model.compile(optimizer='adam',
|
||||||
loss='binary_crossentropy',
|
loss='binary_crossentropy',
|
||||||
loss_weights={"client": 0.0, "server": 1.0},
|
loss_weights=loss_weights,
|
||||||
metrics=['accuracy'] + custom_metrics)
|
metrics=['accuracy'] + custom_metrics)
|
||||||
|
|
||||||
model.summary()
|
model.summary()
|
||||||
model.fit(features, labels,
|
model.fit(features, labels,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
|
callbacks=callbacks,
|
||||||
class_weight=custom_class_weights,
|
class_weight=custom_class_weights,
|
||||||
sample_weight=custom_sample_weights)
|
sample_weight=custom_sample_weights)
|
||||||
|
|
||||||
logger.info("fix server model")
|
|
||||||
model.get_layer("domain_cnn").trainable = False
|
|
||||||
model.get_layer("domain_cnn").layer.trainable = False
|
|
||||||
model.get_layer("dense_server").trainable = False
|
|
||||||
model.get_layer("server").trainable = False
|
|
||||||
loss_weights = {"client": 1.0, "server": 0.0}
|
|
||||||
|
|
||||||
logger.info("compile and train model")
|
|
||||||
embedding.summary()
|
|
||||||
logger.info(model.get_config())
|
|
||||||
model.compile(optimizer='adam',
|
|
||||||
loss='binary_crossentropy',
|
|
||||||
loss_weights=loss_weights,
|
|
||||||
metrics=['accuracy'] + custom_metrics)
|
|
||||||
|
|
||||||
model.summary()
|
|
||||||
model.fit(features, labels,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
epochs=args.epochs,
|
|
||||||
callbacks=callbacks,
|
|
||||||
class_weight=custom_class_weights,
|
|
||||||
sample_weight=custom_sample_weights)
|
|
||||||
|
|
||||||
|
|
||||||
def main_retrain():
|
def main_retrain():
|
||||||
source = os.path.join(args.model_source, "clf.h5")
|
source = os.path.join(args.model_source, "clf.h5")
|
||||||
@ -470,15 +480,6 @@ def main_visualization():
|
|||||||
normalize=True, title="User Confusion Matrix")
|
normalize=True, title="User Confusion Matrix")
|
||||||
|
|
||||||
|
|
||||||
# plot_embedding(args.model_path, results["domain_embds"], args.data, args.domain_length)
|
|
||||||
|
|
||||||
|
|
||||||
# def plot_embedding(model_path, domain_embedding, data, domain_length):
|
|
||||||
# logger.info("visualize embedding")
|
|
||||||
# domain_encs, labels = dataset.load_or_generate_domains(data, domain_length)
|
|
||||||
# visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.png".format(model_path), method="svd")
|
|
||||||
|
|
||||||
|
|
||||||
def main_visualize_all():
|
def main_visualize_all():
|
||||||
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
|
||||||
args.data,
|
args.data,
|
||||||
@ -706,6 +707,7 @@ def main_beta():
|
|||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
def plot_overall_result():
|
def plot_overall_result():
|
||||||
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
|
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
|
||||||
try:
|
try:
|
||||||
@ -816,7 +818,6 @@ def main_stats2():
|
|||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if "train" == args.mode:
|
if "train" == args.mode:
|
||||||
main_train()
|
main_train()
|
||||||
|
Loading…
Reference in New Issue
Block a user