refactor visualization, change arguments for model type and its depth
This commit is contained in:
parent
933eaae04a
commit
dc9180da10
31
Makefile
31
Makefile
@ -1,16 +1,33 @@
|
||||
run:
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test --epochs 10 \
|
||||
--hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test1 --epochs 10 --depth small \
|
||||
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type final
|
||||
|
||||
run_new:
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test2 --epochs 10 \
|
||||
--hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights --new_model
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test2 --epochs 10 --depth small \
|
||||
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type inter
|
||||
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test3 --epochs 10 --depth medium \
|
||||
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type final
|
||||
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test4 --epochs 10 --depth medium \
|
||||
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type inter
|
||||
|
||||
test:
|
||||
python3 main.py --mode test --batch 128 --model results/test --test data/rk_mini.csv.gz
|
||||
python3 main.py --mode test --batch 128 --models results/test* --test data/rk_mini.csv.gz
|
||||
|
||||
fancy:
|
||||
python3 main.py --mode fancy --batch 128 --model results/test --test data/rk_mini.csv.gz
|
||||
python3 main.py --mode fancy --batch 128 --model results/test1 --test data/rk_mini.csv.gz
|
||||
|
||||
python3 main.py --mode fancy --batch 128 --model results/test2 --test data/rk_mini.csv.gz
|
||||
|
||||
python3 main.py --mode fancy --batch 128 --model results/test3 --test data/rk_mini.csv.gz
|
||||
|
||||
python3 main.py --mode fancy --batch 128 --model results/test4 --test data/rk_mini.csv.gz
|
||||
|
||||
all-fancy:
|
||||
python3 main.py --mode all_fancy --batch 128 --models results/test* --test data/rk_mini.csv.gz
|
||||
|
||||
hyper:
|
||||
python3 main.py --mode hyperband --batch 64 --train data/rk_data.csv.gz
|
||||
|
||||
clean:
|
||||
rm -r results/test*
|
19
arguments.py
19
arguments.py
@ -19,8 +19,14 @@ parser.add_argument("--test", action="store", dest="test_data",
|
||||
parser.add_argument("--model", action="store", dest="model_path",
|
||||
default="results/model_x")
|
||||
|
||||
parser.add_argument("--models", action="store", dest="model_paths", nargs="+",
|
||||
default=[])
|
||||
|
||||
parser.add_argument("--type", action="store", dest="model_type",
|
||||
default="paul")
|
||||
default="final") # inter, final, staggered
|
||||
|
||||
parser.add_argument("--depth", action="store", dest="model_depth",
|
||||
default="small") # small, medium
|
||||
|
||||
parser.add_argument("--model_output", action="store", dest="model_output",
|
||||
default="both")
|
||||
@ -74,6 +80,17 @@ parser.add_argument("--new_model", action="store_true", dest="new_model")
|
||||
|
||||
|
||||
|
||||
def get_model_args(args):
|
||||
return [{
|
||||
"model_path": model_path,
|
||||
"embedding_model": os.path.join(model_path, "embd.h5"),
|
||||
"clf_model": os.path.join(model_path, "clf.h5"),
|
||||
"train_log": os.path.join(model_path, "train.log.csv"),
|
||||
"train_h5data": args.train_data + ".h5",
|
||||
"test_h5data": args.test_data + ".h5",
|
||||
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred.h5")
|
||||
} for model_path in args.model_paths]
|
||||
|
||||
def parse():
|
||||
args = parser.parse_args()
|
||||
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
||||
|
@ -233,11 +233,11 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
||||
|
||||
def load_or_generate_domains(train_data, domain_length):
|
||||
fn = f"{train_data}_domains.gz"
|
||||
char_dict = get_character_dict()
|
||||
|
||||
try:
|
||||
user_flow_df = pd.read_csv(fn)
|
||||
except FileNotFoundError:
|
||||
char_dict = get_character_dict()
|
||||
user_flow_df = get_user_flow_data(train_data)
|
||||
# user_flow_df.reset_index(inplace=True)
|
||||
user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0,
|
||||
|
87
main.py
87
main.py
@ -16,6 +16,7 @@ import models
|
||||
import visualize
|
||||
from dataset import load_or_generate_h5data
|
||||
from utils import exists_or_make_path, get_custom_class_weights
|
||||
from arguments import get_model_args
|
||||
|
||||
logger = logging.getLogger('logger')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
@ -56,6 +57,7 @@ if args.gpu:
|
||||
# default parameter
|
||||
PARAMS = {
|
||||
"type": args.model_type,
|
||||
"depth": args.model_depth,
|
||||
"batch_size": 64,
|
||||
"window_size": args.window,
|
||||
"domain_length": args.domain_length,
|
||||
@ -149,8 +151,8 @@ def main_train(param=None):
|
||||
logger.info("class weights: set default")
|
||||
custom_class_weights = None
|
||||
|
||||
logger.info(f"select model: {'new' if args.new_model else 'old'}")
|
||||
if args.new_model:
|
||||
logger.info(f"select model: {args.model_type}")
|
||||
if args.model_type == "inter":
|
||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||
model = new_model
|
||||
logger.info("compile and train model")
|
||||
@ -181,35 +183,44 @@ def main_train(param=None):
|
||||
|
||||
|
||||
def main_test():
|
||||
logger.info("start test: load data")
|
||||
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
||||
args.domain_length, args.window)
|
||||
clf = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||
pred = clf.predict([domain_val, flow_val],
|
||||
batch_size=args.batch_size,
|
||||
verbose=1)
|
||||
if args.model_output == "both":
|
||||
c_pred, s_pred = pred
|
||||
elif args.model_output == "client":
|
||||
c_pred = pred
|
||||
s_pred = np.zeros(0)
|
||||
else:
|
||||
c_pred = np.zeros(0)
|
||||
s_pred = pred
|
||||
dataset.save_predictions(args.future_prediction, c_pred, s_pred)
|
||||
|
||||
model = load_model(args.embedding_model)
|
||||
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||
domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||
np.save(args.model_path + "/domain_embds.npy", domain_embedding)
|
||||
|
||||
for model_args in get_model_args(args):
|
||||
logger.info(f"process model {model_args['model_path']}")
|
||||
clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics())
|
||||
|
||||
pred = clf_model.predict([domain_val, flow_val],
|
||||
batch_size=args.batch_size,
|
||||
verbose=1)
|
||||
|
||||
if args.model_output == "both":
|
||||
c_pred, s_pred = pred
|
||||
elif args.model_output == "client":
|
||||
c_pred = pred
|
||||
s_pred = np.zeros(0)
|
||||
else:
|
||||
c_pred = np.zeros(0)
|
||||
s_pred = pred
|
||||
dataset.save_predictions(model_args["future_prediction"], c_pred, s_pred)
|
||||
|
||||
embd_model = load_model(model_args["embedding_model"])
|
||||
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||
np.save(model_args["model_path"] + "/domain_embds.npy", domain_embeddings)
|
||||
|
||||
|
||||
def main_visualization():
|
||||
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
||||
args.domain_length, args.window)
|
||||
client_val, server_val = client_val.value, server_val.value
|
||||
logger.info("plot model")
|
||||
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||
visualize.plot_model(model, os.path.join(args.model_path, "model.png"))
|
||||
# client_val, server_val = client_val.value, server_val.value
|
||||
client_val = client_val.value
|
||||
|
||||
# logger.info("plot model")
|
||||
# model = load_model(model_args.clf_model, custom_objects=models.get_metrics())
|
||||
# visualize.plot_model(model, os.path.join(args.model_path, "model.png"))
|
||||
|
||||
try:
|
||||
logger.info("plot training curve")
|
||||
logs = pd.read_csv(args.train_log)
|
||||
@ -224,12 +235,16 @@ def main_visualization():
|
||||
client_pred, server_pred = dataset.load_predictions(args.future_prediction)
|
||||
client_pred, server_pred = client_pred.value, server_pred.value
|
||||
logger.info("plot pr curve")
|
||||
visualize.plot_precision_recall(client_val, client_pred.flatten(), "{}/client_prc.png".format(args.model_path))
|
||||
visualize.plot_clf()
|
||||
visualize.plot_precision_recall(client_val, client_pred)
|
||||
visualize.plot_save("{}/client_prc.png".format(args.model_path))
|
||||
# visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path))
|
||||
# visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path))
|
||||
# visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path))
|
||||
logger.info("plot roc curve")
|
||||
visualize.plot_roc_curve(client_val, client_pred.flatten(), "{}/client_roc.png".format(args.model_path))
|
||||
visualize.plot_clf()
|
||||
visualize.plot_roc_curve(client_val, client_pred)
|
||||
visualize.plot_save("{}/client_roc.png".format(args.model_path))
|
||||
# visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path))
|
||||
visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(),
|
||||
"{}/client_cov.png".format(args.model_path),
|
||||
@ -243,6 +258,26 @@ def main_visualization():
|
||||
visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
|
||||
|
||||
|
||||
def main_visualize_all():
|
||||
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
||||
args.domain_length, args.window)
|
||||
logger.info("plot pr curves")
|
||||
visualize.plot_clf()
|
||||
for model_args in get_model_args(args):
|
||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||
visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"])
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("all_client_prc.png")
|
||||
|
||||
logger.info("plot roc curves")
|
||||
visualize.plot_clf()
|
||||
for model_args in get_model_args(args):
|
||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||
visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"])
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("all_client_roc.png")
|
||||
|
||||
|
||||
def main_data():
|
||||
char_dict = dataset.get_character_dict()
|
||||
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
||||
@ -265,6 +300,8 @@ def main():
|
||||
main_test()
|
||||
if "fancy" == args.mode:
|
||||
main_visualization()
|
||||
if "all_fancy" == args.mode:
|
||||
main_visualize_all()
|
||||
if "paul" == args.mode:
|
||||
main_paul_best()
|
||||
if "data" == args.mode:
|
||||
|
@ -8,6 +8,7 @@ def get_models_by_params(params: dict):
|
||||
# decomposing param section
|
||||
# mainly embedding model
|
||||
network_type = params.get("type")
|
||||
network_depth = params.get("depth")
|
||||
embedding_size = params.get("embedding_size")
|
||||
input_length = params.get("input_length")
|
||||
filter_embedding = params.get("filter_embedding")
|
||||
@ -24,7 +25,12 @@ def get_models_by_params(params: dict):
|
||||
dense_dim = params.get("dense_main")
|
||||
model_output = params.get("model_output", "both")
|
||||
# create models
|
||||
networks = renes_networks if network_type == "rene" else pauls_networks
|
||||
if network_depth == "small":
|
||||
networks = pauls_networks
|
||||
elif network_depth == "medium":
|
||||
networks = renes_networks
|
||||
else:
|
||||
raise Exception("network not found")
|
||||
embedding_model = networks.get_embedding(embedding_size, input_length, filter_embedding, kernel_embedding,
|
||||
hidden_embedding, dropout)
|
||||
|
||||
|
114
run.sh
114
run.sh
@ -1,98 +1,26 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
python main.py --mode train \
|
||||
--train /tmp/rk/currentData.csv \
|
||||
--model /tmp/rk/results/small_both \
|
||||
--epochs 25 \
|
||||
--embd 64 \
|
||||
--hidden_chaar_dims 128 \
|
||||
--domain_embd 32 \
|
||||
--batch 256 \
|
||||
--balanced_weights \
|
||||
--model_output both
|
||||
|
||||
python main.py --mode train \
|
||||
--train /tmp/rk/currentData.csv \
|
||||
--model /tmp/rk/results/small_client \
|
||||
--epochs 25 \
|
||||
--embd 64 \
|
||||
--hidden_char_dims 128 \
|
||||
--domain_embd 32 \
|
||||
--batch 256 \
|
||||
--balanced_weights \
|
||||
--model_output client
|
||||
for output in client both
|
||||
do
|
||||
for depth in small medium
|
||||
do
|
||||
for mtype in inter final staggered
|
||||
do
|
||||
|
||||
python main.py --mode train \
|
||||
--train /tmp/rk/currentData.csv \
|
||||
--model /tmp/rk/results/small_new_both \
|
||||
--epochs 25 \
|
||||
--embd 64 \
|
||||
--hidden_char_dims 128 \
|
||||
--domain_embd 32 \
|
||||
--batch 256 \
|
||||
--balanced_weights \
|
||||
--model_output both \
|
||||
--new_model
|
||||
python main.py --mode train \
|
||||
--train /tmp/rk/currentData.csv \
|
||||
--model /tmp/rk/results/${output}_${depth}_${mtype} \
|
||||
--epochs 50 \
|
||||
--embd 64 \
|
||||
--hidden_chaar_dims 128 \
|
||||
--domain_embd 32 \
|
||||
--batch 256 \
|
||||
--balanced_weights \
|
||||
--model_output ${output} \
|
||||
--type ${mtype} \
|
||||
--depth ${depth}
|
||||
|
||||
python main.py --mode train \
|
||||
--train /tmp/rk/currentData.csv \
|
||||
--model /tmp/rk/results/small_new_client \
|
||||
--epochs 25 \
|
||||
--embd 64 \
|
||||
--hidden_char_dims 128 \
|
||||
--domain_embd 32 \
|
||||
--batch 256 \
|
||||
--balanced_weights \
|
||||
--model_output client \
|
||||
--new_model
|
||||
##
|
||||
##
|
||||
python main.py --mode train \
|
||||
--train /tmp/rk/currentData.csv \
|
||||
--model /tmp/rk/results/medium_both \
|
||||
--epochs 25 \
|
||||
--embd 64 \
|
||||
--hidden_char_dims 128 \
|
||||
--domain_embd 32 \
|
||||
--batch 256 \
|
||||
--balanced_weights \
|
||||
--model_output both \
|
||||
--type rene
|
||||
|
||||
python main.py --mode train \
|
||||
--train /tmp/rk/currentData.csv \
|
||||
--model /tmp/rk/results/medium_client \
|
||||
--epochs 25 \
|
||||
--embd 64 \
|
||||
--hidden_char_dims 128 \
|
||||
--domain_embd 32 \
|
||||
--batch 256 \
|
||||
--balanced_weights \
|
||||
--model_output client \
|
||||
--type rene
|
||||
|
||||
python main.py --mode train \
|
||||
--train /tmp/rk/currentData.csv \
|
||||
--model /tmp/rk/results/medium_new_both \
|
||||
--epochs 25 \
|
||||
--embd 64 \
|
||||
--hidden_char_dims 128 \
|
||||
--domain_embd 32 \
|
||||
--batch 256 \
|
||||
--balanced_weights \
|
||||
--model_output both \
|
||||
--new_model \
|
||||
--type rene
|
||||
|
||||
python main.py --mode train \
|
||||
--train /tmp/rk/currentData.csv \
|
||||
--model /tmp/rk/results/medium_new_client \
|
||||
--epochs 25 \
|
||||
--embd 64 \
|
||||
--hidden_char_dims 128 \
|
||||
--domain_embd 32 \
|
||||
--batch 256 \
|
||||
--balanced_weights \
|
||||
--model_output client \
|
||||
--new_model \
|
||||
--type rene
|
||||
done
|
||||
done
|
||||
done
|
||||
|
45
visualize.py
45
visualize.py
@ -2,7 +2,6 @@ import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from keras.utils.vis_utils import plot_model
|
||||
from sklearn.decomposition import TruncatedSVD
|
||||
from sklearn.metrics import (
|
||||
auc, classification_report, confusion_matrix, fbeta_score, precision_recall_curve,
|
||||
@ -32,38 +31,43 @@ def scores(y_true, y_pred):
|
||||
print(" f0.5 score:", f05_score)
|
||||
|
||||
|
||||
def plot_precision_recall(mask, prediction, path):
|
||||
y = mask.flatten()
|
||||
y_pred = prediction.flatten()
|
||||
def plot_clf():
|
||||
plt.clf()
|
||||
|
||||
|
||||
def plot_save(path, dpi=600):
|
||||
plt.savefig(path, dpi=dpi)
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_legend():
|
||||
plt.legend()
|
||||
|
||||
|
||||
def plot_precision_recall(y, y_pred, label=""):
|
||||
y = y.flatten()
|
||||
y_pred = y_pred.flatten()
|
||||
precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
||||
decreasing_max_precision = np.maximum.accumulate(precision)[::-1]
|
||||
|
||||
plt.clf()
|
||||
# fig, ax = plt.subplots(1, 1)
|
||||
# ax.hold(True)
|
||||
plt.plot(recall, precision, '--b')
|
||||
plt.plot(recall, precision, '--', label=label)
|
||||
# ax.step(recall[::-1], decreasing_max_precision, '-r')
|
||||
plt.xlabel('Recall')
|
||||
plt.ylabel('Precision')
|
||||
|
||||
plt.savefig(path, dpi=600)
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_precision_recall_curves(mask, prediction, path):
|
||||
y = mask.flatten()
|
||||
y_pred = prediction.flatten()
|
||||
def plot_precision_recall_curves(y, y_pred):
|
||||
y = y.flatten()
|
||||
y_pred = y_pred.flatten()
|
||||
precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
||||
|
||||
plt.clf()
|
||||
plt.plot(recall, label="Recall")
|
||||
plt.plot(precision, label="Precision")
|
||||
plt.xlabel('Threshold')
|
||||
plt.ylabel('Score')
|
||||
|
||||
plt.savefig(path, dpi=600)
|
||||
plt.close()
|
||||
|
||||
|
||||
def score_model(y, prediction):
|
||||
y = y.flatten()
|
||||
@ -78,16 +82,12 @@ def score_model(y, prediction):
|
||||
print("F0.5 Score", fbeta_score(y, y_pred.round(), 0.5))
|
||||
|
||||
|
||||
def plot_roc_curve(mask, prediction, path):
|
||||
def plot_roc_curve(mask, prediction, label=""):
|
||||
y = mask.flatten()
|
||||
y_pred = prediction.flatten()
|
||||
fpr, tpr, thresholds = roc_curve(y, y_pred)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
plt.clf()
|
||||
plt.plot(fpr, tpr)
|
||||
plt.savefig(path, dpi=600)
|
||||
plt.close()
|
||||
|
||||
plt.plot(fpr, tpr, label=label)
|
||||
print("roc_auc", roc_auc)
|
||||
|
||||
|
||||
@ -161,4 +161,5 @@ def plot_embedding(domain_embedding, labels, path, dpi=600):
|
||||
|
||||
|
||||
def plot_model_as(model, path):
|
||||
from keras.utils.vis_utils import plot_model
|
||||
plot_model(model, to_file=path, show_shapes=True, show_layer_names=True)
|
||||
|
Loading…
Reference in New Issue
Block a user