add output for main_test
This commit is contained in:
parent
21b9d7be73
commit
4a9f94a029
@ -3,6 +3,7 @@ import string
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from keras.utils import np_utils
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
chars = dict((char, idx + 1) for (idx, char) in
|
chars = dict((char, idx + 1) for (idx, char) in
|
||||||
@ -137,6 +138,9 @@ def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10,
|
|||||||
client_tr[:pos_idx.shape[-1]] = 1.0
|
client_tr[:pos_idx.shape[-1]] = 1.0
|
||||||
server_tr = server_tr[idx]
|
server_tr = server_tr[idx]
|
||||||
|
|
||||||
|
client_tr = np_utils.to_categorical(client_tr, 2)
|
||||||
|
server_tr = np_utils.to_categorical(server_tr, 2)
|
||||||
|
|
||||||
return domain_tr, flow_tr, client_tr, server_tr
|
return domain_tr, flow_tr, client_tr, server_tr
|
||||||
|
|
||||||
|
|
||||||
|
41
main.py
41
main.py
@ -73,6 +73,9 @@ parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
args.embedding_model = args.models + "_embd.h5"
|
||||||
|
args.clf_model = args.models + "_clf.h5"
|
||||||
|
|
||||||
|
|
||||||
# config = tf.ConfigProto(log_device_placement=True)
|
# config = tf.ConfigProto(log_device_placement=True)
|
||||||
# config.gpu_options.per_process_gpu_memory_fraction = 0.5
|
# config.gpu_options.per_process_gpu_memory_fraction = 0.5
|
||||||
@ -109,8 +112,8 @@ def main_paul_best():
|
|||||||
shuffle=True,
|
shuffle=True,
|
||||||
validation_split=0.2)
|
validation_split=0.2)
|
||||||
|
|
||||||
embedding.save(args.models + "_embd.h5")
|
embedding.save(args.embedding_model)
|
||||||
model.save(args.models + "_clf.h5")
|
model.save(args.clf_model)
|
||||||
|
|
||||||
|
|
||||||
def main_hyperband():
|
def main_hyperband():
|
||||||
@ -145,8 +148,6 @@ def main_hyperband():
|
|||||||
user_flow_df, char_dict,
|
user_flow_df, char_dict,
|
||||||
max_len=args.domain_length,
|
max_len=args.domain_length,
|
||||||
window_size=args.window)
|
window_size=args.window)
|
||||||
client_tr = np_utils.to_categorical(client_tr, 2)
|
|
||||||
server_tr = np_utils.to_categorical(server_tr, 2)
|
|
||||||
|
|
||||||
hp = hyperband.Hyperband(params, [domain_tr, flow_tr], [client_tr, server_tr])
|
hp = hyperband.Hyperband(params, [domain_tr, flow_tr], [client_tr, server_tr])
|
||||||
hp.run()
|
hp.run()
|
||||||
@ -154,10 +155,10 @@ def main_hyperband():
|
|||||||
|
|
||||||
def main_train():
|
def main_train():
|
||||||
# parameter
|
# parameter
|
||||||
cnnDropout = 0.5
|
dropout_main = 0.5
|
||||||
cnnHiddenDims = 512
|
dense_main = 512
|
||||||
kernel_size = 3
|
kernel_main = 3
|
||||||
filters = 128
|
filter_main = 128
|
||||||
network = models.pauls_networks if args.model_type == "paul" else models.renes_networks
|
network = models.pauls_networks if args.model_type == "paul" else models.renes_networks
|
||||||
|
|
||||||
char_dict = dataset.get_character_dict()
|
char_dict = dataset.get_character_dict()
|
||||||
@ -167,16 +168,14 @@ def main_train():
|
|||||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(
|
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(
|
||||||
user_flow_df, char_dict,
|
user_flow_df, char_dict,
|
||||||
max_len=args.domain_length, window_size=args.window)
|
max_len=args.domain_length, window_size=args.window)
|
||||||
client_tr = np_utils.to_categorical(client_tr, 2)
|
|
||||||
server_tr = np_utils.to_categorical(server_tr, 2)
|
|
||||||
|
|
||||||
embedding = network.get_embedding(len(char_dict) + 1, args.embedding, args.domain_length,
|
embedding = network.get_embedding(len(char_dict) + 1, args.embedding, args.domain_length,
|
||||||
args.hidden_char_dims, kernel_size, args.domain_embedding, 0.5)
|
args.hidden_char_dims, kernel_main, args.domain_embedding, 0.5)
|
||||||
embedding.summary()
|
embedding.summary()
|
||||||
|
|
||||||
model = network.get_model(cnnDropout, flow_tr.shape[-1], args.domain_embedding,
|
model = network.get_model(dropout_main, flow_tr.shape[-1], args.domain_embedding,
|
||||||
args.window, args.domain_length, filters, kernel_size,
|
args.window, args.domain_length, filter_main, kernel_main,
|
||||||
cnnHiddenDims, embedding)
|
dense_main, embedding)
|
||||||
model.summary()
|
model.summary()
|
||||||
|
|
||||||
model.compile(optimizer='adam',
|
model.compile(optimizer='adam',
|
||||||
@ -190,8 +189,11 @@ def main_train():
|
|||||||
shuffle=True,
|
shuffle=True,
|
||||||
validation_split=0.2)
|
validation_split=0.2)
|
||||||
|
|
||||||
embedding.save(args.models + "_embd.h5")
|
embedding.save(args.embedding_model)
|
||||||
model.save(args.models + "_clf.h5")
|
model.save(args.clf_model)
|
||||||
|
|
||||||
|
|
||||||
|
from keras.models import load_model
|
||||||
|
|
||||||
|
|
||||||
def main_test():
|
def main_test():
|
||||||
@ -200,7 +202,12 @@ def main_test():
|
|||||||
domain_val, flow_val, client_val, server_val = dataset.create_dataset_from_flows(
|
domain_val, flow_val, client_val, server_val = dataset.create_dataset_from_flows(
|
||||||
user_flow_df, char_dict,
|
user_flow_df, char_dict,
|
||||||
max_len=args.domain_length, window_size=args.window)
|
max_len=args.domain_length, window_size=args.window)
|
||||||
# TODO: get model and exec model.evaluate(...)
|
# embedding = load_model(args.embedding_model)
|
||||||
|
clf = load_model(args.clf_model)
|
||||||
|
|
||||||
|
print(clf.evaluate([domain_val, flow_val],
|
||||||
|
[client_val, server_val],
|
||||||
|
batch_size=args.batch_size))
|
||||||
|
|
||||||
|
|
||||||
def main_visualization():
|
def main_visualization():
|
||||||
|
Loading…
Reference in New Issue
Block a user