add hyperband savefile config, minor change of parameter name

This commit is contained in:
René Knaebel 2017-10-03 18:58:54 +02:00
parent 68254d6629
commit 371a1dad05
4 changed files with 13 additions and 4 deletions

View File

@ -15,6 +15,9 @@ parser.add_argument("--data", action="store", dest="train_data",
parser.add_argument("--test", action="store", dest="test_data", parser.add_argument("--test", action="store", dest="test_data",
default="data/full_future_dataset.csv.tar.gz") default="data/full_future_dataset.csv.tar.gz")
parser.add_argument("--hyper_result", action="store", dest="hyperband_results",
default="")
parser.add_argument("--model", action="store", dest="model_path", parser.add_argument("--model", action="store", dest="model_path",
default="results/model_x") default="results/model_x")

View File

@ -7,6 +7,7 @@ from math import ceil, log
from random import random as rng from random import random as rng
from time import ctime, time from time import ctime, time
import joblib
import numpy as np import numpy as np
from keras.callbacks import EarlyStopping from keras.callbacks import EarlyStopping
@ -24,7 +25,7 @@ def sample_params(param_distribution: dict):
class Hyperband: class Hyperband:
def __init__(self, param_distribution, X, y, max_iter=81): def __init__(self, param_distribution, X, y, max_iter=81, savefile=None):
self.get_params = lambda: sample_params(param_distribution) self.get_params = lambda: sample_params(param_distribution)
self.max_iter = max_iter # maximum iterations per configuration self.max_iter = max_iter # maximum iterations per configuration
@ -39,6 +40,8 @@ class Hyperband:
self.best_loss = np.inf self.best_loss = np.inf
self.best_counter = -1 self.best_counter = -1
self.savefile = savefile
self.X = X self.X = X
self.y = y self.y = y
@ -143,4 +146,7 @@ class Hyperband:
random_configs = [random_configs[i] for i in indices if not early_stops[i]] random_configs = [random_configs[i] for i in indices if not early_stops[i]]
random_configs = random_configs[0:int(n_configs / self.eta)] random_configs = random_configs[0:int(n_configs / self.eta)]
if self.savefile:
joblib.dump(self.results, self.savefile)
return self.results return self.results

View File

@ -63,7 +63,7 @@ PARAMS = {
# #
'dropout': 0.5, # currently fix 'dropout': 0.5, # currently fix
'domain_features': args.domain_embedding, 'domain_features': args.domain_embedding,
'embedding_size': args.embedding, 'embedding': args.embedding,
'flow_features': 3, 'flow_features': 3,
'filter_embedding': args.filter_embedding, 'filter_embedding': args.filter_embedding,
'dense_embedding': args.dense_embedding, 'dense_embedding': args.dense_embedding,
@ -132,7 +132,7 @@ def main_hyperband():
[domain_tr, flow_tr], [domain_tr, flow_tr],
[client_tr, server_tr]) [client_tr, server_tr])
results = hp.run() results = hp.run()
joblib.dump(results, "hyperband.joblib") joblib.dump(results, args.hyperband_results)
def main_train(param=None): def main_train(param=None):

View File

@ -9,7 +9,7 @@ def get_models_by_params(params: dict):
# mainly embedding model # mainly embedding model
network_type = params.get("type") network_type = params.get("type")
network_depth = params.get("depth") network_depth = params.get("depth")
embedding_size = params.get("embedding_size") embedding_size = params.get("embedding")
input_length = params.get("input_length") input_length = params.get("input_length")
filter_embedding = params.get("filter_embedding") filter_embedding = params.get("filter_embedding")
kernel_embedding = params.get("kernel_embedding") kernel_embedding = params.get("kernel_embedding")