add parameter for hyper band iteration, use hyperband results in new runs
This commit is contained in:
parent
903e81c931
commit
9b8ca8abab
@ -52,6 +52,8 @@ parser.add_argument("--init_epoch", action="store", dest="initial_epoch",
|
|||||||
parser.add_argument("--runs", action="store", dest="runs",
|
parser.add_argument("--runs", action="store", dest="runs",
|
||||||
default=20, type=int)
|
default=20, type=int)
|
||||||
|
|
||||||
|
parser.add_argument("--hyper_max_iter", action="store", dest="hyper_max_iter",
|
||||||
|
default=81, type=int)
|
||||||
|
|
||||||
# parser.add_argument("--samples", action="store", dest="samples",
|
# parser.add_argument("--samples", action="store", dest="samples",
|
||||||
# default=100000, type=int)
|
# default=100000, type=int)
|
||||||
|
16
main.py
16
main.py
@ -155,7 +155,7 @@ def main_paul_best():
|
|||||||
main_train(pauls_best_params)
|
main_train(pauls_best_params)
|
||||||
|
|
||||||
|
|
||||||
def main_hyperband(data, domain_length, window_size, model_type, result_file, dist_size="small"):
|
def main_hyperband(data, domain_length, window_size, model_type, result_file, max_iter, dist_size="small"):
|
||||||
param_dist = get_param_dist(dist_size)
|
param_dist = get_param_dist(dist_size)
|
||||||
|
|
||||||
logger.info("create training dataset")
|
logger.info("create training dataset")
|
||||||
@ -168,7 +168,7 @@ def main_hyperband(data, domain_length, window_size, model_type, result_file, di
|
|||||||
|
|
||||||
domain_tr, flow_tr, client_tr, server_tr = shuffle_training_data(domain_tr, flow_tr, client_tr, server_tr)
|
domain_tr, flow_tr, client_tr, server_tr = shuffle_training_data(domain_tr, flow_tr, client_tr, server_tr)
|
||||||
|
|
||||||
return run_hyperband(dist_size, domain_tr, flow_tr, client_tr, server_tr, 81, result_file)
|
return run_hyperband(dist_size, domain_tr, flow_tr, client_tr, server_tr, max_iter, result_file)
|
||||||
|
|
||||||
|
|
||||||
def run_hyperband(dist_size, domain, flow, client, server, max_iter, savefile):
|
def run_hyperband(dist_size, domain, flow, client, server, max_iter, savefile):
|
||||||
@ -208,9 +208,14 @@ def main_train(param=None):
|
|||||||
|
|
||||||
# call hyperband if used
|
# call hyperband if used
|
||||||
if args.hyperband_results:
|
if args.hyperband_results:
|
||||||
logger.info("start hyperband parameter search")
|
try:
|
||||||
hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, 81, args.hyperband_results)
|
hyper_results = joblib.load(args.hyperband_results)
|
||||||
|
except Exception:
|
||||||
|
logger.info("start hyperband parameter search")
|
||||||
|
hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, args.hyper_max_iter,
|
||||||
|
args.hyperband_results)
|
||||||
param = sorted(hyper_results, key=operator.itemgetter("loss"))[0]["params"]
|
param = sorted(hyper_results, key=operator.itemgetter("loss"))[0]["params"]
|
||||||
|
param["type"] = args.model_type
|
||||||
logger.info(f"select params from result: {param}")
|
logger.info(f"select params from result: {param}")
|
||||||
if not param:
|
if not param:
|
||||||
param = PARAMS
|
param = PARAMS
|
||||||
@ -815,7 +820,8 @@ def main():
|
|||||||
if "retrain" == args.mode:
|
if "retrain" == args.mode:
|
||||||
main_retrain()
|
main_retrain()
|
||||||
if "hyperband" == args.mode:
|
if "hyperband" == args.mode:
|
||||||
main_hyperband(args.data, args.domain_length, args.window, args.model_type, args.hyperband_results)
|
main_hyperband(args.data, args.domain_length, args.window, args.model_type, args.hyperband_results,
|
||||||
|
arg.hyper_max_iter)
|
||||||
if "test" == args.mode:
|
if "test" == args.mode:
|
||||||
main_test()
|
main_test()
|
||||||
if "fancy" == args.mode:
|
if "fancy" == args.mode:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user