add parser argument for naming in multi model modes, minor fixes, re-set fix vals for network - need to make them flexible
This commit is contained in:
parent
ed4f478bad
commit
2080444fb7
56
main.py
56
main.py
@ -66,13 +66,12 @@ PARAMS = {
|
||||
'dropout': 0.5,
|
||||
'domain_features': args.domain_embedding,
|
||||
'embedding_size': args.embedding,
|
||||
'filter_main': 64,
|
||||
'flow_features': 3,
|
||||
# 'dense_main': 512,
|
||||
'dense_main': 64,
|
||||
'filter_embedding': args.hidden_char_dims,
|
||||
'hidden_embedding': args.domain_embedding,
|
||||
'kernel_embedding': 3,
|
||||
'filter_main': 128,
|
||||
'dense_main': 128,
|
||||
'kernels_main': 3,
|
||||
'input_length': 40,
|
||||
'model_output': args.model_output
|
||||
@ -154,6 +153,35 @@ def main_train(param=None):
|
||||
custom_class_weights = None
|
||||
|
||||
logger.info(f"select model: {args.model_type}")
|
||||
if args.model_type == "staggered":
|
||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||
model = new_model
|
||||
logger.info("compile and train model")
|
||||
embedding.summary()
|
||||
model.summary()
|
||||
logger.info(model.get_config())
|
||||
|
||||
model.outputs
|
||||
|
||||
model.compile(optimizer='adam',
|
||||
loss='binary_crossentropy',
|
||||
metrics=['accuracy'] + custom_metrics)
|
||||
|
||||
if args.model_output == "both":
|
||||
labels = [client_tr, server_tr]
|
||||
else:
|
||||
raise ValueError("unknown model output")
|
||||
|
||||
model.fit([domain_tr, flow_tr],
|
||||
labels,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
callbacks=callbacks,
|
||||
shuffle=True,
|
||||
validation_split=0.2,
|
||||
class_weight=custom_class_weights)
|
||||
|
||||
else:
|
||||
if args.model_type == "inter":
|
||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||
model = new_model
|
||||
@ -225,9 +253,9 @@ def main_visualization():
|
||||
# client_val, server_val = client_val.value, server_val.value
|
||||
client_val = client_val.value
|
||||
|
||||
# logger.info("plot model")
|
||||
# model = load_model(model_args.clf_model, custom_objects=models.get_metrics())
|
||||
# visualize.plot_model(model, os.path.join(args.model_path, "model.png"))
|
||||
logger.info("plot model")
|
||||
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
|
||||
|
||||
try:
|
||||
logger.info("plot training curve")
|
||||
@ -276,10 +304,10 @@ def main_visualization():
|
||||
# visualize.plot_confusion_matrix(server_val.argmax(1), server_pred.argmax(1),
|
||||
# "{}/server_cov.png".format(args.model_path),
|
||||
# normalize=False, title="Server Confusion Matrix")
|
||||
# logger.info("visualize embedding")
|
||||
# domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||
# domain_embedding = np.load(args.model_path + "/domain_embds.npy")
|
||||
# visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
|
||||
logger.info("visualize embedding")
|
||||
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||
domain_embedding = np.load(args.model_path + "/domain_embds.npy")
|
||||
visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
|
||||
|
||||
|
||||
def main_visualize_all():
|
||||
@ -293,7 +321,7 @@ def main_visualize_all():
|
||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||
visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"])
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("all_client_prc.png")
|
||||
visualize.plot_save(f"{args.output_prefix}_client_prc.png")
|
||||
|
||||
logger.info("plot roc curves")
|
||||
visualize.plot_clf()
|
||||
@ -301,7 +329,7 @@ def main_visualize_all():
|
||||
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||
visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"])
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("all_client_roc.png")
|
||||
visualize.plot_save(f"{args.output_prefix}_client_roc.png")
|
||||
|
||||
df_val = pd.DataFrame(data={"names": name_val, "client_val": client_val})
|
||||
user_vals = df_val.groupby(df_val.names).max().client_val.as_matrix().astype(float)
|
||||
@ -314,7 +342,7 @@ def main_visualize_all():
|
||||
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
||||
visualize.plot_precision_recall(user_vals, user_preds, model_args["model_path"])
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("all_user_client_prc.png")
|
||||
visualize.plot_save(f"{args.output_prefix}_user_client_prc.png")
|
||||
|
||||
logger.info("plot user roc curves")
|
||||
visualize.plot_clf()
|
||||
@ -324,7 +352,7 @@ def main_visualize_all():
|
||||
user_preds = df_pred.groupby(df_pred.names).max().client_val.as_matrix().astype(float)
|
||||
visualize.plot_roc_curve(user_vals, user_preds, model_args["model_path"])
|
||||
visualize.plot_legend()
|
||||
visualize.plot_save("all_user_client_roc.png")
|
||||
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")
|
||||
|
||||
|
||||
def main_data():
|
||||
|
12
run.sh
12
run.sh
@ -1,6 +1,10 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
|
||||
RESDIR=$1
|
||||
mkdir -p /tmp/rk/RESDIR
|
||||
DATADIR=$2
|
||||
|
||||
for output in client both
|
||||
do
|
||||
for depth in small medium
|
||||
@ -9,8 +13,8 @@ do
|
||||
do
|
||||
|
||||
python main.py --mode train \
|
||||
--train /tmp/rk/currentData.csv \
|
||||
--model /tmp/rk/results/${output}_${depth}_${mtype} \
|
||||
--train ${DATADIR}/currentData.csv \
|
||||
--model ${RESDIR}/${output}_${depth}_${mtype} \
|
||||
--epochs 50 \
|
||||
--embd 64 \
|
||||
--hidden_char_dims 128 \
|
||||
@ -28,8 +32,8 @@ done
|
||||
for depth in small medium
|
||||
do
|
||||
python main.py --mode train \
|
||||
--train /tmp/rk/currentData.csv \
|
||||
--model /tmp/rk/results/both_${depth}_inter \
|
||||
--train ${DATADIR}/currentData.csv \
|
||||
--model ${RESDIR}/both_${depth}_inter \
|
||||
--epochs 50 \
|
||||
--embd 64 \
|
||||
--hidden_char_dims 128 \
|
||||
|
6
test.sh
6
test.sh
@ -1,10 +1,12 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
RESDIR=$1
|
||||
DATADIR=$2
|
||||
|
||||
for output in client both
|
||||
do
|
||||
python3 main.py --mode test --batch 1024 \
|
||||
--models tm/rk/${output}_* \
|
||||
--test data/futureData.csv \
|
||||
--models ${RESDIR}/${output}_*/ \
|
||||
--test ${DATADIR}/futureData.csv \
|
||||
--model_output ${output}
|
||||
done
|
||||
|
Loading…
Reference in New Issue
Block a user