add retrain mode
This commit is contained in:
parent
b157ca6a19
commit
090c89a127
40
Makefile
40
Makefile
@ -1,38 +1,38 @@
|
|||||||
run:
|
run:
|
||||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_1 --epochs 2 --depth small \
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_1 --epochs 2 --depth flat1 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both
|
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both
|
||||||
|
|
||||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_2 --epochs 2 --depth small \
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_2 --epochs 2 --depth flat1 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both
|
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both
|
||||||
|
|
||||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_3 --epochs 2 --depth medium \
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_3 --epochs 2 --depth deep1 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both
|
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both
|
||||||
|
|
||||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_4 --epochs 2 --depth medium \
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_4 --epochs 2 --depth deep1 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both
|
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both
|
||||||
|
|
||||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_5 --epochs 2 --depth small \
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_5 --epochs 2 --depth flat2 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type staggered --model_output both
|
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type staggered --model_output both
|
||||||
|
|
||||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_1 --epochs 2 --depth small \
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_1 --epochs 2 --depth flat2 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output client
|
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output client
|
||||||
|
|
||||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_2 --epochs 2 --depth small \
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_2 --epochs 2 --depth flat2 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output client
|
--dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output client
|
||||||
|
|
||||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_3 --epochs 2 --depth medium \
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_3 --epochs 2 --depth deep1 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output client
|
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client
|
||||||
|
|
||||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_4 --epochs 2 --depth medium \
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_4 --epochs 2 --depth deep1 \
|
||||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output client
|
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output client
|
||||||
|
|
||||||
test:
|
test:
|
||||||
|
11
arguments.py
11
arguments.py
@ -19,6 +19,12 @@ parser.add_argument("--test", action="store", dest="test_data",
|
|||||||
parser.add_argument("--model", action="store", dest="model_path",
|
parser.add_argument("--model", action="store", dest="model_path",
|
||||||
default="results/model_x")
|
default="results/model_x")
|
||||||
|
|
||||||
|
parser.add_argument("--model_src", action="store", dest="model_source",
|
||||||
|
default="results/model_x")
|
||||||
|
|
||||||
|
parser.add_argument("--model_dest", action="store", dest="model_destination",
|
||||||
|
default="results/model_x")
|
||||||
|
|
||||||
parser.add_argument("--models", action="store", dest="model_paths", nargs="+",
|
parser.add_argument("--models", action="store", dest="model_paths", nargs="+",
|
||||||
default=[])
|
default=[])
|
||||||
|
|
||||||
@ -37,6 +43,9 @@ parser.add_argument("--batch", action="store", dest="batch_size",
|
|||||||
parser.add_argument("--epochs", action="store", dest="epochs",
|
parser.add_argument("--epochs", action="store", dest="epochs",
|
||||||
default=10, type=int)
|
default=10, type=int)
|
||||||
|
|
||||||
|
parser.add_argument("--init_epoch", action="store", dest="initial_epoch",
|
||||||
|
default=0, type=int)
|
||||||
|
|
||||||
# parser.add_argument("--samples", action="store", dest="samples",
|
# parser.add_argument("--samples", action="store", dest="samples",
|
||||||
# default=100000, type=int)
|
# default=100000, type=int)
|
||||||
#
|
#
|
||||||
@ -98,7 +107,6 @@ parser.add_argument("--gpu", action="store_true", dest="gpu")
|
|||||||
parser.add_argument("--new_model", action="store_true", dest="new_model")
|
parser.add_argument("--new_model", action="store_true", dest="new_model")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_args(args):
|
def get_model_args(args):
|
||||||
return [{
|
return [{
|
||||||
"model_path": model_path,
|
"model_path": model_path,
|
||||||
@ -111,6 +119,7 @@ def get_model_args(args):
|
|||||||
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred")
|
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred")
|
||||||
} for model_path in args.model_paths]
|
} for model_path in args.model_paths]
|
||||||
|
|
||||||
|
|
||||||
def parse():
|
def parse():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.result_path = os.path.split(os.path.normpath(args.output_prefix))[1]
|
args.result_path = os.path.split(os.path.normpath(args.output_prefix))[1]
|
||||||
|
96
main.py
96
main.py
@ -5,8 +5,8 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from keras.callbacks import CSVLogger, EarlyStopping, LambdaCallback, ModelCheckpoint
|
from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint
|
||||||
from keras.models import Model, load_model
|
from keras.models import Model, load_model as load_keras_model
|
||||||
|
|
||||||
import arguments
|
import arguments
|
||||||
import dataset
|
import dataset
|
||||||
@ -86,6 +86,12 @@ def create_model(model, output_type):
|
|||||||
raise Exception("unknown model output")
|
raise Exception("unknown model output")
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(path, custom_objects=None):
|
||||||
|
clf = load_keras_model(path, custom_objects)
|
||||||
|
embd = clf.get_layer("domain_cnn").layer
|
||||||
|
return embd, clf
|
||||||
|
|
||||||
|
|
||||||
def main_paul_best():
|
def main_paul_best():
|
||||||
pauls_best_params = models.pauls_networks.best_config
|
pauls_best_params = models.pauls_networks.best_config
|
||||||
main_train(pauls_best_params)
|
main_train(pauls_best_params)
|
||||||
@ -161,10 +167,6 @@ def main_train(param=None):
|
|||||||
logger.info(f"Generator model with params: {param}")
|
logger.info(f"Generator model with params: {param}")
|
||||||
embedding, model, new_model = models.get_models_by_params(param)
|
embedding, model, new_model = models.get_models_by_params(param)
|
||||||
|
|
||||||
callbacks.append(LambdaCallback(
|
|
||||||
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
|
|
||||||
)
|
|
||||||
|
|
||||||
model = create_model(model, args.model_output)
|
model = create_model(model, args.model_output)
|
||||||
new_model = create_model(new_model, args.model_output)
|
new_model = create_model(new_model, args.model_output)
|
||||||
|
|
||||||
@ -222,6 +224,67 @@ def main_train(param=None):
|
|||||||
class_weight=custom_class_weights)
|
class_weight=custom_class_weights)
|
||||||
|
|
||||||
|
|
||||||
|
def main_retrain():
|
||||||
|
source = os.path.join(args.model_source, "clf.h5")
|
||||||
|
destination = os.path.join(args.model_destination, "clf.h5")
|
||||||
|
|
||||||
|
logger.info(f"Use command line arguments: {args}")
|
||||||
|
exists_or_make_path(destination)
|
||||||
|
|
||||||
|
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.train_h5data,
|
||||||
|
args.train_data,
|
||||||
|
args.domain_length,
|
||||||
|
args.window)
|
||||||
|
logger.info("define callbacks")
|
||||||
|
callbacks = []
|
||||||
|
callbacks.append(ModelCheckpoint(filepath=destination,
|
||||||
|
monitor='loss',
|
||||||
|
verbose=False,
|
||||||
|
save_best_only=True))
|
||||||
|
callbacks.append(CSVLogger(args.train_log))
|
||||||
|
logger.info(f"Use early stopping: {args.stop_early}")
|
||||||
|
if args.stop_early:
|
||||||
|
callbacks.append(EarlyStopping(monitor='val_loss',
|
||||||
|
patience=5,
|
||||||
|
verbose=False))
|
||||||
|
|
||||||
|
server_tr = np.max(server_windows_tr, axis=1)
|
||||||
|
|
||||||
|
if args.class_weights:
|
||||||
|
logger.info("class weights: compute custom weights")
|
||||||
|
custom_class_weights = get_custom_class_weights(client_tr.value, server_tr)
|
||||||
|
logger.info(custom_class_weights)
|
||||||
|
else:
|
||||||
|
logger.info("class weights: set default")
|
||||||
|
custom_class_weights = None
|
||||||
|
|
||||||
|
logger.info(f"Load pretrained model")
|
||||||
|
embedding, model = load_model(source, custom_objects=models.get_metrics())
|
||||||
|
|
||||||
|
if args.model_type in ("inter", "staggered"):
|
||||||
|
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||||
|
|
||||||
|
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
|
||||||
|
if args.model_output == "both":
|
||||||
|
labels = {"client": client_tr.value, "server": server_tr}
|
||||||
|
elif args.model_output == "client":
|
||||||
|
labels = {"client": client_tr.value}
|
||||||
|
elif args.model_output == "server":
|
||||||
|
labels = {"server": server_tr}
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown model output")
|
||||||
|
|
||||||
|
logger.info("re-train model")
|
||||||
|
embedding.summary()
|
||||||
|
model.summary()
|
||||||
|
model.fit(features, labels,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
epochs=args.epochs,
|
||||||
|
callbacks=callbacks,
|
||||||
|
class_weight=custom_class_weights,
|
||||||
|
initial_epoch=args.initial_epoch)
|
||||||
|
|
||||||
|
|
||||||
def main_test():
|
def main_test():
|
||||||
logger.info("start test: load data")
|
logger.info("start test: load data")
|
||||||
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.test_h5data,
|
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.test_h5data,
|
||||||
@ -233,7 +296,7 @@ def main_test():
|
|||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
results = {}
|
results = {}
|
||||||
logger.info(f"process model {model_args['model_path']}")
|
logger.info(f"process model {model_args['model_path']}")
|
||||||
clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics())
|
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics())
|
||||||
|
|
||||||
pred = clf_model.predict([domain_val, flow_val],
|
pred = clf_model.predict([domain_val, flow_val],
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
@ -248,7 +311,6 @@ def main_test():
|
|||||||
else:
|
else:
|
||||||
results["server_pred"] = pred
|
results["server_pred"] = pred
|
||||||
|
|
||||||
embd_model = load_model(model_args["embedding_model"], custom_objects=models.get_metrics())
|
|
||||||
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||||
results["domain_embds"] = domain_embeddings
|
results["domain_embds"] = domain_embeddings
|
||||||
|
|
||||||
@ -278,7 +340,7 @@ def main_visualization():
|
|||||||
df_paul_user = df_paul.groupby(df_paul.names).max()
|
df_paul_user = df_paul.groupby(df_paul.names).max()
|
||||||
|
|
||||||
logger.info("plot model")
|
logger.info("plot model")
|
||||||
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
embd, model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||||
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
|
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
|
||||||
|
|
||||||
# logger.info("plot training curve")
|
# logger.info("plot training curve")
|
||||||
@ -491,6 +553,16 @@ def main_beta():
|
|||||||
visualize.plot_save(f"{args.output_prefix}_user_client_roc_all.png")
|
visualize.plot_save(f"{args.output_prefix}_user_client_roc_all.png")
|
||||||
|
|
||||||
joblib.dump(results, f"{path}/curves.joblib")
|
joblib.dump(results, f"{path}/curves.joblib")
|
||||||
|
|
||||||
|
plot_overall_result()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_overall_result():
|
||||||
|
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
|
||||||
|
try:
|
||||||
|
results = joblib.load(f"{path}/curves.joblib")
|
||||||
|
except Exception:
|
||||||
|
results = {}
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
x = np.linspace(0, 1, 10000)
|
x = np.linspace(0, 1, 10000)
|
||||||
@ -500,7 +572,7 @@ def main_beta():
|
|||||||
for model_key in results.keys():
|
for model_key in results.keys():
|
||||||
ys_mean, ys_std, score = results[model_key][vis]
|
ys_mean, ys_std, score = results[model_key][vis]
|
||||||
plt.plot(x, ys_mean, label=f"{model_key} - {score:5.4}")
|
plt.plot(x, ys_mean, label=f"{model_key} - {score:5.4}")
|
||||||
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1)
|
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, alpha=0.2)
|
||||||
if vis.endswith("prc"):
|
if vis.endswith("prc"):
|
||||||
plt.xlabel('Recall')
|
plt.xlabel('Recall')
|
||||||
plt.ylabel('Precision')
|
plt.ylabel('Precision')
|
||||||
@ -516,6 +588,8 @@ def main_beta():
|
|||||||
def main():
|
def main():
|
||||||
if "train" == args.mode:
|
if "train" == args.mode:
|
||||||
main_train()
|
main_train()
|
||||||
|
if "retrain" == args.mode:
|
||||||
|
main_retrain()
|
||||||
if "hyperband" == args.mode:
|
if "hyperband" == args.mode:
|
||||||
main_hyperband()
|
main_hyperband()
|
||||||
if "test" == args.mode:
|
if "test" == args.mode:
|
||||||
@ -530,6 +604,8 @@ def main():
|
|||||||
main_paul_best()
|
main_paul_best()
|
||||||
if "beta" == args.mode:
|
if "beta" == args.mode:
|
||||||
main_beta()
|
main_beta()
|
||||||
|
if "beta_all" == args.mode:
|
||||||
|
plot_overall_result()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
21
visualize.py
21
visualize.py
@ -84,14 +84,17 @@ def calc_pr_mean(y, y_preds):
|
|||||||
return ys_mean, ys_std, scores_mean
|
return ys_mean, ys_std, scores_mean
|
||||||
|
|
||||||
|
|
||||||
|
def plot_mean_curve(x, ys, std, score, label):
|
||||||
|
plt.plot(x, ys, label=f"{label} - {score:5.4}")
|
||||||
|
plt.fill_between(x, ys - std, ys + std, alpha=0.1)
|
||||||
|
plt.ylim([0.0, 1.0])
|
||||||
|
plt.xlim([0.0, 1.0])
|
||||||
|
|
||||||
|
|
||||||
def plot_pr_mean(y, y_preds, label=""):
|
def plot_pr_mean(y, y_preds, label=""):
|
||||||
x = np.linspace(0, 1, 10000)
|
x = np.linspace(0, 1, 10000)
|
||||||
ys_mean, ys_std, score = calc_pr_mean(y, y_preds)
|
ys_mean, ys_std, score = calc_pr_mean(y, y_preds)
|
||||||
|
plot_mean_curve(x, ys_mean, ys_std, score, label)
|
||||||
plt.plot(x, ys_mean, label=f"{label} - {score:5.4}")
|
|
||||||
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1)
|
|
||||||
plt.ylim([0.0, 1.0])
|
|
||||||
plt.xlim([0.0, 1.0])
|
|
||||||
plt.xlabel('Recall')
|
plt.xlabel('Recall')
|
||||||
plt.ylabel('Precision')
|
plt.ylabel('Precision')
|
||||||
|
|
||||||
@ -142,13 +145,9 @@ def calc_roc_mean(y, y_preds):
|
|||||||
|
|
||||||
def plot_roc_mean(y, y_preds, label=""):
|
def plot_roc_mean(y, y_preds, label=""):
|
||||||
x = np.linspace(0, 1, 10000)
|
x = np.linspace(0, 1, 10000)
|
||||||
ys_mean, ys_std, auc_mean = calc_roc_mean(y, y_preds)
|
ys_mean, ys_std, score = calc_roc_mean(y, y_preds)
|
||||||
plt.xscale('log')
|
plt.xscale('log')
|
||||||
plt.ylim([0.0, 1.0])
|
plot_mean_curve(x, ys_mean, ys_std, score, label)
|
||||||
plt.xlim([0.0, 1.0])
|
|
||||||
|
|
||||||
plt.plot(x, ys_mean, label=f"{label} - {auc_mean:5.4}")
|
|
||||||
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1)
|
|
||||||
plt.xlabel('False Positive Rate')
|
plt.xlabel('False Positive Rate')
|
||||||
plt.ylabel('True Positive Rate')
|
plt.ylabel('True Positive Rate')
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user