add hyperband savefile config, minor change of parameter name
This commit is contained in:
parent
68254d6629
commit
371a1dad05
@ -15,6 +15,9 @@ parser.add_argument("--data", action="store", dest="train_data",
|
|||||||
parser.add_argument("--test", action="store", dest="test_data",
|
parser.add_argument("--test", action="store", dest="test_data",
|
||||||
default="data/full_future_dataset.csv.tar.gz")
|
default="data/full_future_dataset.csv.tar.gz")
|
||||||
|
|
||||||
|
parser.add_argument("--hyper_result", action="store", dest="hyperband_results",
|
||||||
|
default="")
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument("--model", action="store", dest="model_path",
|
parser.add_argument("--model", action="store", dest="model_path",
|
||||||
default="results/model_x")
|
default="results/model_x")
|
||||||
|
@ -7,6 +7,7 @@ from math import ceil, log
|
|||||||
from random import random as rng
|
from random import random as rng
|
||||||
from time import ctime, time
|
from time import ctime, time
|
||||||
|
|
||||||
|
import joblib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from keras.callbacks import EarlyStopping
|
from keras.callbacks import EarlyStopping
|
||||||
|
|
||||||
@ -24,7 +25,7 @@ def sample_params(param_distribution: dict):
|
|||||||
|
|
||||||
|
|
||||||
class Hyperband:
|
class Hyperband:
|
||||||
def __init__(self, param_distribution, X, y, max_iter=81):
|
def __init__(self, param_distribution, X, y, max_iter=81, savefile=None):
|
||||||
self.get_params = lambda: sample_params(param_distribution)
|
self.get_params = lambda: sample_params(param_distribution)
|
||||||
|
|
||||||
self.max_iter = max_iter # maximum iterations per configuration
|
self.max_iter = max_iter # maximum iterations per configuration
|
||||||
@ -39,6 +40,8 @@ class Hyperband:
|
|||||||
self.best_loss = np.inf
|
self.best_loss = np.inf
|
||||||
self.best_counter = -1
|
self.best_counter = -1
|
||||||
|
|
||||||
|
self.savefile = savefile
|
||||||
|
|
||||||
self.X = X
|
self.X = X
|
||||||
self.y = y
|
self.y = y
|
||||||
|
|
||||||
@ -143,4 +146,7 @@ class Hyperband:
|
|||||||
random_configs = [random_configs[i] for i in indices if not early_stops[i]]
|
random_configs = [random_configs[i] for i in indices if not early_stops[i]]
|
||||||
random_configs = random_configs[0:int(n_configs / self.eta)]
|
random_configs = random_configs[0:int(n_configs / self.eta)]
|
||||||
|
|
||||||
|
if self.savefile:
|
||||||
|
joblib.dump(self.results, self.savefile)
|
||||||
|
|
||||||
return self.results
|
return self.results
|
||||||
|
4
main.py
4
main.py
@ -63,7 +63,7 @@ PARAMS = {
|
|||||||
#
|
#
|
||||||
'dropout': 0.5, # currently fix
|
'dropout': 0.5, # currently fix
|
||||||
'domain_features': args.domain_embedding,
|
'domain_features': args.domain_embedding,
|
||||||
'embedding_size': args.embedding,
|
'embedding': args.embedding,
|
||||||
'flow_features': 3,
|
'flow_features': 3,
|
||||||
'filter_embedding': args.filter_embedding,
|
'filter_embedding': args.filter_embedding,
|
||||||
'dense_embedding': args.dense_embedding,
|
'dense_embedding': args.dense_embedding,
|
||||||
@ -132,7 +132,7 @@ def main_hyperband():
|
|||||||
[domain_tr, flow_tr],
|
[domain_tr, flow_tr],
|
||||||
[client_tr, server_tr])
|
[client_tr, server_tr])
|
||||||
results = hp.run()
|
results = hp.run()
|
||||||
joblib.dump(results, "hyperband.joblib")
|
joblib.dump(results, args.hyperband_results)
|
||||||
|
|
||||||
|
|
||||||
def main_train(param=None):
|
def main_train(param=None):
|
||||||
|
@ -9,7 +9,7 @@ def get_models_by_params(params: dict):
|
|||||||
# mainly embedding model
|
# mainly embedding model
|
||||||
network_type = params.get("type")
|
network_type = params.get("type")
|
||||||
network_depth = params.get("depth")
|
network_depth = params.get("depth")
|
||||||
embedding_size = params.get("embedding_size")
|
embedding_size = params.get("embedding")
|
||||||
input_length = params.get("input_length")
|
input_length = params.get("input_length")
|
||||||
filter_embedding = params.get("filter_embedding")
|
filter_embedding = params.get("filter_embedding")
|
||||||
kernel_embedding = params.get("kernel_embedding")
|
kernel_embedding = params.get("kernel_embedding")
|
||||||
|
Loading…
Reference in New Issue
Block a user