add feature: generate and use h5 data
This commit is contained in:
parent
fdc03c9922
commit
41b38de1ab
@ -130,8 +130,8 @@ def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10,
|
|||||||
return domain_tr, flow_tr, client_tr, server_tr
|
return domain_tr, flow_tr, client_tr, server_tr
|
||||||
|
|
||||||
|
|
||||||
def store_h5dataset(domain_tr, flow_tr, client_tr, server_tr):
|
def store_h5dataset(path, domain_tr, flow_tr, client_tr, server_tr):
|
||||||
f = h5py.File("data/full_dataset.h5", "w")
|
f = h5py.File(path, "w")
|
||||||
domain_tr = domain_tr.astype(np.int8)
|
domain_tr = domain_tr.astype(np.int8)
|
||||||
f.create_dataset("domain", data=domain_tr)
|
f.create_dataset("domain", data=domain_tr)
|
||||||
f.create_dataset("flow", data=flow_tr)
|
f.create_dataset("flow", data=flow_tr)
|
||||||
@ -142,6 +142,11 @@ def store_h5dataset(domain_tr, flow_tr, client_tr, server_tr):
|
|||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
|
|
||||||
|
def load_h5dataset(path):
|
||||||
|
data = h5py.File(path, "r")
|
||||||
|
return data["domain"], data["flow"], data["client"], data["server"]
|
||||||
|
|
||||||
|
|
||||||
def create_dataset_from_lists(domains, flows, vocab, max_len, window_size=10):
|
def create_dataset_from_lists(domains, flows, vocab, max_len, window_size=10):
|
||||||
"""
|
"""
|
||||||
combines domain and feature windows to sequential training data
|
combines domain and feature windows to sequential training data
|
||||||
|
89
main.py
89
main.py
@ -1,6 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import h5py
|
|
||||||
from keras.models import load_model
|
from keras.models import load_model
|
||||||
from keras.utils import np_utils
|
from keras.utils import np_utils
|
||||||
|
|
||||||
@ -78,6 +78,7 @@ args = parser.parse_args()
|
|||||||
|
|
||||||
args.embedding_model = args.models + "_embd.h5"
|
args.embedding_model = args.models + "_embd.h5"
|
||||||
args.clf_model = args.models + "_clf.h5"
|
args.clf_model = args.models + "_clf.h5"
|
||||||
|
args.h5data = args.train_data + ".h5"
|
||||||
|
|
||||||
|
|
||||||
# config = tf.ConfigProto(log_device_placement=True)
|
# config = tf.ConfigProto(log_device_placement=True)
|
||||||
@ -86,6 +87,11 @@ args.clf_model = args.models + "_clf.h5"
|
|||||||
# session = tf.Session(config=config)
|
# session = tf.Session(config=config)
|
||||||
|
|
||||||
|
|
||||||
|
def exists_or_make_path(p):
|
||||||
|
if not os.path.exists(p):
|
||||||
|
os.makedirs(p)
|
||||||
|
|
||||||
|
|
||||||
def main_paul_best():
|
def main_paul_best():
|
||||||
char_dict = dataset.get_character_dict()
|
char_dict = dataset.get_character_dict()
|
||||||
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
||||||
@ -157,28 +163,46 @@ def main_hyperband():
|
|||||||
|
|
||||||
|
|
||||||
def main_train():
|
def main_train():
|
||||||
# parameter
|
exists_or_make_path(args.clf_model)
|
||||||
dropout_main = 0.5
|
|
||||||
dense_main = 512
|
|
||||||
kernel_main = 3
|
|
||||||
filter_main = 128
|
|
||||||
network = models.pauls_networks if args.model_type == "paul" else models.renes_networks
|
|
||||||
|
|
||||||
char_dict = dataset.get_character_dict()
|
char_dict = dataset.get_character_dict()
|
||||||
|
print("check for h5data")
|
||||||
|
try:
|
||||||
|
open(args.h5data, "r")
|
||||||
|
except FileNotFoundError:
|
||||||
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
||||||
|
|
||||||
print("create training dataset")
|
print("create training dataset")
|
||||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(
|
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(
|
||||||
user_flow_df, char_dict,
|
user_flow_df, char_dict,
|
||||||
max_len=args.domain_length, window_size=args.window)
|
max_len=args.domain_length, window_size=args.window)
|
||||||
|
print("store training dataset as h5 file")
|
||||||
|
dataset.store_h5dataset(args.h5data, domain_tr, flow_tr, client_tr, server_tr)
|
||||||
|
print("load h5 dataset")
|
||||||
|
domain_tr, flow_tr, client_tr, server_tr = dataset.load_h5dataset(args.h5data)
|
||||||
|
|
||||||
embedding = network.get_embedding(len(char_dict) + 1, args.embedding, args.domain_length,
|
# parameter
|
||||||
args.hidden_char_dims, kernel_main, args.domain_embedding, 0.5)
|
param = {
|
||||||
|
"type": "paul",
|
||||||
|
"batch_size": 64,
|
||||||
|
"window_size": args.window,
|
||||||
|
"domain_length": args.domain_length,
|
||||||
|
"flow_features": 3,
|
||||||
|
"vocab_size": len(char_dict) + 1,
|
||||||
|
#
|
||||||
|
'dropout': 0.5,
|
||||||
|
'domain_features': args.domain_embedding,
|
||||||
|
'embedding_size': args.embedding,
|
||||||
|
'filter_main': 128,
|
||||||
|
'flow_features': 3,
|
||||||
|
'dense_main': 512,
|
||||||
|
'filter_embedding': args.hidden_char_dims,
|
||||||
|
'hidden_embedding': args.domain_embedding,
|
||||||
|
'kernel_embedding': 3,
|
||||||
|
'kernels_main': 3,
|
||||||
|
'input_length': 40
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding, model = models.get_models_by_params(param)
|
||||||
embedding.summary()
|
embedding.summary()
|
||||||
|
|
||||||
model = network.get_model(dropout_main, flow_tr.shape[-1], args.domain_embedding,
|
|
||||||
args.window, args.domain_length, filter_main, kernel_main,
|
|
||||||
dense_main, embedding)
|
|
||||||
model.summary()
|
model.summary()
|
||||||
|
|
||||||
model.compile(optimizer='adam',
|
model.compile(optimizer='adam',
|
||||||
@ -196,41 +220,6 @@ def main_train():
|
|||||||
model.save(args.clf_model)
|
model.save(args.clf_model)
|
||||||
|
|
||||||
|
|
||||||
def main_train_h5():
|
|
||||||
# parameter
|
|
||||||
dropout_main = 0.5
|
|
||||||
dense_main = 512
|
|
||||||
kernel_main = 3
|
|
||||||
filter_main = 128
|
|
||||||
network = models.pauls_networks if args.model_type == "paul" else models.renes_networks
|
|
||||||
|
|
||||||
char_dict = dataset.get_character_dict()
|
|
||||||
data = h5py.File("data/full_dataset.h5", "r")
|
|
||||||
|
|
||||||
embedding = network.get_embedding(len(char_dict) + 1, args.embedding, args.domain_length,
|
|
||||||
args.hidden_char_dims, kernel_main, args.domain_embedding, 0.5)
|
|
||||||
embedding.summary()
|
|
||||||
|
|
||||||
model = network.get_model(dropout_main, data["flow"].shape[-1], args.domain_embedding,
|
|
||||||
args.window, args.domain_length, filter_main, kernel_main,
|
|
||||||
dense_main, embedding)
|
|
||||||
model.summary()
|
|
||||||
|
|
||||||
model.compile(optimizer='adam',
|
|
||||||
loss='categorical_crossentropy',
|
|
||||||
metrics=['accuracy'])
|
|
||||||
|
|
||||||
model.fit([data["domain"], data["flow"]],
|
|
||||||
[data["client"], data["server"]],
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
epochs=args.epochs,
|
|
||||||
shuffle=True,
|
|
||||||
validation_split=0.2)
|
|
||||||
|
|
||||||
embedding.save(args.embedding_model)
|
|
||||||
model.save(args.clf_model)
|
|
||||||
|
|
||||||
|
|
||||||
def main_test():
|
def main_test():
|
||||||
char_dict = dataset.get_character_dict()
|
char_dict = dataset.get_character_dict()
|
||||||
user_flow_df = dataset.get_user_flow_data(args.test_data)
|
user_flow_df = dataset.get_user_flow_data(args.test_data)
|
||||||
|
Loading…
Reference in New Issue
Block a user