add pauls config test (TMP)

This commit is contained in:
René Knaebel 2017-07-08 11:53:03 +02:00
parent be56112b33
commit 36cdba3fdf
2 changed files with 51 additions and 8 deletions

48
main.py
View File

@ -19,8 +19,8 @@ parser.add_argument("--test", action="store", dest="test_data",
# parser.add_argument("--h5data", action="store", dest="h5data", # parser.add_argument("--h5data", action="store", dest="h5data",
# default="") # default="")
# #
parser.add_argument("--model", action="store", dest="model", parser.add_argument("--models", action="store", dest="models",
default="model_x") default="models/model_x")
# parser.add_argument("--pred", action="store", dest="pred", # parser.add_argument("--pred", action="store", dest="pred",
# default="") # default="")
@ -80,6 +80,39 @@ args = parser.parse_args()
# session = tf.Session(config=config) # session = tf.Session(config=config)
def main_paul_best():
char_dict = dataset.get_character_dict()
user_flow_df = dataset.get_user_flow_data(args.train_data)
param = models.pauls_networks.best_config
param["vocab_size"] = len(char_dict) + 1
print(param)
print("create training dataset")
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(
user_flow_df, char_dict,
max_len=args.domain_length,
window_size=args.window)
client_tr = np_utils.to_categorical(client_tr, 2)
server_tr = np_utils.to_categorical(server_tr, 2)
embedding, model = models.get_models_by_params(param)
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit([domain_tr, flow_tr],
[client_tr, server_tr],
batch_size=args.batch_size,
epochs=args.epochs,
shuffle=True,
validation_split=0.2)
embedding.save(args.models + "_embd.h5")
model.save(args.models + "_clf.h5")
def main_hyperband(): def main_hyperband():
char_dict = dataset.get_character_dict() char_dict = dataset.get_character_dict()
user_flow_df = dataset.get_user_flow_data(args.train_data) user_flow_df = dataset.get_user_flow_data(args.train_data)
@ -137,13 +170,13 @@ def main_train():
client_tr = np_utils.to_categorical(client_tr, 2) client_tr = np_utils.to_categorical(client_tr, 2)
server_tr = np_utils.to_categorical(server_tr, 2) server_tr = np_utils.to_categorical(server_tr, 2)
shared_cnn = network.get_embedding(len(char_dict) + 1, args.embedding, args.domain_length, embedding = network.get_embedding(len(char_dict) + 1, args.embedding, args.domain_length,
args.hidden_char_dims, kernel_size, args.domain_embedding, 0.5) args.hidden_char_dims, kernel_size, args.domain_embedding, 0.5)
shared_cnn.summary() embedding.summary()
model = network.get_model(cnnDropout, flow_tr.shape[-1], args.domain_embedding, model = network.get_model(cnnDropout, flow_tr.shape[-1], args.domain_embedding,
args.window, args.domain_length, filters, kernel_size, args.window, args.domain_length, filters, kernel_size,
cnnHiddenDims, shared_cnn) cnnHiddenDims, embedding)
model.summary() model.summary()
model.compile(optimizer='adam', model.compile(optimizer='adam',
@ -157,7 +190,8 @@ def main_train():
shuffle=True, shuffle=True,
validation_split=0.2) validation_split=0.2)
model.save(args.model) embedding.save(args.models + "_embd.h5")
model.save(args.models + "_clf.h5")
def main_test(): def main_test():
@ -206,6 +240,8 @@ def main():
main_visualization() main_visualization()
if "score" in args.modes: if "score" in args.modes:
main_score() main_score()
if "paul" in args.modes:
main_paul_best()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -3,12 +3,19 @@ from keras.engine import Input, Model
from keras.layers import Embedding, Conv1D, GlobalMaxPooling1D, Dense, Dropout, Activation, TimeDistributed from keras.layers import Embedding, Conv1D, GlobalMaxPooling1D, Dense, Dropout, Activation, TimeDistributed
best_config = { best_config = {
"type": "paul",
"batch_size": 64,
"window_size": 10,
"domain_length": 40,
"flow_features": 3,
#
'dropout': 0.5,
'domain_features': 32, 'domain_features': 32,
'drop_out': 0.5, 'drop_out': 0.5,
'embedding_size': 64, 'embedding_size': 64,
'filter_main': 512, 'filter_main': 512,
'flow_features': 3, 'flow_features': 3,
'hidden_dims': 32, 'dense_main': 32,
'filter_embedding': 32, 'filter_embedding': 32,
'hidden_embedding': 32, 'hidden_embedding': 32,
'kernel_embedding': 8, 'kernel_embedding': 8,