From 71f218888d9ebfab432221019baa9d482d0122ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Thu, 7 Sep 2017 15:31:04 +0200 Subject: [PATCH] set server to be not trainable too; refactor visualization script --- fancy.sh | 23 +++++++++++++++++++++++ main.py | 1 + 2 files changed, 24 insertions(+) create mode 100644 fancy.sh diff --git a/fancy.sh b/fancy.sh new file mode 100644 index 0000000..d604926 --- /dev/null +++ b/fancy.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +RESDIR=$1 +DATADIR=$2 + +python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_small_final --test ${DATADIR}/futureData.csv --model_output both +python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_small_inter --test ${DATADIR}/futureData.csv --model_output both +python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_small_final --test ${DATADIR}/futureData.csv --model_output client +python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_small_inter --test ${DATADIR}/futureData.csv --model_output client + +python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_final --test ${DATADIR}/futureData.csv --model_output both +python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_inter --test ${DATADIR}/futureData.csv --model_output both +python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_final --test ${DATADIR}/futureData.csv --model_output client +python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_inter --test ${DATADIR}/futureData.csv --model_output client + +python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \ + --models ${RESDIR}/*_small_*/ --out-prefix ${RESDIR}/small + +python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \ + --models ${RESDIR}/*_medium_*/ --out-prefix ${RESDIR}/medium + +python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \ + --models ${RESDIR}/*/ --out-prefix ${RESDIR}/all \ No newline at end of file diff --git a/main.py b/main.py index 43247c5..7bbe486 100644 --- a/main.py +++ b/main.py @@ -180,6 +180,7 @@ def main_train(param=None): class_weight=custom_class_weights) model.get_layer("dense_server").trainable = False + model.get_layer("server").trainable = False model.compile(optimizer='adam', loss='binary_crossentropy', loss_weights={"client": 1.0, "server": 0.0},