add retrain mode
This commit is contained in:
parent
b157ca6a19
commit
090c89a127
40
Makefile
40
Makefile
@ -1,38 +1,38 @@
|
||||
run:
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_1 --epochs 2 --depth small \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_1 --epochs 2 --depth flat1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both
|
||||
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_2 --epochs 2 --depth small \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_2 --epochs 2 --depth flat1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both
|
||||
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_3 --epochs 2 --depth medium \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_3 --epochs 2 --depth deep1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both
|
||||
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_4 --epochs 2 --depth medium \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_4 --epochs 2 --depth deep1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both
|
||||
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_5 --epochs 2 --depth small \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_5 --epochs 2 --depth flat2 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type staggered --model_output both
|
||||
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_1 --epochs 2 --depth small \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_1 --epochs 2 --depth flat2 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output client
|
||||
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_2 --epochs 2 --depth small \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output client
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_2 --epochs 2 --depth flat2 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output client
|
||||
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_3 --epochs 2 --depth medium \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output client
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_3 --epochs 2 --depth deep1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client
|
||||
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_4 --epochs 2 --depth medium \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 128 \
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_4 --epochs 2 --depth deep1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output client
|
||||
|
||||
test:
|
||||
|
11
arguments.py
11
arguments.py
@ -19,6 +19,12 @@ parser.add_argument("--test", action="store", dest="test_data",
|
||||
parser.add_argument("--model", action="store", dest="model_path",
|
||||
default="results/model_x")
|
||||
|
||||
parser.add_argument("--model_src", action="store", dest="model_source",
|
||||
default="results/model_x")
|
||||
|
||||
parser.add_argument("--model_dest", action="store", dest="model_destination",
|
||||
default="results/model_x")
|
||||
|
||||
parser.add_argument("--models", action="store", dest="model_paths", nargs="+",
|
||||
default=[])
|
||||
|
||||
@ -37,6 +43,9 @@ parser.add_argument("--batch", action="store", dest="batch_size",
|
||||
parser.add_argument("--epochs", action="store", dest="epochs",
|
||||
default=10, type=int)
|
||||
|
||||
parser.add_argument("--init_epoch", action="store", dest="initial_epoch",
|
||||
default=0, type=int)
|
||||
|
||||
# parser.add_argument("--samples", action="store", dest="samples",
|
||||
# default=100000, type=int)
|
||||
#
|
||||
@ -98,7 +107,6 @@ parser.add_argument("--gpu", action="store_true", dest="gpu")
|
||||
parser.add_argument("--new_model", action="store_true", dest="new_model")
|
||||
|
||||
|
||||
|
||||
def get_model_args(args):
|
||||
return [{
|
||||
"model_path": model_path,
|
||||
@ -111,6 +119,7 @@ def get_model_args(args):
|
||||
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred")
|
||||
} for model_path in args.model_paths]
|
||||
|
||||
|
||||
def parse():
|
||||
args = parser.parse_args()
|
||||
args.result_path = os.path.split(os.path.normpath(args.output_prefix))[1]
|
||||
|
96
main.py
96
main.py
@ -5,8 +5,8 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow as tf
|
||||
from keras.callbacks import CSVLogger, EarlyStopping, LambdaCallback, ModelCheckpoint
|
||||
from keras.models import Model, load_model
|
||||
from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint
|
||||
from keras.models import Model, load_model as load_keras_model
|
||||
|
||||
import arguments
|
||||
import dataset
|
||||
@ -86,6 +86,12 @@ def create_model(model, output_type):
|
||||
raise Exception("unknown model output")
|
||||
|
||||
|
||||
def load_model(path, custom_objects=None):
|
||||
clf = load_keras_model(path, custom_objects)
|
||||
embd = clf.get_layer("domain_cnn").layer
|
||||
return embd, clf
|
||||
|
||||
|
||||
def main_paul_best():
|
||||
pauls_best_params = models.pauls_networks.best_config
|
||||
main_train(pauls_best_params)
|
||||
@ -161,10 +167,6 @@ def main_train(param=None):
|
||||
logger.info(f"Generator model with params: {param}")
|
||||
embedding, model, new_model = models.get_models_by_params(param)
|
||||
|
||||
callbacks.append(LambdaCallback(
|
||||
on_epoch_end=lambda epoch, logs: embedding.save(args.embedding_model))
|
||||
)
|
||||
|
||||
model = create_model(model, args.model_output)
|
||||
new_model = create_model(new_model, args.model_output)
|
||||
|
||||
@ -222,6 +224,67 @@ def main_train(param=None):
|
||||
class_weight=custom_class_weights)
|
||||
|
||||
|
||||
def main_retrain():
|
||||
source = os.path.join(args.model_source, "clf.h5")
|
||||
destination = os.path.join(args.model_destination, "clf.h5")
|
||||
|
||||
logger.info(f"Use command line arguments: {args}")
|
||||
exists_or_make_path(destination)
|
||||
|
||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.train_h5data,
|
||||
args.train_data,
|
||||
args.domain_length,
|
||||
args.window)
|
||||
logger.info("define callbacks")
|
||||
callbacks = []
|
||||
callbacks.append(ModelCheckpoint(filepath=destination,
|
||||
monitor='loss',
|
||||
verbose=False,
|
||||
save_best_only=True))
|
||||
callbacks.append(CSVLogger(args.train_log))
|
||||
logger.info(f"Use early stopping: {args.stop_early}")
|
||||
if args.stop_early:
|
||||
callbacks.append(EarlyStopping(monitor='val_loss',
|
||||
patience=5,
|
||||
verbose=False))
|
||||
|
||||
server_tr = np.max(server_windows_tr, axis=1)
|
||||
|
||||
if args.class_weights:
|
||||
logger.info("class weights: compute custom weights")
|
||||
custom_class_weights = get_custom_class_weights(client_tr.value, server_tr)
|
||||
logger.info(custom_class_weights)
|
||||
else:
|
||||
logger.info("class weights: set default")
|
||||
custom_class_weights = None
|
||||
|
||||
logger.info(f"Load pretrained model")
|
||||
embedding, model = load_model(source, custom_objects=models.get_metrics())
|
||||
|
||||
if args.model_type in ("inter", "staggered"):
|
||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||
|
||||
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
|
||||
if args.model_output == "both":
|
||||
labels = {"client": client_tr.value, "server": server_tr}
|
||||
elif args.model_output == "client":
|
||||
labels = {"client": client_tr.value}
|
||||
elif args.model_output == "server":
|
||||
labels = {"server": server_tr}
|
||||
else:
|
||||
raise ValueError("unknown model output")
|
||||
|
||||
logger.info("re-train model")
|
||||
embedding.summary()
|
||||
model.summary()
|
||||
model.fit(features, labels,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
callbacks=callbacks,
|
||||
class_weight=custom_class_weights,
|
||||
initial_epoch=args.initial_epoch)
|
||||
|
||||
|
||||
def main_test():
|
||||
logger.info("start test: load data")
|
||||
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.test_h5data,
|
||||
@ -233,7 +296,7 @@ def main_test():
|
||||
for model_args in get_model_args(args):
|
||||
results = {}
|
||||
logger.info(f"process model {model_args['model_path']}")
|
||||
clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics())
|
||||
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics())
|
||||
|
||||
pred = clf_model.predict([domain_val, flow_val],
|
||||
batch_size=args.batch_size,
|
||||
@ -248,7 +311,6 @@ def main_test():
|
||||
else:
|
||||
results["server_pred"] = pred
|
||||
|
||||
embd_model = load_model(model_args["embedding_model"], custom_objects=models.get_metrics())
|
||||
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||
results["domain_embds"] = domain_embeddings
|
||||
|
||||
@ -278,7 +340,7 @@ def main_visualization():
|
||||
df_paul_user = df_paul.groupby(df_paul.names).max()
|
||||
|
||||
logger.info("plot model")
|
||||
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||
embd, model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
|
||||
|
||||
# logger.info("plot training curve")
|
||||
@ -491,6 +553,16 @@ def main_beta():
|
||||
visualize.plot_save(f"{args.output_prefix}_user_client_roc_all.png")
|
||||
|
||||
joblib.dump(results, f"{path}/curves.joblib")
|
||||
|
||||
plot_overall_result()
|
||||
|
||||
|
||||
def plot_overall_result():
|
||||
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
|
||||
try:
|
||||
results = joblib.load(f"{path}/curves.joblib")
|
||||
except Exception:
|
||||
results = {}
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
x = np.linspace(0, 1, 10000)
|
||||
@ -500,7 +572,7 @@ def main_beta():
|
||||
for model_key in results.keys():
|
||||
ys_mean, ys_std, score = results[model_key][vis]
|
||||
plt.plot(x, ys_mean, label=f"{model_key} - {score:5.4}")
|
||||
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1)
|
||||
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, alpha=0.2)
|
||||
if vis.endswith("prc"):
|
||||
plt.xlabel('Recall')
|
||||
plt.ylabel('Precision')
|
||||
@ -516,6 +588,8 @@ def main_beta():
|
||||
def main():
|
||||
if "train" == args.mode:
|
||||
main_train()
|
||||
if "retrain" == args.mode:
|
||||
main_retrain()
|
||||
if "hyperband" == args.mode:
|
||||
main_hyperband()
|
||||
if "test" == args.mode:
|
||||
@ -530,6 +604,8 @@ def main():
|
||||
main_paul_best()
|
||||
if "beta" == args.mode:
|
||||
main_beta()
|
||||
if "beta_all" == args.mode:
|
||||
plot_overall_result()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
21
visualize.py
21
visualize.py
@ -84,14 +84,17 @@ def calc_pr_mean(y, y_preds):
|
||||
return ys_mean, ys_std, scores_mean
|
||||
|
||||
|
||||
def plot_mean_curve(x, ys, std, score, label):
|
||||
plt.plot(x, ys, label=f"{label} - {score:5.4}")
|
||||
plt.fill_between(x, ys - std, ys + std, alpha=0.1)
|
||||
plt.ylim([0.0, 1.0])
|
||||
plt.xlim([0.0, 1.0])
|
||||
|
||||
|
||||
def plot_pr_mean(y, y_preds, label=""):
|
||||
x = np.linspace(0, 1, 10000)
|
||||
ys_mean, ys_std, score = calc_pr_mean(y, y_preds)
|
||||
|
||||
plt.plot(x, ys_mean, label=f"{label} - {score:5.4}")
|
||||
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1)
|
||||
plt.ylim([0.0, 1.0])
|
||||
plt.xlim([0.0, 1.0])
|
||||
plot_mean_curve(x, ys_mean, ys_std, score, label)
|
||||
plt.xlabel('Recall')
|
||||
plt.ylabel('Precision')
|
||||
|
||||
@ -142,13 +145,9 @@ def calc_roc_mean(y, y_preds):
|
||||
|
||||
def plot_roc_mean(y, y_preds, label=""):
|
||||
x = np.linspace(0, 1, 10000)
|
||||
ys_mean, ys_std, auc_mean = calc_roc_mean(y, y_preds)
|
||||
ys_mean, ys_std, score = calc_roc_mean(y, y_preds)
|
||||
plt.xscale('log')
|
||||
plt.ylim([0.0, 1.0])
|
||||
plt.xlim([0.0, 1.0])
|
||||
|
||||
plt.plot(x, ys_mean, label=f"{label} - {auc_mean:5.4}")
|
||||
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, color='grey', alpha=0.1)
|
||||
plot_mean_curve(x, ys_mean, ys_std, score, label)
|
||||
plt.xlabel('False Positive Rate')
|
||||
plt.ylabel('True Positive Rate')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user