add parser argument for naming in multi model modes, minor fixes, re-set fix vals for network - need to make them flexible

This commit is contained in:
René Knaebel 2017-09-05 17:40:57 +02:00
parent ed4f478bad
commit 2080444fb7
3 changed files with 77 additions and 43 deletions

56
main.py
View File

@ -66,13 +66,12 @@ PARAMS = {
'dropout': 0.5,
'domain_features': args.domain_embedding,
'embedding_size': args.embedding,
'filter_main': 64,
'flow_features': 3,
# 'dense_main': 512,
'dense_main': 64,
'filter_embedding': args.hidden_char_dims,
'hidden_embedding': args.domain_embedding,
'kernel_embedding': 3,
'filter_main': 128,
'dense_main': 128,
'kernels_main': 3,
'input_length': 40,
'model_output': args.model_output
@ -154,6 +153,35 @@ def main_train(param=None):
custom_class_weights = None
logger.info(f"select model: {args.model_type}")
if args.model_type == "staggered":
server_tr = np.expand_dims(server_windows_tr, 2)
model = new_model
logger.info("compile and train model")
embedding.summary()
model.summary()
logger.info(model.get_config())
model.outputs
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'] + custom_metrics)
if args.model_output == "both":
labels = [client_tr, server_tr]
else:
raise ValueError("unknown model output")
model.fit([domain_tr, flow_tr],
labels,
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=callbacks,
shuffle=True,
validation_split=0.2,
class_weight=custom_class_weights)
else:
if args.model_type == "inter":
server_tr = np.expand_dims(server_windows_tr, 2)
model = new_model
@ -225,9 +253,9 @@ def main_visualization():
# client_val, server_val = client_val.value, server_val.value
client_val = client_val.value
# logger.info("plot model")
# model = load_model(model_args.clf_model, custom_objects=models.get_metrics())
# visualize.plot_model(model, os.path.join(args.model_path, "model.png"))
logger.info("plot model")
model = load_model(args.clf_model, custom_objects=models.get_metrics())
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
try:
logger.info("plot training curve")
@ -276,10 +304,10 @@ def main_visualization():
# visualize.plot_confusion_matrix(server_val.argmax(1), server_pred.argmax(1),
# "{}/server_cov.png".format(args.model_path),
# normalize=False, title="Server Confusion Matrix")
# logger.info("visualize embedding")
# domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
# domain_embedding = np.load(args.model_path + "/domain_embds.npy")
# visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
logger.info("visualize embedding")
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
domain_embedding = np.load(args.model_path + "/domain_embds.npy")
visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
def main_visualize_all():
@ -293,7 +321,7 @@ def main_visualize_all():
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"])
visualize.plot_legend()
visualize.plot_save("all_client_prc.png")
visualize.plot_save(f"{args.output_prefix}_client_prc.png")
logger.info("plot roc curves")
visualize.plot_clf()
@ -301,7 +329,7 @@ def main_visualize_all():
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"])
visualize.plot_legend()
visualize.plot_save("all_client_roc.png")
visualize.plot_save(f"{args.output_prefix}_client_roc.png")
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
@ -314,7 +342,7 @@ def main_visualize_all():
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"])
visualize.plot_legend()
visualize.plot_save("all_user_client_prc.png")
visualize.plot_save(f"{args.output_prefix}_user_client_prc.png")
logger.info("plot user roc curves")
visualize.plot_clf()
@ -324,7 +352,7 @@ def main_visualize_all():
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"])
visualize.plot_legend()
visualize.plot_save("all_user_client_roc.png")
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")
def main_data():

12
run.sh
View File

@ -1,6 +1,10 @@
#!/usr/bin/env bash
RESDIR=$1
mkdir -p /tmp/rk/RESDIR
DATADIR=$2
for output in client both
do
for depth in small medium
@ -9,8 +13,8 @@ do
do
python main.py --mode train \
--train /tmp/rk/currentData.csv \
--model /tmp/rk/results/${output}_${depth}_${mtype} \
--train ${DATADIR}/currentData.csv \
--model ${RESDIR}/${output}_${depth}_${mtype} \
--epochs 50 \
--embd 64 \
--hidden_char_dims 128 \
@ -28,8 +32,8 @@ done
for depth in small medium
do
python main.py --mode train \
--train /tmp/rk/currentData.csv \
--model /tmp/rk/results/both_${depth}_inter \
--train ${DATADIR}/currentData.csv \
--model ${RESDIR}/both_${depth}_inter \
--epochs 50 \
--embd 64 \
--hidden_char_dims 128 \

View File

@ -1,10 +1,12 @@
#!/usr/bin/env bash
RESDIR=$1
DATADIR=$2
for output in client both
do
python3 main.py --mode test --batch 1024 \
--models tm/rk/${output}_* \
--test data/futureData.csv \
--models ${RESDIR}/${output}_*/ \
--test ${DATADIR}/futureData.csv \
--model_output ${output}
done