refactor argparser into separate file, add logger
This commit is contained in:
parent
9f0bae33d5
commit
2afaccc84b
2
Makefile
2
Makefile
@ -1,5 +1,5 @@
|
||||
test:
|
||||
python3 main.py --modes train --epochs 1 --batch 64 --train data/rk_data.csv.gz
|
||||
python3 main.py --modes train --epochs 1 --batch 128 --train data/rk_mini.csv.gz
|
||||
|
||||
hyper:
|
||||
python3 main.py --modes hyperband --epochs 1 --batch 64 --train data/rk_data.csv.gz
|
||||
|
78
arguments.py
Normal file
78
arguments.py
Normal file
@ -0,0 +1,78 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--modes", action="store", dest="modes", nargs="+",
|
||||
default=[])
|
||||
|
||||
parser.add_argument("--train", action="store", dest="train_data",
|
||||
default="data/full_dataset.csv.tar.bz2")
|
||||
|
||||
parser.add_argument("--test", action="store", dest="test_data",
|
||||
default="data/full_future_dataset.csv.tar.bz2")
|
||||
|
||||
# parser.add_argument("--h5data", action="store", dest="h5data",
|
||||
# default="")
|
||||
#
|
||||
parser.add_argument("--models", action="store", dest="model_path",
|
||||
default="models/models_x")
|
||||
|
||||
# parser.add_argument("--pred", action="store", dest="pred",
|
||||
# default="")
|
||||
#
|
||||
parser.add_argument("--type", action="store", dest="model_type",
|
||||
default="paul")
|
||||
|
||||
parser.add_argument("--batch", action="store", dest="batch_size",
|
||||
default=64, type=int)
|
||||
|
||||
parser.add_argument("--epochs", action="store", dest="epochs",
|
||||
default=10, type=int)
|
||||
|
||||
# parser.add_argument("--samples", action="store", dest="samples",
|
||||
# default=100000, type=int)
|
||||
#
|
||||
# parser.add_argument("--samples_val", action="store", dest="samples_val",
|
||||
# default=10000, type=int)
|
||||
#
|
||||
parser.add_argument("--embd", action="store", dest="embedding",
|
||||
default=128, type=int)
|
||||
|
||||
parser.add_argument("--hidden_char_dims", action="store", dest="hidden_char_dims",
|
||||
default=256, type=int)
|
||||
|
||||
parser.add_argument("--window", action="store", dest="window",
|
||||
default=10, type=int)
|
||||
|
||||
parser.add_argument("--domain_length", action="store", dest="domain_length",
|
||||
default=40, type=int)
|
||||
|
||||
parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
|
||||
default=512, type=int)
|
||||
|
||||
|
||||
# parser.add_argument("--queue", action="store", dest="queue_size",
|
||||
# default=50, type=int)
|
||||
#
|
||||
# parser.add_argument("--p", action="store", dest="p_train",
|
||||
# default=0.5, type=float)
|
||||
#
|
||||
# parser.add_argument("--p_val", action="store", dest="p_val",
|
||||
# default=0.01, type=float)
|
||||
#
|
||||
# parser.add_argument("--gpu", action="store", dest="gpu",
|
||||
# default=0, type=int)
|
||||
#
|
||||
# parser.add_argument("--tmp", action="store_true", dest="tmp")
|
||||
#
|
||||
# parser.add_argument("--test", action="store_true", dest="test")
|
||||
|
||||
|
||||
def parse():
|
||||
args = parser.parse_args()
|
||||
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
||||
args.clf_model = os.path.join(args.model_path, "clf.h5")
|
||||
args.train_log = os.path.join(args.model_path, "train.log")
|
||||
args.h5data = args.train_data + ".h5"
|
||||
return args
|
16
dataset.py
16
dataset.py
@ -1,4 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import logging
|
||||
import string
|
||||
|
||||
import h5py
|
||||
@ -7,6 +8,8 @@ import pandas as pd
|
||||
from keras.utils import np_utils
|
||||
from tqdm import tqdm
|
||||
|
||||
logger = logging.getLogger('logger')
|
||||
|
||||
chars = dict((char, idx + 1) for (idx, char) in
|
||||
enumerate(string.ascii_lowercase + string.punctuation + string.digits))
|
||||
|
||||
@ -36,7 +39,7 @@ def get_user_chunks(dataFrame, windowSize=10, overlapping=False,
|
||||
userIDs = np.arange(len(dataFrame))
|
||||
for blockID in np.arange(numBlocks):
|
||||
curIDs = userIDs[(blockID * windowSize):((blockID + 1) * windowSize)]
|
||||
# print(curIDs)
|
||||
# logger.info(curIDs)
|
||||
useData = dataFrame.iloc[curIDs]
|
||||
curDomains = useData['domain']
|
||||
if maxLengthInSeconds != -1:
|
||||
@ -88,7 +91,7 @@ def get_all_flow_features(features):
|
||||
def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10):
|
||||
domains = []
|
||||
features = []
|
||||
print("get chunks from user data frames")
|
||||
logger.info("get chunks from user data frames")
|
||||
for i, user_flow in tqdm(list(enumerate(get_flow_per_user(user_flow_df)))):
|
||||
(domain_windows, feature_windows) = get_user_chunks(user_flow,
|
||||
windowSize=window_size,
|
||||
@ -97,7 +100,7 @@ def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10):
|
||||
domains += domain_windows
|
||||
features += feature_windows
|
||||
|
||||
print("create training dataset")
|
||||
logger.info("create training dataset")
|
||||
domain_tr, flow_tr, hits_tr, _, server_tr, trusted_hits_tr = create_dataset_from_lists(domains=domains,
|
||||
flows=features,
|
||||
vocab=char_dict,
|
||||
@ -150,13 +153,8 @@ def create_dataset_from_lists(domains, flows, vocab, max_len, window_size=10):
|
||||
:param window_size: size of the flow window
|
||||
:return:
|
||||
"""
|
||||
# sample_size = len(domains)
|
||||
|
||||
# domain_features = np.zeros((sample_size, window_size, max_len))
|
||||
flow_features = get_all_flow_features(flows)
|
||||
|
||||
domain_features = np.array([[get_domain_features(d, vocab, max_len) for d in x] for x in domains])
|
||||
|
||||
flow_features = get_all_flow_features(flows)
|
||||
hits = np.max(np.stack(map(lambda f: f.virusTotalHits, flows)), axis=1)
|
||||
names = np.unique(np.stack(map(lambda f: f.user_hash, flows)), axis=1)
|
||||
servers = np.max(np.stack(map(lambda f: f.serverLabel, flows)), axis=1)
|
||||
|
@ -1,6 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# implementation of hyperband:
|
||||
# https://arxiv.org/pdf/1603.06560.pdf
|
||||
import logging
|
||||
import random
|
||||
from math import log, ceil
|
||||
from random import random as rng
|
||||
@ -10,6 +11,8 @@ import numpy as np
|
||||
|
||||
import models
|
||||
|
||||
logger = logging.getLogger('logger')
|
||||
|
||||
|
||||
def sample_params(param_distribution: dict):
|
||||
p = {}
|
||||
@ -75,7 +78,7 @@ class Hyperband:
|
||||
n_configs = n * self.eta ** (-i)
|
||||
n_iterations = r * self.eta ** (i)
|
||||
|
||||
print("\n*** {} configurations x {:.1f} iterations each".format(
|
||||
logger.info("\n*** {} configurations x {:.1f} iterations each".format(
|
||||
n_configs, n_iterations))
|
||||
|
||||
val_losses = []
|
||||
@ -84,7 +87,7 @@ class Hyperband:
|
||||
for t in T:
|
||||
|
||||
self.counter += 1
|
||||
print("\n{} | {} | lowest loss so far: {:.4f} (run {})\n".format(
|
||||
logger.info("\n{} | {} | lowest loss so far: {:.4f} (run {})\n".format(
|
||||
self.counter, ctime(), self.best_loss, self.best_counter))
|
||||
|
||||
start_time = time()
|
||||
@ -98,7 +101,7 @@ class Hyperband:
|
||||
assert ('loss' in result)
|
||||
|
||||
seconds = int(round(time() - start_time))
|
||||
print("\n{} seconds.".format(seconds))
|
||||
logger.info("\n{} seconds.".format(seconds))
|
||||
|
||||
loss = result['loss']
|
||||
val_losses.append(loss)
|
||||
|
138
main.py
138
main.py
@ -1,85 +1,46 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
|
||||
from keras.models import load_model
|
||||
|
||||
import arguments
|
||||
import dataset
|
||||
import hyperband
|
||||
import models
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# create logger
|
||||
logger = logging.getLogger('logger')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
parser.add_argument("--modes", action="store", dest="modes", nargs="+",
|
||||
default=[])
|
||||
# create console handler and set level to debug
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.DEBUG)
|
||||
|
||||
parser.add_argument("--train", action="store", dest="train_data",
|
||||
default="data/full_dataset.csv.tar.bz2")
|
||||
# create formatter
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
parser.add_argument("--test", action="store", dest="test_data",
|
||||
default="data/full_future_dataset.csv.tar.bz2")
|
||||
# add formatter to ch
|
||||
ch.setFormatter(formatter)
|
||||
|
||||
# parser.add_argument("--h5data", action="store", dest="h5data",
|
||||
# default="")
|
||||
#
|
||||
parser.add_argument("--models", action="store", dest="model_path",
|
||||
default="models/models_x")
|
||||
# add ch to logger
|
||||
logger.addHandler(ch)
|
||||
|
||||
# parser.add_argument("--pred", action="store", dest="pred",
|
||||
# default="")
|
||||
#
|
||||
parser.add_argument("--type", action="store", dest="model_type",
|
||||
default="paul")
|
||||
ch = logging.FileHandler("info.log")
|
||||
ch.setLevel(logging.DEBUG)
|
||||
|
||||
parser.add_argument("--batch", action="store", dest="batch_size",
|
||||
default=64, type=int)
|
||||
# create formatter
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
parser.add_argument("--epochs", action="store", dest="epochs",
|
||||
default=10, type=int)
|
||||
# add formatter to ch
|
||||
ch.setFormatter(formatter)
|
||||
|
||||
# parser.add_argument("--samples", action="store", dest="samples",
|
||||
# default=100000, type=int)
|
||||
#
|
||||
# parser.add_argument("--samples_val", action="store", dest="samples_val",
|
||||
# default=10000, type=int)
|
||||
#
|
||||
parser.add_argument("--embd", action="store", dest="embedding",
|
||||
default=128, type=int)
|
||||
# add ch to logger
|
||||
logger.addHandler(ch)
|
||||
|
||||
parser.add_argument("--hidden_char_dims", action="store", dest="hidden_char_dims",
|
||||
default=256, type=int)
|
||||
|
||||
parser.add_argument("--window", action="store", dest="window",
|
||||
default=10, type=int)
|
||||
|
||||
parser.add_argument("--domain_length", action="store", dest="domain_length",
|
||||
default=40, type=int)
|
||||
|
||||
parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
|
||||
default=512, type=int)
|
||||
|
||||
# parser.add_argument("--queue", action="store", dest="queue_size",
|
||||
# default=50, type=int)
|
||||
#
|
||||
# parser.add_argument("--p", action="store", dest="p_train",
|
||||
# default=0.5, type=float)
|
||||
#
|
||||
# parser.add_argument("--p_val", action="store", dest="p_val",
|
||||
# default=0.01, type=float)
|
||||
#
|
||||
# parser.add_argument("--gpu", action="store", dest="gpu",
|
||||
# default=0, type=int)
|
||||
#
|
||||
# parser.add_argument("--tmp", action="store_true", dest="tmp")
|
||||
#
|
||||
# parser.add_argument("--test", action="store_true", dest="test")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
||||
args.clf_model = os.path.join(args.model_path, "clf.h5")
|
||||
args.train_log = os.path.join(args.model_path, "train.log")
|
||||
args.h5data = args.train_data + ".h5"
|
||||
args = arguments.parse()
|
||||
|
||||
|
||||
# config = tf.ConfigProto(log_device_placement=True)
|
||||
@ -125,7 +86,7 @@ def main_hyperband():
|
||||
params = {
|
||||
# static params
|
||||
"type": ["paul"],
|
||||
"batch_size": [64],
|
||||
"batch_size": [args.batch_size],
|
||||
"vocab_size": [len(char_dict) + 1],
|
||||
"window_size": [10],
|
||||
"domain_length": [40],
|
||||
@ -143,32 +104,35 @@ def main_hyperband():
|
||||
"dense_main": [16, 32, 64, 128, 256, 512],
|
||||
}
|
||||
param = hyperband.sample_params(params)
|
||||
print(param)
|
||||
logger.info(param)
|
||||
|
||||
print("create training dataset")
|
||||
logger.info("create training dataset")
|
||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict,
|
||||
max_len=args.domain_length,
|
||||
window_size=args.window)
|
||||
|
||||
hp = hyperband.Hyperband(params, [domain_tr, flow_tr], [client_tr, server_tr])
|
||||
hp.run()
|
||||
hp = hyperband.Hyperband(params,
|
||||
[domain_tr, flow_tr],
|
||||
[client_tr, server_tr])
|
||||
results = hp.run()
|
||||
json.dump(results, open("hyperband.json"))
|
||||
|
||||
|
||||
def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
||||
char_dict = dataset.get_character_dict()
|
||||
print("check for h5data", h5data)
|
||||
logger.info(f"check for h5data {h5data}")
|
||||
try:
|
||||
open(h5data, "r")
|
||||
except FileNotFoundError:
|
||||
print("h5 data not found - load csv file")
|
||||
logger.info("h5 data not found - load csv file")
|
||||
user_flow_df = dataset.get_user_flow_data(train_data)
|
||||
print("create training dataset")
|
||||
logger.info("create training dataset")
|
||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict,
|
||||
max_len=domain_length,
|
||||
window_size=window_size)
|
||||
print("store training dataset as h5 file")
|
||||
logger.info("store training dataset as h5 file")
|
||||
dataset.store_h5dataset(args.h5data, domain_tr, flow_tr, client_tr, server_tr)
|
||||
print("load h5 dataset")
|
||||
logger.info("load h5 dataset")
|
||||
return dataset.load_h5dataset(h5data)
|
||||
|
||||
|
||||
@ -204,7 +168,7 @@ def main_train():
|
||||
embedding, model = models.get_models_by_params(param)
|
||||
embedding.summary()
|
||||
model.summary()
|
||||
print("define callbacks")
|
||||
logger.info("define callbacks")
|
||||
cp = ModelCheckpoint(filepath=args.clf_model,
|
||||
monitor='val_loss',
|
||||
verbose=False,
|
||||
@ -213,11 +177,11 @@ def main_train():
|
||||
early = EarlyStopping(monitor='val_loss',
|
||||
patience=5,
|
||||
verbose=False)
|
||||
print("compile model")
|
||||
logger.info("compile model")
|
||||
model.compile(optimizer='adam',
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
print("start training")
|
||||
logger.info("start training")
|
||||
model.fit([domain_tr, flow_tr],
|
||||
[client_tr, server_tr],
|
||||
batch_size=args.batch_size,
|
||||
@ -225,40 +189,40 @@ def main_train():
|
||||
callbacks=[cp, csv, early],
|
||||
shuffle=True,
|
||||
validation_split=0.2)
|
||||
print("save embedding")
|
||||
logger.info("save embedding")
|
||||
embedding.save(args.embedding_model)
|
||||
|
||||
|
||||
def main_test():
|
||||
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.h5data, args.train_data,
|
||||
args.domain_length, args.window)
|
||||
# embedding = load_model(args.embedding_model)
|
||||
clf = load_model(args.clf_model)
|
||||
|
||||
loss, _, _, client_acc, server_acc = clf.evaluate([domain_val, flow_val],
|
||||
[client_val, server_val],
|
||||
batch_size=args.batch_size)
|
||||
|
||||
print(f"loss: {loss}\nclient acc: {client_acc}\nserver acc: {server_acc}")
|
||||
logger.info(f"loss: {loss}\nclient acc: {client_acc}\nserver acc: {server_acc}")
|
||||
y_pred = clf.predict([domain_val, flow_val],
|
||||
batch_size=args.batch_size)
|
||||
np.save(os.path.join(args.model_path, "future_predict.npy"), y_pred)
|
||||
|
||||
|
||||
def main_visualization():
|
||||
mask = dataset.load_mask_eval(args.data, args.test_image)
|
||||
y_pred_path = args.model_path + "pred.npy"
|
||||
print("plot model")
|
||||
logger.info("plot model")
|
||||
model = load_model(args.model_path + "model.h5",
|
||||
custom_objects=evaluation.get_metrics())
|
||||
visualize.plot_model(model, args.model_path + "model.png")
|
||||
print("plot training curve")
|
||||
logger.info("plot training curve")
|
||||
logs = pd.read_csv(args.model_path + "train.log")
|
||||
visualize.plot_training_curve(logs, "{}/train.png".format(args.model_path))
|
||||
pred = np.load(y_pred_path)
|
||||
print("plot pr curve")
|
||||
logger.info("plot pr curve")
|
||||
visualize.plot_precision_recall(mask, pred, "{}/prc.png".format(args.model_path))
|
||||
visualize.plot_precision_recall_curves(mask, pred, "{}/prc2.png".format(args.model_path))
|
||||
print("plot roc curve")
|
||||
logger.info("plot roc curve")
|
||||
visualize.plot_roc_curve(mask, pred, "{}/roc.png".format(args.model_path))
|
||||
print("store prediction image")
|
||||
logger.info("store prediction image")
|
||||
visualize.save_image_as(pred, "{}/pred.png".format(args.model_path))
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user