ma_cisco_malware/main.py

855 lines
36 KiB
Python

import logging
import operator
import os
import joblib
import numpy as np
import pandas as pd
import tensorflow as tf
from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint
from sklearn.metrics import confusion_matrix
import arguments
import dataset
import hyperband
import models
# create logger
import visualize
from arguments import get_model_args
from utils import exists_or_make_path, get_custom_class_weights, get_custom_sample_weights, load_model
logger = logging.getLogger('cisco_logger')
logger.setLevel(logging.DEBUG)
logger.propagate = False
# create console handler and set level to debug
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
# create formatter
formatter1 = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# add formatter to ch
ch.setFormatter(formatter1)
# add ch to logger
logger.addHandler(ch)
# ch = logging.FileHandler("info.log")
# ch.setLevel(logging.DEBUG)
#
# # create formatter
# formatter2 = logging.Formatter('!! %(asctime)s - %(name)s - %(levelname)s - %(message)s')
#
# # add formatter to ch
# ch.setFormatter(formatter2)
#
# # add ch to logger
# logger.addHandler(ch)
args = arguments.parse()
if args.gpu:
config = tf.ConfigProto(log_device_placement=True)
config.gpu_options.per_process_gpu_memory_fraction = 0.5
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
# default parameter
PARAMS = {
"type": args.model_type,
"embedding_type": args.embedding_type,
# "depth": args.model_depth,
"batch_size": args.batch_size,
"window_size": args.window,
"domain_length": args.domain_length,
"flow_features": 3,
#
'dropout': 0.5, # currently fix
'embedding': args.embedding,
'flow_features': 3,
'filter_embedding': args.filter_embedding,
'dense_embedding': args.dense_embedding,
'kernel_embedding': args.kernel_embedding,
'filter_main': args.filter_main,
'dense_main': args.dense_main,
'kernel_main': args.kernel_main,
'model_output': args.model_output
}
# TODO: remove inner global params
def get_param_dist(dist_size="small"):
if dist_size == "small":
return {
# static params
"type": [args.model_type],
"embedding_type": [args.embedding_type],
# "depth": [args.model_depth],
"model_output": [args.model_output],
"batch_size": [args.batch_size],
"window_size": [args.window],
"flow_features": [3],
"domain_length": [args.domain_length],
# model params
"embedding": [2 ** x for x in range(3, 6)],
"filter_embedding": [2 ** x for x in range(1, 8)],
"kernel_embedding": [1, 3, 5],
"dense_embedding": [2 ** x for x in range(4, 8)],
"dropout": [0.5],
"filter_main": [2 ** x for x in range(1, 8)],
"kernel_main": [1, 3, 5],
"dense_main": [2 ** x for x in range(1, 8)],
}
else:
return {
# static params
"type": [args.model_type],
"embedding_type": [args.embedding_type],
# "depth": [args.model_depth],
"model_output": [args.model_output],
"batch_size": [args.batch_size],
"window_size": [args.window],
"flow_features": [3],
"domain_length": [args.domain_length],
# model params
"embedding": [2 ** x for x in range(3, 7)],
"filter_embedding": [2 ** x for x in range(1, 10)],
"kernel_embedding": [1, 3, 5, 7, 9],
"dense_embedding": [2 ** x for x in range(4, 10)],
"dropout": [0.5],
"filter_main": [2 ** x for x in range(1, 10)],
"kernel_main": [1, 3, 5, 7, 9],
"dense_main": [2 ** x for x in range(1, 12)],
}
def shuffle_training_data(domain, flow, client, server):
idx = np.random.permutation(len(domain))
domain = domain[idx]
flow = flow[idx]
client = client[idx]
server = server[idx]
return domain, flow, client, server
def main_paul_best():
pauls_best_params = {
"type": "paul",
"batch_size": 64,
"window_size": 10,
"domain_length": 40,
"flow_features": 3,
#
'dropout': 0.5,
'domain_features': 32,
'drop_out': 0.5,
'embedding_size': 64,
'filter_main': 512,
'flow_features': 3,
'dense_main': 32,
'filter_embedding': 32,
'hidden_embedding': 32,
'kernel_embedding': 8,
'kernels_main': 8,
'input_length': 40
}
main_train(pauls_best_params)
def main_hyperband(data, domain_length, window_size, model_type, result_file, max_iter, dist_size="small"):
logger.info("create training dataset")
domain_tr, flow_tr, client_tr, server_tr = load_data(data, domain_length, window_size, model_type, shuffled=True)
return run_hyperband(dist_size, domain_tr, flow_tr, client_tr, server_tr, max_iter, result_file)
def run_hyperband(dist_size, domain, flow, client, server, max_iter, savefile):
param_dist = get_param_dist(dist_size)
hp = hyperband.Hyperband(param_dist,
[domain, flow],
[client, server],
max_iter=max_iter,
savefile=savefile)
results = hp.run()
return results
def train(parameters, features, labels):
pass
def load_data(data, domain_length, window_size, model_type, shuffled=False):
# data preparation
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data, domain_length,
window_size)
server_tr = np.max(server_windows_tr, axis=1)
if model_type in ("inter", "staggered"):
server_tr = np.expand_dims(server_windows_tr, 2)
if shuffled:
domain_tr, flow_tr, client_tr, server_tr = shuffle_training_data(domain_tr, flow_tr, client_tr, server_tr)
return domain_tr, flow_tr, client_tr, server_tr
def get_weighting(class_weights, sample_weights, client, server):
if class_weights:
logger.info("class weights: compute custom weights")
custom_class_weights = get_custom_class_weights(client, server)
logger.info(custom_class_weights)
else:
logger.info("class weights: set default")
custom_class_weights = None
if sample_weights:
logger.info("class weights: compute custom weights")
custom_sample_weights = get_custom_sample_weights(client, server)
logger.info(custom_sample_weights)
else:
logger.info("class weights: set default")
custom_sample_weights = None
return custom_class_weights, custom_sample_weights
def main_train(param=None):
logger.info(f"Create model path {args.model_path}")
exists_or_make_path(args.model_path)
logger.info(f"Use command line arguments: {args}")
# data preparation
domain_tr, flow_tr, client_tr, server_tr = load_data(args.data, args.domain_length,
args.window, args.model_type)
# call hyperband if used
if args.hyperband_results:
try:
hyper_results = joblib.load(args.hyperband_results)
except Exception:
logger.info("start hyperband parameter search")
hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, args.hyper_max_iter,
args.hyperband_results)
param = sorted(hyper_results, key=operator.itemgetter("loss"))[0]["params"]
param["type"] = args.model_type
logger.info(f"select params from result: {param}")
if not param:
param = PARAMS
# custom class or sample weights
custom_class_weights, custom_sample_weights = get_weighting(args.class_weights, args.sample_weights,
client_tr.value, server_tr)
for i in range(args.runs):
model_path = os.path.join(args.model_path, f"clf_{i}.h5")
train_log_path = os.path.join(args.model_path, f"train_{i}.log.csv")
# define training call backs
logger.info("define callbacks")
callbacks = []
callbacks.append(ModelCheckpoint(filepath=model_path,
monitor='loss',
verbose=False,
save_best_only=True))
callbacks.append(CSVLogger(train_log_path))
logger.info(f"Use early stopping: {args.stop_early}")
if args.stop_early:
callbacks.append(EarlyStopping(monitor='val_loss',
patience=5,
verbose=False))
custom_metrics = models.get_metric_functions()
logger.info(f"Generator model with params: {param}")
model = models.get_models_by_params(param)
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
if args.model_output == "both":
labels = {"client": client_tr.value, "server": server_tr}
loss_weights = {"client": 1.0, "server": 1.0}
elif args.model_output == "client":
labels = {"client": client_tr.value}
loss_weights = {"client": 1.0}
elif args.model_output == "server":
labels = {"server": server_tr}
loss_weights = {"server": 1.0}
else:
raise ValueError("unknown model output")
logger.info(f"select model: {args.model_type}")
if args.model_type == "staggered":
logger.info("compile and pre-train server model")
logger.info(model.get_config())
model.compile(optimizer='adam',
loss='binary_crossentropy',
loss_weights={"client": 0.0, "server": 1.0},
metrics=['accuracy'] + custom_metrics)
model.summary()
model.fit(features, labels,
batch_size=args.batch_size,
epochs=args.epochs,
class_weight=custom_class_weights,
sample_weight=custom_sample_weights)
logger.info("fix server model")
model.get_layer("domain_cnn").trainable = False
model.get_layer("domain_cnn").layer.trainable = False
model.get_layer("dense_server").trainable = False
model.get_layer("server").trainable = False
loss_weights = {"client": 1.0, "server": 0.0}
logger.info("compile and train model")
logger.info(model.get_config())
model.compile(optimizer='adam',
loss='binary_crossentropy',
loss_weights=loss_weights,
metrics=['accuracy'] + custom_metrics)
model.summary()
model.fit(features, labels,
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks,
class_weight=custom_class_weights,
sample_weight=custom_sample_weights)
def main_retrain():
source = os.path.join(args.model_source, "clf.h5")
destination = os.path.join(args.model_destination, "clf.h5")
logger.info(f"Use command line arguments: {args}")
exists_or_make_path(args.model_destination)
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
args.domain_length,
args.window)
logger.info("define callbacks")
callbacks = []
callbacks.append(ModelCheckpoint(filepath=destination,
monitor='loss',
verbose=False,
save_best_only=True))
callbacks.append(CSVLogger(args.train_log))
logger.info(f"Use early stopping: {args.stop_early}")
if args.stop_early:
callbacks.append(EarlyStopping(monitor='val_loss',
patience=5,
verbose=False))
server_tr = np.max(server_windows_tr, axis=1)
if args.class_weights:
logger.info("class weights: compute custom weights")
custom_class_weights = get_custom_class_weights(client_tr.value, server_tr)
logger.info(custom_class_weights)
else:
logger.info("class weights: set default")
custom_class_weights = None
logger.info(f"Load pretrained model")
embedding, model = load_model(source, custom_objects=models.get_custom_objects())
if args.model_type in ("inter", "staggered"):
server_tr = np.expand_dims(server_windows_tr, 2)
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
if args.model_output == "both":
labels = {"client": client_tr.value, "server": server_tr}
elif args.model_output == "client":
labels = {"client": client_tr.value}
elif args.model_output == "server":
labels = {"server": server_tr}
else:
raise ValueError("unknown model output")
logger.info("re-train model")
embedding.summary()
model.summary()
model.fit(features, labels,
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks,
class_weight=custom_class_weights,
initial_epoch=args.initial_epoch)
def main_test():
logger.info("load test data")
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window)
logger.info("load test domains")
domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
def get_dir(path):
return os.path.split(os.path.normpath(path))
results = {}
for model_path in args.model_paths:
file = get_dir(model_path)[1]
results[file] = {}
logger.info(f"process model {model_path}")
embd_model, clf_model = load_model(model_path, custom_objects=models.get_custom_objects())
pred = clf_model.predict([domain_val, flow_val],
batch_size=args.batch_size,
verbose=1)
if args.model_output == "both":
c_pred, s_pred = pred
results[file]["client_pred"] = c_pred
results[file]["server_pred"] = s_pred
elif args.model_output == "client":
results[file]["client_pred"] = pred
else:
results[file]["server_pred"] = pred
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
results["domain_embds"] = domain_embeddings
# store results every round - safety first!
dataset.save_predictions(get_dir(model_path)[0], results)
def main_visualization():
def plot_model(clf_model, path):
embd, model = load_model(clf_model, custom_objects=models.get_custom_objects())
visualize.plot_model_as(embd, os.path.join(path, "model_embd.pdf"), shapes=False)
visualize.plot_model_as(model, os.path.join(path, "model_clf.pdf"), shapes=False)
def vis(model_name, model_path, df, df_paul, aggregation, curve):
visualize.plot_clf()
if aggregation == "user":
df = df.groupby(df.names).max()
df_paul = df_paul.groupby(df_paul.names).max()
if curve == "prc":
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_name)
visualize.plot_precision_recall(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
elif curve == "roc":
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_name)
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
visualize.plot_legend()
visualize.plot_save("{}/{}_{}.pdf".format(model_path, aggregation, curve))
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
args.domain_length,
args.window)
results = dataset.load_predictions(args.model_path)
df = pd.DataFrame(data={
"names": name_val, "client_pred": results["client_pred"].flatten(),
"hits_vt": hits_vt, "hits_trusted": hits_trusted
})
df["client_val"] = np.logical_or(df.hits_vt == 1.0, df.hits_trusted >= 3)
df_user = df.groupby(df.names).max()
paul = dataset.load_predictions("results/paul/")
df_paul = pd.DataFrame(data={
"names": paul["testNames"].flatten(), "client_pred": paul["testScores"].flatten(),
"hits_vt": paul["testLabel"].flatten(), "hits_trusted": paul["testHits"].flatten()
})
df_paul["client_val"] = np.logical_or(df_paul.hits_vt == 1.0, df_paul.hits_trusted >= 3)
logger.info("plot model")
plot_model(args.clf_model, args.model_path)
# logger.info("plot training curve")
# logs = pd.read_csv(args.train_log)
# if "acc" in logs.keys():
# visualize.plot_training_curve(logs, "", "{}/client_train.png".format(args.model_path))
# elif "client_acc" in logs.keys() and "server_acc" in logs.keys():
# visualize.plot_training_curve(logs, "client_", "{}/client_train.png".format(args.model_path))
# visualize.plot_training_curve(logs, "server_", "{}/server_train.png".format(args.model_path))
# else:
# logger.warning("Error while plotting training curves")
logger.info("plot window prc")
vis(args.model_name, args.model_path, df, df_paul, "window", "prc")
logger.info("plot window roc")
vis(args.model_name, args.model_path, df, df_paul, "window", "roc")
logger.info("plot user prc")
vis(args.model_name, args.model_path, df, df_paul, "user", "prc")
logger.info("plot user roc")
vis(args.model_name, args.model_path, df, df_paul, "user", "roc")
# absolute values
visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(),
"{}/client_cov.pdf".format(args.model_path),
normalize=False, title="Client Confusion Matrix")
visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(),
"{}/user_cov.pdf".format(args.model_path),
normalize=False, title="User Confusion Matrix")
# normalized
visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(),
"{}/client_cov_norm.pdf".format(args.model_path),
normalize=True, title="Client Confusion Matrix")
visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(),
"{}/user_cov_norm.pdf".format(args.model_path),
normalize=True, title="User Confusion Matrix")
def main_visualize_all():
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
args.domain_length,
args.window)
def load_df(path):
res = dataset.load_predictions(path)
res = pd.DataFrame(data={
"names": name_val, "client_pred": res["client_pred"].flatten(),
"hits_vt": hits_vt, "hits_trusted": hits_trusted
})
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
return res
dfs = [(model_args["model_name"], load_df(model_args["model_path"])) for model_args in get_model_args(args)]
paul = dataset.load_predictions("results/paul/")
df_paul = pd.DataFrame(data={
"names": paul["testNames"].flatten(), "client_pred": paul["testScores"].flatten(),
"hits_vt": paul["testLabel"].flatten(), "hits_trusted": paul["testHits"].flatten()
})
df_paul["client_val"] = np.logical_or(df_paul.hits_vt == 1.0, df_paul.hits_trusted >= 3)
def vis(output_prefix, dfs, df_paul, aggregation, curve):
visualize.plot_clf()
if curve == "prc":
for model_name, df in dfs:
if aggregation == "user":
df = df.groupby(df.names).max()
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_name)
if aggregation == "user":
df_paul = df_paul.groupby(df_paul.names).max()
visualize.plot_precision_recall(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
elif curve == "roc":
for model_name, df in dfs:
if aggregation == "user":
df = df.groupby(df.names).max()
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_name)
if aggregation == "user":
df_paul = df_paul.groupby(df_paul.names).max()
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
visualize.plot_legend()
visualize.plot_save("{}_{}_{}.pdf".format(output_prefix, aggregation, curve))
logger.info("plot pr curves")
vis(args.output_prefix, dfs, df_paul, "window", "prc")
logger.info("plot roc curves")
vis(args.output_prefix, dfs, df_paul, "window", "roc")
logger.info("plot user pr curves")
vis(args.output_prefix, dfs, df_paul, "user", "prc")
logger.info("plot user roc curves")
vis(args.output_prefix, dfs, df_paul, "user", "roc")
def main_visualize_all_embds():
def load_df(path):
res = dataset.load_predictions(path)
return res["domain_embds"]
dfs = [(model_args["model_name"], load_df(model_args["model_path"])) for model_args in get_model_args(args)]
from sklearn.decomposition import TruncatedSVD
def vis2(domain_embedding, labels):
n_levels = 7
logger.info(f"reduction for {len(domain_embedding)} points")
red = TruncatedSVD(n_components=2, algorithm="arpack")
domains = red.fit_transform(domain_embedding)
logger.info("plot kde")
benign = domains[labels.sum(axis=1) == 0]
# print(domains.shape)
# print(benign.shape)
# benign_idx
# sns.kdeplot(domains[labels.sum(axis=1) == 0, 0], domains[labels.sum(axis=1) == 0, 1],
# cmap="Blues", label="benign", n_levels=9, alpha=0.35, shade=True, shade_lowest=False)
# sns.kdeplot(domains[labels[:, 1], 0], domains[labels[:, 1], 1],
# cmap="Greens", label="server", n_levels=5, alpha=0.35, shade=True, shade_lowest=False)
# sns.kdeplot(domains[labels[:, 0], 0], domains[labels[:, 0], 1],
# cmap="Reds", label="client", n_levels=5, alpha=0.35, shade=True, shade_lowest=False)
plt.scatter(benign[benign_idx, 0], benign[benign_idx, 1],
cmap="Blues", label="benign", alpha=0.35, s=10)
plt.scatter(domains[labels[:, 1], 0], domains[labels[:, 1], 1],
cmap="Greens", label="server", alpha=0.35, s=10)
plt.scatter(domains[labels[:, 0], 0], domains[labels[:, 0], 1],
cmap="Reds", label="client", alpha=0.35, s=10)
return np.concatenate((domains[:1000], domains[1000:2000], domains[2000:3000]), axis=0)
domain_encs, _, labels = dataset.load_or_generate_domains(args.data, args.domain_length)
idx = np.arange(len(labels))
client = labels[:, 0]
server = labels[:, 1]
benign = np.logical_not(np.logical_or(client, server))
print(client.sum(), server.sum(), benign.sum())
idx = np.concatenate((
np.random.choice(idx[client], 1000),
np.random.choice(idx[server], 1000),
np.random.choice(idx[benign], 6000)), axis=0)
benign_idx = np.random.choice(np.arange(6000), 1000)
print(idx.shape)
lls = labels[idx]
for model_name, embd in dfs:
logger.info(f"plot embedding for {model_name}")
visualize.plot_clf()
embd = embd[idx]
points = vis2(embd, lls)
# np.savetxt("{}_{}.csv".format(args.output_prefix, model_name), points, delimiter=",")
visualize.plot_save("{}_{}.pdf".format(args.output_prefix, model_name))
def main_beta():
domain_val, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
args.domain_length,
args.window)
path, model_prefix = os.path.split(os.path.normpath(args.model_path))
print(path, model_prefix)
try:
curves = joblib.load(f"{path}/curves.joblib")
logger.info(f"load file {path}/curves.joblib successfully")
except Exception:
curves = {}
logger.info(f"currently {len(curves)} models in file: {curves.keys()}")
curves[model_prefix] = {"all": {}}
domains = domain_val.value.reshape(-1, 40)
domains = np.apply_along_axis(lambda d: dataset.decode_domain(d), 1, domains)
def load_df(res):
df_server = None
data = {
"names": name_val, "client_pred": res["client_pred"].flatten(),
"hits_vt": hits_vt, "hits_trusted": hits_trusted,
}
if "server_pred" in res:
server = res["server_pred"] if len(res["server_pred"].shape) == 2 else res["server_pred"].max(axis=1)
val = server_val.value.max(axis=1)
data["server_pred"] = server.flatten()
data["server_val"] = val.flatten()
if res["server_pred"].flatten().shape == server_val.value.flatten().shape:
df_server = pd.DataFrame(data={
"server_pred": res["server_pred"].flatten(),
"domain": domains,
"server_val": server_val.value.flatten()
})
res = pd.DataFrame(data=data)
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
return res, df_server
logger.info(f"load results from {args.model_path}")
res = dataset.load_predictions(args.model_path)
model_keys = sorted(filter(lambda x: x.startswith("clf"), res.keys()), key=lambda x: int(x[4:-3]))
client_preds = []
server_preds = []
server_flow_preds = []
client_user_preds = []
server_user_preds = []
server_domain_preds = []
server_domain_avg_preds = []
for model_name in model_keys:
logger.info(f"load model {model_name}")
df, df_server = load_df(res[model_name])
client_preds.append(df.client_pred.as_matrix())
if "server_val" in df.columns:
server_preds.append(df.server_pred.as_matrix())
if df_server is not None:
logger.info(f" group servers")
server_flow_preds.append(df_server.server_pred.as_matrix())
df_domain = df_server.groupby(df_server.domain).max()
server_domain_preds.append(df_domain.server_pred.as_matrix())
df_domain_avg = df_server.groupby(df_server.domain).rolling(10).mean()
server_domain_avg_preds.append(df_domain_avg.server_pred.as_matrix())
curves[model_prefix][model_name] = confusion_matrix(df.client_val.as_matrix(),
df.client_pred.as_matrix().round())
logger.info(f" group users")
df_user = df.groupby(df.names).max()
client_user_preds.append(df_user.client_pred.as_matrix())
if "server_val" in df.columns:
server_user_preds.append(df_user.server_pred.as_matrix())
logger.info("compute client curves")
curves[model_prefix]["all"]["client_window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), client_preds)
curves[model_prefix]["all"]["client_window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), client_preds)
curves[model_prefix]["all"]["client_user_prc"] = visualize.calc_pr_mean(df_user.client_val.as_matrix(),
client_user_preds)
curves[model_prefix]["all"]["client_user_roc"] = visualize.calc_roc_mean(df_user.client_val.as_matrix(),
client_user_preds)
if "server_val" in df.columns:
logger.info("compute server curves")
curves[model_prefix]["all"]["server_window_prc"] = visualize.calc_pr_mean(df.server_val.as_matrix(),
server_preds)
curves[model_prefix]["all"]["server_window_roc"] = visualize.calc_roc_mean(df.server_val.as_matrix(),
server_preds)
curves[model_prefix]["all"]["server_user_prc"] = visualize.calc_pr_mean(df_user.server_val.as_matrix(),
server_user_preds)
curves[model_prefix]["all"]["server_user_roc"] = visualize.calc_roc_mean(df_user.server_val.as_matrix(),
server_user_preds)
if df_server is not None:
logger.info("compute server flow curves")
curves[model_prefix]["all"]["server_flow_prc"] = visualize.calc_pr_mean(df_server.server_val.as_matrix(),
server_flow_preds)
curves[model_prefix]["all"]["server_flow_roc"] = visualize.calc_roc_mean(df_server.server_val.as_matrix(),
server_flow_preds)
curves[model_prefix]["all"]["server_domain_prc"] = visualize.calc_pr_mean(df_domain.server_val.as_matrix(),
server_domain_preds)
curves[model_prefix]["all"]["server_domain_roc"] = visualize.calc_roc_mean(df_domain.server_val.as_matrix(),
server_domain_preds)
curves[model_prefix]["all"]["server_domain_avg_prc"] = visualize.calc_pr_mean(
df_domain_avg.server_val.as_matrix(),
server_domain_avg_preds)
curves[model_prefix]["all"]["server_domain_avg_roc"] = visualize.calc_roc_mean(
df_domain_avg.server_val.as_matrix(),
server_domain_avg_preds)
joblib.dump(curves, f"{path}/curves.joblib")
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
def plot_overall_result():
path, model_prefix = os.path.split(os.path.normpath(args.model_path))
exists_or_make_path(f"{path}/figs/curves/")
try:
results = joblib.load(f"{path}/curves.joblib")
logger.info("curves successfully loaded")
except Exception:
results = {}
x = np.linspace(0, 1, 10000)
for vis in ["client_window_prc", "client_window_roc", "client_user_prc", "client_user_roc",
"server_window_prc", "server_window_roc", "server_user_prc", "server_user_roc",
"server_flow_prc", "server_flow_roc", "server_domain_prc", "server_domain_roc"]:
logger.info(f"plot {vis}")
visualize.plot_clf()
for model_key in results.keys():
if vis not in results[model_key]["all"]:
continue
if "final" in model_key and vis.startswith("server_flow"):
continue
ys_mean, ys_std, ys = results[model_key]["all"][vis]
plt.plot(x, ys_mean, label=f"{model_key} - {np.mean(ys_mean):5.4} ({np.mean(ys_std):4.3})")
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, alpha=0.2)
if vis.endswith("prc"):
plt.xlabel('Recall')
plt.ylabel('Precision')
else:
plt.plot(x, x, label="random classifier", ls="--", c=".3", alpha=0.4)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.xscale('log')
plt.ylim([0.0, 1.0])
plt.xlim([0.0, 1.0])
visualize.plot_legend()
visualize.plot_save(f"{path}/figs/curves/{vis}_all.pdf")
return
for vis in ["client_window_prc", "client_window_roc", "client_user_prc", "client_user_roc",
"server_window_prc", "server_window_roc", "server_user_prc", "server_user_roc",
"server_flow_prc", "server_flow_roc", "server_domain_prc", "server_domain_roc"]:
logger.info(f"plot {vis}")
visualize.plot_clf()
for model_key in results.keys():
if vis not in results[model_key]["all"]:
continue
if "final" in model_key and vis.startswith("server_flow"):
continue
_, _, ys = results[model_key]["all"][vis]
for y in ys:
plt.plot(x, y, label=f"{model_key} - {np.mean(y):5.4}")
if vis.endswith("prc"):
plt.xlabel('Recall')
plt.ylabel('Precision')
else:
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.xscale('log')
plt.ylim([0.0, 1.0])
plt.xlim([0.0, 1.0])
visualize.plot_legend()
visualize.plot_save(f"{path}/figs/Appendices/{model_key}_{vis}.pdf")
def main_stats():
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
for time in ("current", "future"):
df = dataset.get_user_flow_data(f"data/{time}Data.csv.gz")
df["clientlabel"] = np.logical_or(df.virusTotalHits > 3, df.trustedHits > 0)
# df_user = df.groupby(df.user_hash).max()
# df_server = df.groupby(df.domain).max()
# len(df)
# df.clientlabel.sum()
# df.serverLabel.sum()
for col in ["duration", "bytes_down", "bytes_up"]:
# visualize.plot_clf()
plt.clf()
plt.hist(df[col])
visualize.plot_save(f"{path}/figs/hist_{time}_{col}.pdf")
print(".")
# visualize.plot_clf()
plt.clf()
plt.hist(np.log1p(df[col]))
visualize.plot_save(f"{path}/figs/hist_{time}_norm_{col}.pdf")
print("-")
def main_stats2():
import joblib
res = joblib.load("results/variance_test_hyper/curves.joblib")
for vis in ["client_window_prc", "client_window_roc", "client_user_prc", "client_user_roc",
"server_window_prc", "server_window_roc", "server_user_prc", "server_user_roc",
"server_flow_prc", "server_flow_roc", "server_domain_prc", "server_domain_roc",
"server_domain_avg_prc", "server_domain_avg_roc"]:
tab = []
for m, r in res.items():
if vis not in r: continue
tab.append(r["all"][vis][2].mean(axis=1))
if not tab: continue
df = pd.DataFrame(data=np.vstack(tab).T, columns=list(res.keys()),
index=range(1, 21))
df.to_csv(f"{vis}.csv")
print(f"% {vis}")
print(df.round(4).to_latex())
print()
def main():
if "train" == args.mode:
main_train()
if "retrain" == args.mode:
main_retrain()
if "hyperband" == args.mode:
main_hyperband(args.data, args.domain_length, args.window, args.model_type, args.hyperband_results,
args.hyper_max_iter)
if "test" == args.mode:
main_test()
if "beta" == args.mode:
main_beta()
if "all_beta" == args.mode:
plot_overall_result()
if "embedding" == args.mode:
main_visualize_all_embds()
if __name__ == "__main__":
main()