add h5 support for pauls best config main
This commit is contained in:
parent
41b38de1ab
commit
522854ee0d
@ -96,7 +96,7 @@ def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10,
|
||||
domains = []
|
||||
features = []
|
||||
print("get chunks from user data frames")
|
||||
for i, user_flow in tqdm(list(enumerate(get_flow_per_user(user_flow_df)))):
|
||||
for i, user_flow in tqdm(list(enumerate(get_flow_per_user(user_flow_df)))[:50]):
|
||||
(domain_windows, feature_windows) = get_user_chunks(user_flow,
|
||||
windowSize=window_size,
|
||||
overlapping=False,
|
||||
|
30
main.py
30
main.py
@ -2,7 +2,6 @@ import argparse
|
||||
import os
|
||||
|
||||
from keras.models import load_model
|
||||
from keras.utils import np_utils
|
||||
|
||||
import dataset
|
||||
import hyperband
|
||||
@ -94,19 +93,24 @@ def exists_or_make_path(p):
|
||||
|
||||
def main_paul_best():
|
||||
char_dict = dataset.get_character_dict()
|
||||
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
||||
print("check for h5data")
|
||||
try:
|
||||
open(args.h5data, "r")
|
||||
raise FileNotFoundError()
|
||||
except FileNotFoundError:
|
||||
print("h5 data not found - load csv file")
|
||||
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
||||
print("create training dataset")
|
||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(
|
||||
user_flow_df, char_dict,
|
||||
max_len=args.domain_length, window_size=args.window)
|
||||
print("store training dataset as h5 file")
|
||||
dataset.store_h5dataset(args.h5data, domain_tr, flow_tr, client_tr, server_tr)
|
||||
print("load h5 dataset")
|
||||
domain_tr, flow_tr, client_tr, server_tr = dataset.load_h5dataset(args.h5data)
|
||||
|
||||
param = models.pauls_networks.best_config
|
||||
param["vocab_size"] = len(char_dict) + 1
|
||||
print(param)
|
||||
|
||||
print("create training dataset")
|
||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(
|
||||
user_flow_df, char_dict,
|
||||
max_len=args.domain_length,
|
||||
window_size=args.window)
|
||||
client_tr = np_utils.to_categorical(client_tr, 2)
|
||||
server_tr = np_utils.to_categorical(server_tr, 2)
|
||||
|
||||
embedding, model = models.get_models_by_params(param)
|
||||
|
||||
@ -163,12 +167,14 @@ def main_hyperband():
|
||||
|
||||
|
||||
def main_train():
|
||||
exists_or_make_path(args.clf_model)
|
||||
# exists_or_make_path(args.clf_model)
|
||||
char_dict = dataset.get_character_dict()
|
||||
print("check for h5data")
|
||||
try:
|
||||
open(args.h5data, "r")
|
||||
raise FileNotFoundError()
|
||||
except FileNotFoundError:
|
||||
print("h5 data not found - load csv file")
|
||||
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
||||
print("create training dataset")
|
||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(
|
||||
|
@ -10,9 +10,8 @@ df.reset_index(inplace=True)
|
||||
df.dropna(axis=0, how="any", inplace=True)
|
||||
df[["duration", "bytes_down", "bytes_up"]] = df[["duration", "bytes_down", "bytes_up"]].astype(np.int)
|
||||
df[["domain", "server_ip"]] = df[["domain", "server_ip"]].astype(str)
|
||||
df[["server_label"]] = df[["server_label"]].astype(np.bool)
|
||||
df.serverLabel = df.serverLabel.astype(np.bool)
|
||||
df.virusTotalHits = df.virusTotalHits.astype(np.int)
|
||||
df.trustedHits = df.trustedHits.astype(np.int)
|
||||
df.virusTotalHits = df.virusTotalHits.astype(np.int8)
|
||||
df.trustedHits = df.trustedHits.astype(np.int8)
|
||||
|
||||
df.to_csv("/tmp/rk/full_future_dataset.csv.gz", compression="gzip")
|
||||
|
Loading…
x
Reference in New Issue
Block a user