add feature: generate and use h5 data
This commit is contained in:
parent
fdc03c9922
commit
41b38de1ab
@ -130,8 +130,8 @@ def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10,
|
||||
return domain_tr, flow_tr, client_tr, server_tr
|
||||
|
||||
|
||||
def store_h5dataset(domain_tr, flow_tr, client_tr, server_tr):
|
||||
f = h5py.File("data/full_dataset.h5", "w")
|
||||
def store_h5dataset(path, domain_tr, flow_tr, client_tr, server_tr):
|
||||
f = h5py.File(path, "w")
|
||||
domain_tr = domain_tr.astype(np.int8)
|
||||
f.create_dataset("domain", data=domain_tr)
|
||||
f.create_dataset("flow", data=flow_tr)
|
||||
@ -142,6 +142,11 @@ def store_h5dataset(domain_tr, flow_tr, client_tr, server_tr):
|
||||
f.close()
|
||||
|
||||
|
||||
def load_h5dataset(path):
|
||||
data = h5py.File(path, "r")
|
||||
return data["domain"], data["flow"], data["client"], data["server"]
|
||||
|
||||
|
||||
def create_dataset_from_lists(domains, flows, vocab, max_len, window_size=10):
|
||||
"""
|
||||
combines domain and feature windows to sequential training data
|
||||
|
97
main.py
97
main.py
@ -1,6 +1,6 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import h5py
|
||||
from keras.models import load_model
|
||||
from keras.utils import np_utils
|
||||
|
||||
@ -78,6 +78,7 @@ args = parser.parse_args()
|
||||
|
||||
args.embedding_model = args.models + "_embd.h5"
|
||||
args.clf_model = args.models + "_clf.h5"
|
||||
args.h5data = args.train_data + ".h5"
|
||||
|
||||
|
||||
# config = tf.ConfigProto(log_device_placement=True)
|
||||
@ -86,6 +87,11 @@ args.clf_model = args.models + "_clf.h5"
|
||||
# session = tf.Session(config=config)
|
||||
|
||||
|
||||
def exists_or_make_path(p):
|
||||
if not os.path.exists(p):
|
||||
os.makedirs(p)
|
||||
|
||||
|
||||
def main_paul_best():
|
||||
char_dict = dataset.get_character_dict()
|
||||
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
||||
@ -157,28 +163,46 @@ def main_hyperband():
|
||||
|
||||
|
||||
def main_train():
|
||||
# parameter
|
||||
dropout_main = 0.5
|
||||
dense_main = 512
|
||||
kernel_main = 3
|
||||
filter_main = 128
|
||||
network = models.pauls_networks if args.model_type == "paul" else models.renes_networks
|
||||
|
||||
exists_or_make_path(args.clf_model)
|
||||
char_dict = dataset.get_character_dict()
|
||||
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
||||
print("check for h5data")
|
||||
try:
|
||||
open(args.h5data, "r")
|
||||
except FileNotFoundError:
|
||||
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
||||
print("create training dataset")
|
||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(
|
||||
user_flow_df, char_dict,
|
||||
max_len=args.domain_length, window_size=args.window)
|
||||
print("store training dataset as h5 file")
|
||||
dataset.store_h5dataset(args.h5data, domain_tr, flow_tr, client_tr, server_tr)
|
||||
print("load h5 dataset")
|
||||
domain_tr, flow_tr, client_tr, server_tr = dataset.load_h5dataset(args.h5data)
|
||||
|
||||
print("create training dataset")
|
||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(
|
||||
user_flow_df, char_dict,
|
||||
max_len=args.domain_length, window_size=args.window)
|
||||
# parameter
|
||||
param = {
|
||||
"type": "paul",
|
||||
"batch_size": 64,
|
||||
"window_size": args.window,
|
||||
"domain_length": args.domain_length,
|
||||
"flow_features": 3,
|
||||
"vocab_size": len(char_dict) + 1,
|
||||
#
|
||||
'dropout': 0.5,
|
||||
'domain_features': args.domain_embedding,
|
||||
'embedding_size': args.embedding,
|
||||
'filter_main': 128,
|
||||
'flow_features': 3,
|
||||
'dense_main': 512,
|
||||
'filter_embedding': args.hidden_char_dims,
|
||||
'hidden_embedding': args.domain_embedding,
|
||||
'kernel_embedding': 3,
|
||||
'kernels_main': 3,
|
||||
'input_length': 40
|
||||
}
|
||||
|
||||
embedding = network.get_embedding(len(char_dict) + 1, args.embedding, args.domain_length,
|
||||
args.hidden_char_dims, kernel_main, args.domain_embedding, 0.5)
|
||||
embedding, model = models.get_models_by_params(param)
|
||||
embedding.summary()
|
||||
|
||||
model = network.get_model(dropout_main, flow_tr.shape[-1], args.domain_embedding,
|
||||
args.window, args.domain_length, filter_main, kernel_main,
|
||||
dense_main, embedding)
|
||||
model.summary()
|
||||
|
||||
model.compile(optimizer='adam',
|
||||
@ -196,41 +220,6 @@ def main_train():
|
||||
model.save(args.clf_model)
|
||||
|
||||
|
||||
def main_train_h5():
|
||||
# parameter
|
||||
dropout_main = 0.5
|
||||
dense_main = 512
|
||||
kernel_main = 3
|
||||
filter_main = 128
|
||||
network = models.pauls_networks if args.model_type == "paul" else models.renes_networks
|
||||
|
||||
char_dict = dataset.get_character_dict()
|
||||
data = h5py.File("data/full_dataset.h5", "r")
|
||||
|
||||
embedding = network.get_embedding(len(char_dict) + 1, args.embedding, args.domain_length,
|
||||
args.hidden_char_dims, kernel_main, args.domain_embedding, 0.5)
|
||||
embedding.summary()
|
||||
|
||||
model = network.get_model(dropout_main, data["flow"].shape[-1], args.domain_embedding,
|
||||
args.window, args.domain_length, filter_main, kernel_main,
|
||||
dense_main, embedding)
|
||||
model.summary()
|
||||
|
||||
model.compile(optimizer='adam',
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
|
||||
model.fit([data["domain"], data["flow"]],
|
||||
[data["client"], data["server"]],
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
shuffle=True,
|
||||
validation_split=0.2)
|
||||
|
||||
embedding.save(args.embedding_model)
|
||||
model.save(args.clf_model)
|
||||
|
||||
|
||||
def main_test():
|
||||
char_dict = dataset.get_character_dict()
|
||||
user_flow_df = dataset.get_user_flow_data(args.test_data)
|
||||
|
Loading…
x
Reference in New Issue
Block a user