add staggered model training for intermediate sever prediction; refactor model return values

This commit is contained in:
2017-09-07 14:24:55 +02:00
parent 2080444fb7
commit 5bd8e41711
6 changed files with 92 additions and 70 deletions

View File

@@ -1,16 +1,19 @@
run:
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test1 --epochs 10 --depth small \
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test1 --epochs 2 --depth small \
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type final
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test2 --epochs 10 --depth small \
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test2 --epochs 2 --depth small \
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type inter
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test3 --epochs 10 --depth medium \
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test3 --epochs 2 --depth medium \
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type final
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test4 --epochs 10 --depth medium \
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test4 --epochs 2 --depth medium \
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type inter
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test5 --epochs 2 --depth small \
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type staggered
test:
python3 main.py --mode test --batch 128 --models results/test* --test data/rk_mini.csv.gz