add sample weight metrics to fit function
This commit is contained in:
parent
e8473048cb
commit
33063f3081
19
main.py
19
main.py
@ -15,7 +15,7 @@ import models
|
|||||||
# create logger
|
# create logger
|
||||||
import visualize
|
import visualize
|
||||||
from arguments import get_model_args
|
from arguments import get_model_args
|
||||||
from utils import exists_or_make_path, get_custom_class_weights, load_model
|
from utils import exists_or_make_path, get_custom_class_weights, get_custom_sample_weights, load_model
|
||||||
|
|
||||||
logger = logging.getLogger('logger')
|
logger = logging.getLogger('logger')
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
@ -166,6 +166,14 @@ def main_train(param=None):
|
|||||||
logger.info("class weights: set default")
|
logger.info("class weights: set default")
|
||||||
custom_class_weights = None
|
custom_class_weights = None
|
||||||
|
|
||||||
|
if args.sample_weights:
|
||||||
|
logger.info("class weights: compute custom weights")
|
||||||
|
custom_sample_weights = get_custom_sample_weights(client_tr.value, server_tr)
|
||||||
|
logger.info(custom_class_weights)
|
||||||
|
else:
|
||||||
|
logger.info("class weights: set default")
|
||||||
|
custom_sample_weights = None
|
||||||
|
|
||||||
if not param:
|
if not param:
|
||||||
param = PARAMS
|
param = PARAMS
|
||||||
logger.info(f"Generator model with params: {param}")
|
logger.info(f"Generator model with params: {param}")
|
||||||
@ -205,7 +213,8 @@ def main_train(param=None):
|
|||||||
model.fit(features, labels,
|
model.fit(features, labels,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
class_weight=custom_class_weights)
|
class_weight=custom_class_weights,
|
||||||
|
sample_weight=custom_sample_weights)
|
||||||
|
|
||||||
logger.info("fix server model")
|
logger.info("fix server model")
|
||||||
model.get_layer("domain_cnn").trainable = False
|
model.get_layer("domain_cnn").trainable = False
|
||||||
@ -227,7 +236,8 @@ def main_train(param=None):
|
|||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
class_weight=custom_class_weights)
|
class_weight=custom_class_weights,
|
||||||
|
sample_weight=custom_sample_weights)
|
||||||
|
|
||||||
|
|
||||||
def main_retrain():
|
def main_retrain():
|
||||||
@ -406,7 +416,7 @@ def main_visualization():
|
|||||||
def plot_embedding(model_path, domain_embedding, data, domain_length):
|
def plot_embedding(model_path, domain_embedding, data, domain_length):
|
||||||
logger.info("visualize embedding")
|
logger.info("visualize embedding")
|
||||||
domain_encs, labels = dataset.load_or_generate_domains(data, domain_length)
|
domain_encs, labels = dataset.load_or_generate_domains(data, domain_length)
|
||||||
visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.pdf".format(model_path), method="svd")
|
visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.png".format(model_path), method="svd")
|
||||||
|
|
||||||
|
|
||||||
def main_visualize_all():
|
def main_visualize_all():
|
||||||
@ -641,7 +651,6 @@ def train_server_only():
|
|||||||
model.fit(features, labels,
|
model.fit(features, labels,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
validation_split=0.2,
|
|
||||||
callbacks=callbacks)
|
callbacks=callbacks)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user