add parameter for hyper band iteration, use hyperband results in new runs
This commit is contained in:
parent
903e81c931
commit
9b8ca8abab
@ -52,6 +52,8 @@ parser.add_argument("--init_epoch", action="store", dest="initial_epoch",
|
||||
parser.add_argument("--runs", action="store", dest="runs",
|
||||
default=20, type=int)
|
||||
|
||||
parser.add_argument("--hyper_max_iter", action="store", dest="hyper_max_iter",
|
||||
default=81, type=int)
|
||||
|
||||
# parser.add_argument("--samples", action="store", dest="samples",
|
||||
# default=100000, type=int)
|
||||
|
16
main.py
16
main.py
@ -155,7 +155,7 @@ def main_paul_best():
|
||||
main_train(pauls_best_params)
|
||||
|
||||
|
||||
def main_hyperband(data, domain_length, window_size, model_type, result_file, dist_size="small"):
|
||||
def main_hyperband(data, domain_length, window_size, model_type, result_file, max_iter, dist_size="small"):
|
||||
param_dist = get_param_dist(dist_size)
|
||||
|
||||
logger.info("create training dataset")
|
||||
@ -168,7 +168,7 @@ def main_hyperband(data, domain_length, window_size, model_type, result_file, di
|
||||
|
||||
domain_tr, flow_tr, client_tr, server_tr = shuffle_training_data(domain_tr, flow_tr, client_tr, server_tr)
|
||||
|
||||
return run_hyperband(dist_size, domain_tr, flow_tr, client_tr, server_tr, 81, result_file)
|
||||
return run_hyperband(dist_size, domain_tr, flow_tr, client_tr, server_tr, max_iter, result_file)
|
||||
|
||||
|
||||
def run_hyperband(dist_size, domain, flow, client, server, max_iter, savefile):
|
||||
@ -208,9 +208,14 @@ def main_train(param=None):
|
||||
|
||||
# call hyperband if used
|
||||
if args.hyperband_results:
|
||||
logger.info("start hyperband parameter search")
|
||||
hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, 81, args.hyperband_results)
|
||||
try:
|
||||
hyper_results = joblib.load(args.hyperband_results)
|
||||
except Exception:
|
||||
logger.info("start hyperband parameter search")
|
||||
hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, args.hyper_max_iter,
|
||||
args.hyperband_results)
|
||||
param = sorted(hyper_results, key=operator.itemgetter("loss"))[0]["params"]
|
||||
param["type"] = args.model_type
|
||||
logger.info(f"select params from result: {param}")
|
||||
if not param:
|
||||
param = PARAMS
|
||||
@ -815,7 +820,8 @@ def main():
|
||||
if "retrain" == args.mode:
|
||||
main_retrain()
|
||||
if "hyperband" == args.mode:
|
||||
main_hyperband(args.data, args.domain_length, args.window, args.model_type, args.hyperband_results)
|
||||
main_hyperband(args.data, args.domain_length, args.window, args.model_type, args.hyperband_results,
|
||||
arg.hyper_max_iter)
|
||||
if "test" == args.mode:
|
||||
main_test()
|
||||
if "fancy" == args.mode:
|
||||
|
Loading…
Reference in New Issue
Block a user