refactor hyperband; fix domain generation
integrate hyperband option in training procedure - start refactoring - remove the index erro in generation and add helper functions
This commit is contained in:
parent
8b17bd0701
commit
88e3eda595
62
dataset.py
62
dataset.py
@ -12,7 +12,7 @@ from tqdm import tqdm
|
|||||||
logger = logging.getLogger('cisco_logger')
|
logger = logging.getLogger('cisco_logger')
|
||||||
|
|
||||||
char2idx = dict((char, idx + 1) for (idx, char) in
|
char2idx = dict((char, idx + 1) for (idx, char) in
|
||||||
enumerate(string.ascii_lowercase + string.punctuation + string.digits))
|
enumerate(string.ascii_lowercase + string.punctuation + string.digits + " "))
|
||||||
|
|
||||||
idx2char = {v: k for k, v in char2idx.items()}
|
idx2char = {v: k for k, v in char2idx.items()}
|
||||||
|
|
||||||
@ -34,50 +34,18 @@ def decode_char(i):
|
|||||||
|
|
||||||
|
|
||||||
encode_char = np.vectorize(encode_char)
|
encode_char = np.vectorize(encode_char)
|
||||||
|
decode_char = np.vectorize(decode_char)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_domain(domain: string):
|
||||||
|
return encode_char(list(domain))
|
||||||
|
|
||||||
|
|
||||||
|
def decode_domain(domain):
|
||||||
|
return "".join(decode_char(domain))
|
||||||
|
|
||||||
|
|
||||||
# TODO: ask for correct refactoring
|
|
||||||
def get_user_chunks(user_flow, window=10):
|
def get_user_chunks(user_flow, window=10):
|
||||||
# TODO: what is maxLengthInSeconds for?!?
|
|
||||||
# maxMilliSeconds = maxLengthInSeconds * 1000
|
|
||||||
# domains = []
|
|
||||||
# flows = []
|
|
||||||
# if not overlapping:
|
|
||||||
# numBlocks = int(np.ceil(len(user_flow) / window))
|
|
||||||
# userIDs = np.arange(len(user_flow))
|
|
||||||
# for blockID in np.arange(numBlocks):
|
|
||||||
# curIDs = userIDs[(blockID * window):((blockID + 1) * window)]
|
|
||||||
# useData = user_flow.iloc[curIDs]
|
|
||||||
# curDomains = useData['domain']
|
|
||||||
# if maxLengthInSeconds != -1:
|
|
||||||
# curMinMilliSeconds = np.min(useData['timeStamp']) + maxMilliSeconds
|
|
||||||
# underTimeOutIDs = np.where(np.array(useData['timeStamp']) <= curMinMilliSeconds)
|
|
||||||
# if len(underTimeOutIDs) != len(curIDs):
|
|
||||||
# curIDs = curIDs[underTimeOutIDs]
|
|
||||||
# useData = user_flow.iloc[curIDs]
|
|
||||||
# curDomains = useData['domain']
|
|
||||||
# domains.append(list(curDomains))
|
|
||||||
# flows.append(useData)
|
|
||||||
# else:
|
|
||||||
# numBlocks = len(user_flow) + 1 - window
|
|
||||||
# userIDs = np.arange(len(user_flow))
|
|
||||||
# for blockID in np.arange(numBlocks):
|
|
||||||
# curIDs = userIDs[blockID:blockID + window]
|
|
||||||
# useData = user_flow.iloc[curIDs]
|
|
||||||
# curDomains = useData['domain']
|
|
||||||
# if maxLengthInSeconds != -1:
|
|
||||||
# curMinMilliSeconds = np.min(useData['timeStamp']) + maxMilliSeconds
|
|
||||||
# underTimeOutIDs = np.where(np.array(useData['timeStamp']) <= curMinMilliSeconds)
|
|
||||||
# if len(underTimeOutIDs) != len(curIDs):
|
|
||||||
# curIDs = curIDs[underTimeOutIDs]
|
|
||||||
# useData = user_flow.iloc[curIDs]
|
|
||||||
# curDomains = useData['domain']
|
|
||||||
# domains.append(list(curDomains))
|
|
||||||
# flows.append(useData)
|
|
||||||
# if domains and len(domains[-1]) != window:
|
|
||||||
# domains.pop(-1)
|
|
||||||
# flows.pop(-1)
|
|
||||||
# return domains, flows
|
|
||||||
result = []
|
result = []
|
||||||
chunk_size = (len(user_flow) // window)
|
chunk_size = (len(user_flow) // window)
|
||||||
for i in range(chunk_size):
|
for i in range(chunk_size):
|
||||||
@ -87,12 +55,11 @@ def get_user_chunks(user_flow, window=10):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# TODO: DATA CORRUPTION; reverse, 0! to n
|
def get_domain_features(domain: string, max_length=40):
|
||||||
def get_domain_features(domain, max_length=40):
|
|
||||||
encoding = np.zeros((max_length,))
|
encoding = np.zeros((max_length,))
|
||||||
for j in range(min(len(domain), max_length)):
|
for j in range(min(len(domain), max_length)):
|
||||||
char = domain[-j] # TODO: why -j -> order reversed for domain url?
|
c = domain[len(domain) - 1 - j]
|
||||||
encoding[j] = encode_char(char)
|
encoding[max_length - 1 - j] = encode_char(c)
|
||||||
return encoding
|
return encoding
|
||||||
|
|
||||||
|
|
||||||
@ -207,6 +174,7 @@ def get_user_flow_data(csv_file):
|
|||||||
"bytes_up": int,
|
"bytes_up": int,
|
||||||
"domain": object,
|
"domain": object,
|
||||||
"timeStamp": float,
|
"timeStamp": float,
|
||||||
|
"http_method": object,
|
||||||
"server_ip": object,
|
"server_ip": object,
|
||||||
"user_hash": float,
|
"user_hash": float,
|
||||||
"virusTotalHits": int,
|
"virusTotalHits": int,
|
||||||
@ -314,7 +282,7 @@ def load_or_generate_domains(train_data, domain_length):
|
|||||||
domain_encs = user_flow_df.domain.apply(lambda d: get_domain_features(d, domain_length))
|
domain_encs = user_flow_df.domain.apply(lambda d: get_domain_features(d, domain_length))
|
||||||
domain_encs = np.stack(domain_encs)
|
domain_encs = np.stack(domain_encs)
|
||||||
|
|
||||||
return domain_encs, user_flow_df[["clientLabel", "serverLabel"]].as_matrix().astype(bool)
|
return domain_encs, user_flow_df.domain, user_flow_df[["clientLabel", "serverLabel"]].as_matrix().astype(bool)
|
||||||
|
|
||||||
|
|
||||||
def save_predictions(path, results):
|
def save_predictions(path, results):
|
||||||
|
2
fancy.sh
2
fancy.sh
@ -27,4 +27,4 @@ DATADIR=$4
|
|||||||
python3 main.py --mode embedding --batch 1024 --models ${RESDIR}/client_final_{1..20}/ ${RESDIR}/both_final_{1..20}/ \
|
python3 main.py --mode embedding --batch 1024 --models ${RESDIR}/client_final_{1..20}/ ${RESDIR}/both_final_{1..20}/ \
|
||||||
${RESDIR}/both_inter_{1..20}/ ${RESDIR}/both_staggered_{1..20}/ \
|
${RESDIR}/both_inter_{1..20}/ ${RESDIR}/both_staggered_{1..20}/ \
|
||||||
--data ${DATADIR} \
|
--data ${DATADIR} \
|
||||||
--out-prefix ${RESDIR}/figs/tsne/tsne
|
--out-prefix ${RESDIR}/figs/svd/svd
|
||||||
|
273
main.py
273
main.py
@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import operator
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import joblib
|
import joblib
|
||||||
@ -78,6 +79,50 @@ PARAMS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: remove inner global params
|
||||||
|
def get_param_dist(size="small"):
|
||||||
|
if dist_type == "small":
|
||||||
|
return {
|
||||||
|
# static params
|
||||||
|
"type": [args.model_type],
|
||||||
|
"depth": [args.model_depth],
|
||||||
|
"model_output": [args.model_output],
|
||||||
|
"batch_size": [args.batch_size],
|
||||||
|
"window_size": [args.window],
|
||||||
|
"flow_features": [3],
|
||||||
|
"domain_length": [args.domain_length],
|
||||||
|
# model params
|
||||||
|
"embedding": [2 ** x for x in range(3, 6)],
|
||||||
|
"filter_embedding": [2 ** x for x in range(1, 8)],
|
||||||
|
"kernel_embedding": [1, 3, 5],
|
||||||
|
"dense_embedding": [2 ** x for x in range(4, 8)],
|
||||||
|
"dropout": [0.5],
|
||||||
|
"filter_main": [2 ** x for x in range(1, 8)],
|
||||||
|
"kernel_main": [1, 3, 5],
|
||||||
|
"dense_main": [2 ** x for x in range(1, 8)],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
# static params
|
||||||
|
"type": [args.model_type],
|
||||||
|
"depth": [args.model_depth],
|
||||||
|
"model_output": [args.model_output],
|
||||||
|
"batch_size": [args.batch_size],
|
||||||
|
"window_size": [args.window],
|
||||||
|
"flow_features": [3],
|
||||||
|
"domain_length": [args.domain_length],
|
||||||
|
# model params
|
||||||
|
"embedding": [2 ** x for x in range(3, 7)],
|
||||||
|
"filter_embedding": [2 ** x for x in range(1, 10)],
|
||||||
|
"kernel_embedding": [1, 3, 5, 7, 9],
|
||||||
|
"dense_embedding": [2 ** x for x in range(4, 10)],
|
||||||
|
"dropout": [0.5],
|
||||||
|
"filter_main": [2 ** x for x in range(1, 10)],
|
||||||
|
"kernel_main": [1, 3, 5, 7, 9],
|
||||||
|
"dense_main": [2 ** x for x in range(1, 12)],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def create_model(model, output_type):
|
def create_model(model, output_type):
|
||||||
if output_type == "both":
|
if output_type == "both":
|
||||||
return Model(inputs=[model.in_domains, model.in_flows], outputs=(model.out_client, model.out_server))
|
return Model(inputs=[model.in_domains, model.in_flows], outputs=(model.out_client, model.out_server))
|
||||||
@ -87,53 +132,45 @@ def create_model(model, output_type):
|
|||||||
raise Exception("unknown model output")
|
raise Exception("unknown model output")
|
||||||
|
|
||||||
|
|
||||||
|
def shuffle_training_data(domain, flow, client, server):
|
||||||
|
idx = np.random.permutation(len(domain))
|
||||||
|
domain = domain[idx]
|
||||||
|
flow = flow[idx]
|
||||||
|
client = client[idx]
|
||||||
|
server = server[idx]
|
||||||
|
return domain, flow, client, server
|
||||||
|
|
||||||
|
|
||||||
def main_paul_best():
|
def main_paul_best():
|
||||||
pauls_best_params = models.pauls_networks.best_config
|
pauls_best_params = models.pauls_networks.best_config
|
||||||
main_train(pauls_best_params)
|
main_train(pauls_best_params)
|
||||||
|
|
||||||
|
|
||||||
def main_hyperband():
|
def main_hyperband(data, domain_length, window_size, model_type, result_file, dist_size="small"):
|
||||||
param_dist = {
|
param_dist = get_param_dist(dist_size)
|
||||||
# static params
|
|
||||||
"type": [args.model_type],
|
|
||||||
"depth": [args.model_depth],
|
|
||||||
"model_output": [args.model_output],
|
|
||||||
"batch_size": [args.batch_size],
|
|
||||||
"window_size": [args.window],
|
|
||||||
"flow_features": [3],
|
|
||||||
"domain_length": [args.domain_length],
|
|
||||||
# model params
|
|
||||||
"embedding": [2 ** x for x in range(3, 7)],
|
|
||||||
"filter_embedding": [2 ** x for x in range(1, 10)],
|
|
||||||
"kernel_embedding": [1, 3, 5, 7, 9],
|
|
||||||
"dense_embedding": [2 ** x for x in range(4, 10)],
|
|
||||||
"dropout": [0.5],
|
|
||||||
"filter_main": [2 ** x for x in range(1, 10)],
|
|
||||||
"kernel_main": [1, 3, 5, 7, 9],
|
|
||||||
"dense_main": [2 ** x for x in range(1, 12)],
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info("create training dataset")
|
logger.info("create training dataset")
|
||||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
|
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(data,
|
||||||
args.data,
|
data,
|
||||||
args.domain_length,
|
domain_length,
|
||||||
args.window)
|
window)
|
||||||
server_tr = np.max(server_windows_tr, axis=1)
|
server_tr = np.max(server_windows_tr, axis=1)
|
||||||
|
|
||||||
if args.model_type in ("inter", "staggered"):
|
if model_type in ("inter", "staggered"):
|
||||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||||
|
|
||||||
idx = np.random.permutation(len(domain_tr))
|
|
||||||
domain_tr = domain_tr[idx]
|
|
||||||
flow_tr = flow_tr[idx]
|
|
||||||
client_tr = client_tr[idx]
|
|
||||||
server_tr = server_tr[idx]
|
|
||||||
|
|
||||||
|
domain_tr, flow_tr, client_tr, server_tr = shuffle_training_data(domain_tr, flow_tr, client_tr, server_tr)
|
||||||
|
|
||||||
|
return run_hyperband(dist_size, domain_tr, flow_tr, client_tr, server_tr, 81, result_file)
|
||||||
|
|
||||||
|
|
||||||
|
def run_hyperband(dist_size, domain, flow, client, server, max_iter, savefile):
|
||||||
|
param_dist = get_param_dist(dist_size)
|
||||||
hp = hyperband.Hyperband(param_dist,
|
hp = hyperband.Hyperband(param_dist,
|
||||||
[domain_tr, flow_tr],
|
[domain, flow],
|
||||||
[client_tr, server_tr],
|
[client, server],
|
||||||
max_iter=81,
|
max_iter=max_iter,
|
||||||
savefile=args.hyperband_results)
|
savefile=savefile)
|
||||||
results = hp.run()
|
results = hp.run()
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@ -148,10 +185,23 @@ def main_train(param=None):
|
|||||||
exists_or_make_path(args.model_path)
|
exists_or_make_path(args.model_path)
|
||||||
logger.info(f"Use command line arguments: {args}")
|
logger.info(f"Use command line arguments: {args}")
|
||||||
|
|
||||||
|
# data preparation
|
||||||
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
|
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
|
||||||
args.data,
|
args.data,
|
||||||
args.domain_length,
|
args.domain_length,
|
||||||
args.window)
|
args.window)
|
||||||
|
server_tr = np.max(server_windows_tr, axis=1)
|
||||||
|
if args.model_type in ("inter", "staggered"):
|
||||||
|
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||||
|
|
||||||
|
# call hyperband if used
|
||||||
|
if args.hyperband_results:
|
||||||
|
logger.info("start hyperband parameter search")
|
||||||
|
hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, 81, args.hyperband_results)
|
||||||
|
param = sorted(hyper_results, key=operator.itemgetter("loss"))[0]
|
||||||
|
logger.info(f"select params from result: {param}")
|
||||||
|
|
||||||
|
# define training call backs
|
||||||
logger.info("define callbacks")
|
logger.info("define callbacks")
|
||||||
callbacks = []
|
callbacks = []
|
||||||
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
|
callbacks.append(ModelCheckpoint(filepath=args.clf_model,
|
||||||
@ -166,8 +216,7 @@ def main_train(param=None):
|
|||||||
verbose=False))
|
verbose=False))
|
||||||
custom_metrics = models.get_metric_functions()
|
custom_metrics = models.get_metric_functions()
|
||||||
|
|
||||||
server_tr = np.max(server_windows_tr, axis=1)
|
# custom class or sample weights
|
||||||
|
|
||||||
if args.class_weights:
|
if args.class_weights:
|
||||||
logger.info("class weights: compute custom weights")
|
logger.info("class weights: compute custom weights")
|
||||||
custom_class_weights = get_custom_class_weights(client_tr.value, server_tr)
|
custom_class_weights = get_custom_class_weights(client_tr.value, server_tr)
|
||||||
@ -193,7 +242,6 @@ def main_train(param=None):
|
|||||||
new_model = create_model(new_model, args.model_output)
|
new_model = create_model(new_model, args.model_output)
|
||||||
|
|
||||||
if args.model_type in ("inter", "staggered"):
|
if args.model_type in ("inter", "staggered"):
|
||||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
|
||||||
model = new_model
|
model = new_model
|
||||||
|
|
||||||
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
|
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
|
||||||
@ -317,7 +365,7 @@ def main_test():
|
|||||||
args.data,
|
args.data,
|
||||||
args.domain_length,
|
args.domain_length,
|
||||||
args.window)
|
args.window)
|
||||||
domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
|
||||||
|
|
||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
results = {}
|
results = {}
|
||||||
@ -488,41 +536,52 @@ def main_visualize_all():
|
|||||||
|
|
||||||
|
|
||||||
def main_visualize_all_embds():
|
def main_visualize_all_embds():
|
||||||
import seaborn as sns
|
|
||||||
|
|
||||||
def load_df(path):
|
def load_df(path):
|
||||||
res = dataset.load_predictions(path)
|
res = dataset.load_predictions(path)
|
||||||
return res["domain_embds"]
|
return res["domain_embds"]
|
||||||
|
|
||||||
dfs = [(model_args["model_name"], load_df(model_args["model_path"])) for model_args in get_model_args(args)]
|
dfs = [(model_args["model_name"], load_df(model_args["model_path"])) for model_args in get_model_args(args)]
|
||||||
|
|
||||||
from sklearn.manifold import TSNE
|
from sklearn.decomposition import TruncatedSVD
|
||||||
|
|
||||||
def vis2(domain_embedding, labels):
|
def vis2(domain_embedding, labels):
|
||||||
n_levels = 7
|
n_levels = 7
|
||||||
logger.info(f"reduction for {sub_sample} of {len(domain_embedding)} points")
|
logger.info(f"reduction for {len(domain_embedding)} points")
|
||||||
red = TSNE(n_components=2)
|
red = TruncatedSVD(n_components=2, algorithm="arpack")
|
||||||
domains = red.fit_transform(domain_embedding)
|
domains = red.fit_transform(domain_embedding)
|
||||||
logger.info("plot kde")
|
logger.info("plot kde")
|
||||||
sns.kdeplot(domains[labels.sum(axis=1) == 0, 0], domains[labels.sum(axis=1) == 0, 1],
|
benign = domains[labels.sum(axis=1) == 0]
|
||||||
cmap="Blues", label="benign", n_levels=9, alpha=0.45, shade=True, shade_lowest=False)
|
# print(domains.shape)
|
||||||
sns.kdeplot(domains[labels[:, 1], 0], domains[labels[:, 1], 1],
|
# print(benign.shape)
|
||||||
cmap="Greens", label="server", n_levels=5, alpha=0.45, shade=True, shade_lowest=False)
|
# benign_idx
|
||||||
sns.kdeplot(domains[labels[:, 0], 0], domains[labels[:, 0], 1],
|
# sns.kdeplot(domains[labels.sum(axis=1) == 0, 0], domains[labels.sum(axis=1) == 0, 1],
|
||||||
cmap="Reds", label="client", n_levels=5, alpha=0.45, shade=True, shade_lowest=False)
|
# cmap="Blues", label="benign", n_levels=9, alpha=0.35, shade=True, shade_lowest=False)
|
||||||
|
# sns.kdeplot(domains[labels[:, 1], 0], domains[labels[:, 1], 1],
|
||||||
domain_encs, labels = dataset.load_or_generate_domains(args.data, args.domain_length)
|
# cmap="Greens", label="server", n_levels=5, alpha=0.35, shade=True, shade_lowest=False)
|
||||||
|
# sns.kdeplot(domains[labels[:, 0], 0], domains[labels[:, 0], 1],
|
||||||
|
# cmap="Reds", label="client", n_levels=5, alpha=0.35, shade=True, shade_lowest=False)
|
||||||
|
plt.scatter(benign[benign_idx, 0], benign[benign_idx, 1],
|
||||||
|
cmap="Blues", label="benign", alpha=0.35, s=10)
|
||||||
|
plt.scatter(domains[labels[:, 1], 0], domains[labels[:, 1], 1],
|
||||||
|
cmap="Greens", label="server", alpha=0.35, s=10)
|
||||||
|
plt.scatter(domains[labels[:, 0], 0], domains[labels[:, 0], 1],
|
||||||
|
cmap="Reds", label="client", alpha=0.35, s=10)
|
||||||
|
|
||||||
|
return np.concatenate((domains[:1000], domains[1000:2000], domains[2000:3000]), axis=0)
|
||||||
|
|
||||||
|
domain_encs, _, labels = dataset.load_or_generate_domains(args.data, args.domain_length)
|
||||||
|
|
||||||
idx = np.arange(len(labels))
|
idx = np.arange(len(labels))
|
||||||
client = labels[:, 0]
|
client = labels[:, 0]
|
||||||
server = labels[:, 1]
|
server = labels[:, 1]
|
||||||
benign = np.logical_not(np.logical_and(client, server))
|
benign = np.logical_not(np.logical_or(client, server))
|
||||||
print(client.sum(), server.sum(), benign.sum())
|
print(client.sum(), server.sum(), benign.sum())
|
||||||
|
|
||||||
idx = np.concatenate((
|
idx = np.concatenate((
|
||||||
np.random.choice(idx[client], 1000),
|
np.random.choice(idx[client], 1000),
|
||||||
np.random.choice(idx[server], 1000),
|
np.random.choice(idx[server], 1000),
|
||||||
np.random.choice(idx[benign], 6000)), axis=0)
|
np.random.choice(idx[benign], 6000)), axis=0)
|
||||||
|
benign_idx = np.random.choice(np.arange(6000), 1000)
|
||||||
|
|
||||||
print(idx.shape)
|
print(idx.shape)
|
||||||
lls = labels[idx]
|
lls = labels[idx]
|
||||||
@ -531,7 +590,8 @@ def main_visualize_all_embds():
|
|||||||
logger.info(f"plot embedding for {model_name}")
|
logger.info(f"plot embedding for {model_name}")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
embd = embd[idx]
|
embd = embd[idx]
|
||||||
vis2(embd, lls)
|
points = vis2(embd, lls)
|
||||||
|
# np.savetxt("{}_{}.csv".format(args.output_prefix, model_name), points, delimiter=",")
|
||||||
visualize.plot_save("{}_{}.pdf".format(args.output_prefix, model_name))
|
visualize.plot_save("{}_{}.pdf".format(args.output_prefix, model_name))
|
||||||
|
|
||||||
|
|
||||||
@ -644,6 +704,8 @@ def main_beta():
|
|||||||
# plot_overall_result()
|
# plot_overall_result()
|
||||||
|
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
def plot_overall_result():
|
def plot_overall_result():
|
||||||
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
|
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
|
||||||
try:
|
try:
|
||||||
@ -651,12 +713,10 @@ def plot_overall_result():
|
|||||||
except Exception:
|
except Exception:
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
x = np.linspace(0, 1, 10000)
|
x = np.linspace(0, 1, 10000)
|
||||||
for vis in ["client_window_prc", "client_window_roc", "client_user_prc", "client_user_roc",
|
for vis in ["client_window_prc", "client_window_roc", "client_user_prc", "client_user_roc",
|
||||||
"server_window_prc", "server_window_roc", "server_user_prc", "server_user_roc",
|
"server_window_prc", "server_window_roc", "server_user_prc", "server_user_roc",
|
||||||
"server_flow_prc", "server_flow_roc", "server_domain_prc", "server_domain_roc",
|
"server_flow_prc", "server_flow_roc", "server_domain_prc", "server_domain_roc"]:
|
||||||
"server_domain_avg_prc", "server_domain_avg_roc"]:
|
|
||||||
logger.info(f"plot {vis}")
|
logger.info(f"plot {vis}")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
for model_key in results.keys():
|
for model_key in results.keys():
|
||||||
@ -667,22 +727,23 @@ def plot_overall_result():
|
|||||||
ys_mean, ys_std, ys = results[model_key]["all"][vis]
|
ys_mean, ys_std, ys = results[model_key]["all"][vis]
|
||||||
plt.plot(x, ys_mean, label=f"{model_key} - {np.mean(ys_mean):5.4} ({np.mean(ys_std):4.3})")
|
plt.plot(x, ys_mean, label=f"{model_key} - {np.mean(ys_mean):5.4} ({np.mean(ys_std):4.3})")
|
||||||
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, alpha=0.2)
|
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, alpha=0.2)
|
||||||
if vis.endswith("prc"):
|
if vis.endswith("prc"):
|
||||||
plt.xlabel('Recall')
|
plt.xlabel('Recall')
|
||||||
plt.ylabel('Precision')
|
plt.ylabel('Precision')
|
||||||
else:
|
else:
|
||||||
plt.xlabel('False Positive Rate')
|
plt.plot(x, x, label="random classifier", ls="--", c=".3", alpha=0.4)
|
||||||
plt.ylabel('True Positive Rate')
|
plt.xlabel('False Positive Rate')
|
||||||
plt.xscale('log')
|
plt.ylabel('True Positive Rate')
|
||||||
plt.ylim([0.0, 1.0])
|
plt.xscale('log')
|
||||||
plt.xlim([0.0, 1.0])
|
plt.ylim([0.0, 1.0])
|
||||||
|
plt.xlim([0.0, 1.0])
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{path}/figs/curves/{vis}_all.pdf")
|
visualize.plot_save(f"{path}/figs/curves/{vis}_all.pdf")
|
||||||
|
return
|
||||||
|
|
||||||
for vis in ["client_window_prc", "client_window_roc", "client_user_prc", "client_user_roc",
|
for vis in ["client_window_prc", "client_window_roc", "client_user_prc", "client_user_roc",
|
||||||
"server_window_prc", "server_window_roc", "server_user_prc", "server_user_roc",
|
"server_window_prc", "server_window_roc", "server_user_prc", "server_user_roc",
|
||||||
"server_flow_prc", "server_flow_roc", "server_domain_prc", "server_domain_roc",
|
"server_flow_prc", "server_flow_roc", "server_domain_prc", "server_domain_roc"]:
|
||||||
"server_domain_avg_prc", "server_domain_avg_roc"]:
|
|
||||||
logger.info(f"plot {vis}")
|
logger.info(f"plot {vis}")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
for model_key in results.keys():
|
for model_key in results.keys():
|
||||||
@ -693,26 +754,76 @@ def plot_overall_result():
|
|||||||
_, _, ys = results[model_key]["all"][vis]
|
_, _, ys = results[model_key]["all"][vis]
|
||||||
for y in ys:
|
for y in ys:
|
||||||
plt.plot(x, y, label=f"{model_key} - {np.mean(y):5.4}")
|
plt.plot(x, y, label=f"{model_key} - {np.mean(y):5.4}")
|
||||||
if vis.endswith("prc"):
|
if vis.endswith("prc"):
|
||||||
plt.xlabel('Recall')
|
plt.xlabel('Recall')
|
||||||
plt.ylabel('Precision')
|
plt.ylabel('Precision')
|
||||||
else:
|
else:
|
||||||
plt.xlabel('False Positive Rate')
|
plt.xlabel('False Positive Rate')
|
||||||
plt.ylabel('True Positive Rate')
|
plt.ylabel('True Positive Rate')
|
||||||
plt.xscale('log')
|
plt.xscale('log')
|
||||||
plt.ylim([0.0, 1.0])
|
plt.ylim([0.0, 1.0])
|
||||||
plt.xlim([0.0, 1.0])
|
plt.xlim([0.0, 1.0])
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{path}/figs/appendix/{model_key}_{vis}.pdf")
|
visualize.plot_save(f"{path}/figs/Appendices/{model_key}_{vis}.pdf")
|
||||||
|
|
||||||
|
|
||||||
|
def main_stats():
|
||||||
|
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
|
||||||
|
|
||||||
|
for time in ("current", "future"):
|
||||||
|
df = dataset.get_user_flow_data(f"data/{time}Data.csv.gz")
|
||||||
|
df["clientlabel"] = np.logical_or(df.virusTotalHits > 3, df.trustedHits > 0)
|
||||||
|
# df_user = df.groupby(df.user_hash).max()
|
||||||
|
# df_server = df.groupby(df.domain).max()
|
||||||
|
|
||||||
|
# len(df)
|
||||||
|
# df.clientlabel.sum()
|
||||||
|
# df.serverLabel.sum()
|
||||||
|
|
||||||
|
for col in ["duration", "bytes_down", "bytes_up"]:
|
||||||
|
# visualize.plot_clf()
|
||||||
|
plt.clf()
|
||||||
|
plt.hist(df[col])
|
||||||
|
visualize.plot_save(f"{path}/figs/hist_{time}_{col}.pdf")
|
||||||
|
print(".")
|
||||||
|
# visualize.plot_clf()
|
||||||
|
plt.clf()
|
||||||
|
plt.hist(np.log1p(df[col]))
|
||||||
|
visualize.plot_save(f"{path}/figs/hist_{time}_norm_{col}.pdf")
|
||||||
|
print("-")
|
||||||
|
|
||||||
|
|
||||||
|
def main_stats2():
|
||||||
|
import joblib
|
||||||
|
res = joblib.load("results/variance_test_hyper/curves.joblib")
|
||||||
|
|
||||||
|
for vis in ["client_window_prc", "client_window_roc", "client_user_prc", "client_user_roc",
|
||||||
|
"server_window_prc", "server_window_roc", "server_user_prc", "server_user_roc",
|
||||||
|
"server_flow_prc", "server_flow_roc", "server_domain_prc", "server_domain_roc",
|
||||||
|
"server_domain_avg_prc", "server_domain_avg_roc"]:
|
||||||
|
tab = []
|
||||||
|
for m, r in res.items():
|
||||||
|
if vis not in r: continue
|
||||||
|
tab.append(r["all"][vis][2].mean(axis=1))
|
||||||
|
if not tab: continue
|
||||||
|
|
||||||
|
df = pd.DataFrame(data=np.vstack(tab).T, columns=list(res.keys()),
|
||||||
|
index=range(1, 21))
|
||||||
|
df.to_csv(f"{vis}.csv")
|
||||||
|
|
||||||
|
print(f"% {vis}")
|
||||||
|
print(df.round(4).to_latex())
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if "train" == args.mode:
|
if "train" == args.mode:
|
||||||
main_train()
|
main_train()
|
||||||
if "retrain" == args.mode:
|
if "retrain" == args.mode:
|
||||||
main_retrain()
|
main_retrain()
|
||||||
if "hyperband" == args.mode:
|
if "hyperband" == args.mode:
|
||||||
main_hyperband()
|
main_hyperband(args.data, args.domain_length, args.window, args.model_type, args.hyperband_results)
|
||||||
if "test" == args.mode:
|
if "test" == args.mode:
|
||||||
main_test()
|
main_test()
|
||||||
if "fancy" == args.mode:
|
if "fancy" == args.mode:
|
||||||
@ -729,6 +840,8 @@ def main():
|
|||||||
test_server_only()
|
test_server_only()
|
||||||
if "embedding" == args.mode:
|
if "embedding" == args.mode:
|
||||||
main_visualize_all_embds()
|
main_visualize_all_embds()
|
||||||
|
if "stats" == args.mode:
|
||||||
|
main_stats()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user