From b0da2de0eab5e14e887464886518c2fe284aa79a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Sat, 29 Jul 2017 19:47:02 +0200 Subject: [PATCH] move utils functions to new file --- dataset.py | 3 +-- main.py | 18 +++++++----------- utils.py | 6 ++++++ 3 files changed, 14 insertions(+), 13 deletions(-) create mode 100644 utils.py diff --git a/dataset.py b/dataset.py index 2b28aab..ec6a0d4 100644 --- a/dataset.py +++ b/dataset.py @@ -29,7 +29,7 @@ def encode_char(c): encode_char = np.vectorize(encode_char) -# TODO: refactor +# TODO: ask for correct refactoring def get_user_chunks(user_flow, window=10): # TODO: what is maxLengthInSeconds for?!? # maxMilliSeconds = maxLengthInSeconds * 1000 @@ -149,7 +149,6 @@ def create_dataset_from_lists(chunks, vocab, max_len): :param max_len: :return: """ - def get_domain_features_reduced(d): return get_domain_features(d[0], vocab, max_len) diff --git a/main.py b/main.py index bb251e4..278b1ff 100644 --- a/main.py +++ b/main.py @@ -17,6 +17,7 @@ import models # create logger import visualize from dataset import load_or_generate_h5data +from utils import exists_or_make_path logger = logging.getLogger('logger') logger.setLevel(logging.DEBUG) @@ -55,11 +56,6 @@ if args.gpu: session = tf.Session(config=config) -def exists_or_make_path(p): - if not os.path.exists(p): - os.makedirs(p) - - def main_paul_best(): char_dict = dataset.get_character_dict() pauls_best_params = models.pauls_networks.best_config @@ -80,15 +76,15 @@ def main_hyperband(): "flow_features": [3], "input_length": [40], # model params - "embedding_size": [16, 32, 64, 128, 256, 512], - "filter_embedding": [16, 32, 64, 128, 256, 512], + "embedding_size": [8, 16, 32, 64, 128, 256], + "filter_embedding": [8, 16, 32, 64, 128, 256], "kernel_embedding": [1, 3, 5, 7, 9], - "hidden_embedding": [16, 32, 64, 128, 256, 512], + "hidden_embedding": [8, 16, 32, 64, 128, 256], "dropout": [0.5], - "domain_features": [16, 32, 64, 128, 256, 512], - "filter_main": [16, 32, 64, 128, 256, 512], + "domain_features": [8, 16, 32, 64, 128, 256], + "filter_main": [8, 16, 32, 64, 128, 256], "kernels_main": [1, 3, 5, 7, 9], - "dense_main": [16, 32, 64, 128, 256, 512], + "dense_main": [8, 16, 32, 64, 128, 256], } logger.info("create training dataset") diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..fb9b783 --- /dev/null +++ b/utils.py @@ -0,0 +1,6 @@ +import os + + +def exists_or_make_path(p): + if not os.path.exists(p): + os.makedirs(p)