move utils functions to new file
This commit is contained in:
parent
820a5d1a4d
commit
b0da2de0ea
@ -29,7 +29,7 @@ def encode_char(c):
|
|||||||
encode_char = np.vectorize(encode_char)
|
encode_char = np.vectorize(encode_char)
|
||||||
|
|
||||||
|
|
||||||
# TODO: refactor
|
# TODO: ask for correct refactoring
|
||||||
def get_user_chunks(user_flow, window=10):
|
def get_user_chunks(user_flow, window=10):
|
||||||
# TODO: what is maxLengthInSeconds for?!?
|
# TODO: what is maxLengthInSeconds for?!?
|
||||||
# maxMilliSeconds = maxLengthInSeconds * 1000
|
# maxMilliSeconds = maxLengthInSeconds * 1000
|
||||||
@ -149,7 +149,6 @@ def create_dataset_from_lists(chunks, vocab, max_len):
|
|||||||
:param max_len:
|
:param max_len:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_domain_features_reduced(d):
|
def get_domain_features_reduced(d):
|
||||||
return get_domain_features(d[0], vocab, max_len)
|
return get_domain_features(d[0], vocab, max_len)
|
||||||
|
|
||||||
|
18
main.py
18
main.py
@ -17,6 +17,7 @@ import models
|
|||||||
# create logger
|
# create logger
|
||||||
import visualize
|
import visualize
|
||||||
from dataset import load_or_generate_h5data
|
from dataset import load_or_generate_h5data
|
||||||
|
from utils import exists_or_make_path
|
||||||
|
|
||||||
logger = logging.getLogger('logger')
|
logger = logging.getLogger('logger')
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
@ -55,11 +56,6 @@ if args.gpu:
|
|||||||
session = tf.Session(config=config)
|
session = tf.Session(config=config)
|
||||||
|
|
||||||
|
|
||||||
def exists_or_make_path(p):
|
|
||||||
if not os.path.exists(p):
|
|
||||||
os.makedirs(p)
|
|
||||||
|
|
||||||
|
|
||||||
def main_paul_best():
|
def main_paul_best():
|
||||||
char_dict = dataset.get_character_dict()
|
char_dict = dataset.get_character_dict()
|
||||||
pauls_best_params = models.pauls_networks.best_config
|
pauls_best_params = models.pauls_networks.best_config
|
||||||
@ -80,15 +76,15 @@ def main_hyperband():
|
|||||||
"flow_features": [3],
|
"flow_features": [3],
|
||||||
"input_length": [40],
|
"input_length": [40],
|
||||||
# model params
|
# model params
|
||||||
"embedding_size": [16, 32, 64, 128, 256, 512],
|
"embedding_size": [8, 16, 32, 64, 128, 256],
|
||||||
"filter_embedding": [16, 32, 64, 128, 256, 512],
|
"filter_embedding": [8, 16, 32, 64, 128, 256],
|
||||||
"kernel_embedding": [1, 3, 5, 7, 9],
|
"kernel_embedding": [1, 3, 5, 7, 9],
|
||||||
"hidden_embedding": [16, 32, 64, 128, 256, 512],
|
"hidden_embedding": [8, 16, 32, 64, 128, 256],
|
||||||
"dropout": [0.5],
|
"dropout": [0.5],
|
||||||
"domain_features": [16, 32, 64, 128, 256, 512],
|
"domain_features": [8, 16, 32, 64, 128, 256],
|
||||||
"filter_main": [16, 32, 64, 128, 256, 512],
|
"filter_main": [8, 16, 32, 64, 128, 256],
|
||||||
"kernels_main": [1, 3, 5, 7, 9],
|
"kernels_main": [1, 3, 5, 7, 9],
|
||||||
"dense_main": [16, 32, 64, 128, 256, 512],
|
"dense_main": [8, 16, 32, 64, 128, 256],
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info("create training dataset")
|
logger.info("create training dataset")
|
||||||
|
Loading…
Reference in New Issue
Block a user