refactor class weights
This commit is contained in:
parent
461d4cab8f
commit
d58dbcb101
36
Makefile
36
Makefile
@ -1,27 +1,27 @@
|
||||
run:
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_client --epochs 2 --depth flat1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client --runs 1
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_client --epochs 2 \
|
||||
--filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
|
||||
--dense_embd 8 --domain_embd 8 --batch 64 --type final --model_output client --runs 1
|
||||
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_final --epochs 2 --depth flat1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output both --runs 1
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_final --epochs 2 \
|
||||
--filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
|
||||
--dense_embd 8 --domain_embd 8 --batch 64 --type final --model_output both --runs 1
|
||||
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_inter --epochs 2 --depth flat1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output both --runs 1
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_inter --epochs 2 \
|
||||
--filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
|
||||
--dense_embd 8 --domain_embd 8 --batch 64 --type inter --model_output both --runs 1
|
||||
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_soft --epochs 2 --depth flat1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type soft --model_output both --runs 1
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_soft --epochs 2 \
|
||||
--filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
|
||||
--dense_embd 8 --domain_embd 8 --batch 64 --type soft --model_output both --runs 1
|
||||
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_long --epochs 2 --depth flat1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type long --model_output both --runs 1
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_long --epochs 2 \
|
||||
--filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
|
||||
--dense_embd 8 --domain_embd 8 --batch 64 --type long --model_output both --runs 1
|
||||
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_staggered --epochs 2 --depth flat1 \
|
||||
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
|
||||
--dense_embd 16 --domain_embd 8 --batch 64 --type staggered --model_output both --runs 1
|
||||
python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_staggered --epochs 2 \
|
||||
--filter_embd 8 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 16 \
|
||||
--dense_embd 8 --domain_embd 8 --batch 64 --type staggered --model_output both --runs 1
|
||||
|
||||
|
||||
test:
|
||||
|
51
main.py
51
main.py
@ -163,11 +163,9 @@ def main_hyperband(data, domain_length, window_size, model_type, result_file, ma
|
||||
return run_hyperband(dist_size, domain_tr, flow_tr, client_tr, server_tr, max_iter, result_file)
|
||||
|
||||
|
||||
def run_hyperband(dist_size, domain, flow, client, server, max_iter, savefile):
|
||||
def run_hyperband(dist_size, features, labels, max_iter, savefile):
|
||||
param_dist = get_param_dist(dist_size)
|
||||
hp = hyperband.Hyperband(param_dist,
|
||||
[domain, flow],
|
||||
[client, server],
|
||||
hp = hyperband.Hyperband(param_dist, features, labels,
|
||||
max_iter=max_iter,
|
||||
savefile=savefile)
|
||||
results = hp.run()
|
||||
@ -191,7 +189,27 @@ def load_data(data, domain_length, window_size, model_type, shuffled=False):
|
||||
return domain_tr, flow_tr, client_tr, server_tr
|
||||
|
||||
|
||||
def get_weighting(class_weights, sample_weights, client, server):
|
||||
def load_training_data(data, model_output, domain_length, window_size, model_type, shuffled=False):
|
||||
domain_tr, flow_tr, client_tr, server_tr = load_data(data, domain_length,
|
||||
window_size, model_type, shuffled)
|
||||
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
|
||||
if model_output == "both":
|
||||
labels = {"client": client_tr.value, "server": server_tr}
|
||||
loss_weights = {"client": 1.0, "server": 1.0}
|
||||
elif model_output == "client":
|
||||
labels = {"client": client_tr.value}
|
||||
loss_weights = {"client": 1.0}
|
||||
elif model_output == "server":
|
||||
labels = {"server": server_tr}
|
||||
loss_weights = {"server": 1.0}
|
||||
else:
|
||||
raise ValueError("unknown model output")
|
||||
return features, labels, loss_weights
|
||||
|
||||
|
||||
def get_weighting(class_weights, sample_weights, labels):
|
||||
return None, None
|
||||
client, server = labels["client"], labels["server"]
|
||||
if class_weights:
|
||||
logger.info("class weights: compute custom weights")
|
||||
custom_class_weights = get_custom_class_weights(client, server)
|
||||
@ -217,16 +235,16 @@ def main_train(param=None):
|
||||
logger.info(f"Use command line arguments: {args}")
|
||||
|
||||
# data preparation
|
||||
domain_tr, flow_tr, client_tr, server_tr = load_data(args.data, args.domain_length,
|
||||
features, labels, loss_weights = load_training_data(args.data, args.model_output, args.domain_length,
|
||||
args.window, args.model_type)
|
||||
|
||||
# call hyperband if used
|
||||
# call hyperband if results are not accessible
|
||||
if args.hyperband_results:
|
||||
try:
|
||||
hyper_results = joblib.load(args.hyperband_results)
|
||||
except Exception:
|
||||
logger.info("start hyperband parameter search")
|
||||
hyper_results = run_hyperband("small", domain_tr, flow_tr, client_tr, server_tr, args.hyper_max_iter,
|
||||
hyper_results = run_hyperband("small", features, labels, args.hyper_max_iter,
|
||||
args.hyperband_results)
|
||||
param = sorted(hyper_results, key=operator.itemgetter("loss"))[0]["params"]
|
||||
param["type"] = args.model_type
|
||||
@ -235,8 +253,8 @@ def main_train(param=None):
|
||||
param = PARAMS
|
||||
|
||||
# custom class or sample weights
|
||||
custom_class_weights, custom_sample_weights = get_weighting(args.class_weights, args.sample_weights,
|
||||
client_tr.value, server_tr)
|
||||
# TODO: should throw an error when using weights with only the client labels
|
||||
custom_class_weights, custom_sample_weights = get_weighting(args.class_weights, args.sample_weights, labels)
|
||||
|
||||
for i in range(args.runs):
|
||||
model_path = os.path.join(args.model_path, f"clf_{i}.h5")
|
||||
@ -259,19 +277,6 @@ def main_train(param=None):
|
||||
logger.info(f"Generator model with params: {param}")
|
||||
model = models.get_models_by_params(param)
|
||||
|
||||
features = {"ipt_domains": domain_tr.value, "ipt_flows": flow_tr.value}
|
||||
if args.model_output == "both":
|
||||
labels = {"client": client_tr.value, "server": server_tr}
|
||||
loss_weights = {"client": 1.0, "server": 1.0}
|
||||
elif args.model_output == "client":
|
||||
labels = {"client": client_tr.value}
|
||||
loss_weights = {"client": 1.0}
|
||||
elif args.model_output == "server":
|
||||
labels = {"server": server_tr}
|
||||
loss_weights = {"server": 1.0}
|
||||
else:
|
||||
raise ValueError("unknown model output")
|
||||
|
||||
logger.info(f"select model: {args.model_type}")
|
||||
if args.model_type == "staggered":
|
||||
logger.info("compile and pre-train server model")
|
||||
|
Loading…
Reference in New Issue
Block a user