change hyperband to count minimal val_loss over all losses

This commit is contained in:
René Knaebel 2017-10-05 12:55:46 +02:00
parent 371a1dad05
commit b24fa770f9
5 changed files with 121 additions and 110 deletions

View File

@ -1,65 +1,65 @@
run: run:
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_1 --epochs 2 --depth flat1 \ python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_both_1 --epochs 2 --depth flat1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_2 --epochs 2 --depth flat1 \ python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_both_2 --epochs 2 --depth flat1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_3 --epochs 2 --depth deep1 \ python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_both_3 --epochs 2 --depth deep1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output both
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_4 --epochs 2 --depth deep1 \ python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_both_4 --epochs 2 --depth deep1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output both
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_both_5 --epochs 2 --depth flat2 \ python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_both_5 --epochs 2 --depth flat2 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type staggered --model_output both --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type staggered --model_output both
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_1 --epochs 2 --depth flat2 \ python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_client_1 --epochs 2 --depth flat2 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output client --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type final --model_output client
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_2 --epochs 2 --depth flat2 \ python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_client_2 --epochs 2 --depth flat2 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output client --dense_embd 16 --domain_embd 8 --batch 64 --type inter --model_output client
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_3 --epochs 2 --depth deep1 \ python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_client_3 --epochs 2 --depth deep1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client --dense_embd 16 --domain_embd 8 --batch 64 --type final --model_output client
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test/test_client_4 --epochs 2 --depth deep1 \ python3 main.py --mode train --data data/rk_mini.csv.gz --model results/test/test_client_4 --epochs 2 --depth deep1 \
--filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \ --filter_embd 32 --kernel_embd 3 --filter_main 16 --kernel_main 3 --dense_main 32 \
--dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output client --dense_embd 16 --domain_embd 8 --batch 64 --balanced_weights --type inter --model_output client
test: test:
python3 main.py --mode test --batch 128 --models results/test/test_both_* --test data/rk_mini.csv.gz --model_output both python3 main.py --mode test --batch 128 --models results/test/test_both_* --data data/rk_mini.csv.gz --model_output both
python3 main.py --mode test --batch 128 --models results/test/test_client_* --test data/rk_mini.csv.gz --model_output client python3 main.py --mode test --batch 128 --models results/test/test_client_* --data data/rk_mini.csv.gz --model_output client
fancy: fancy:
python3 main.py --mode fancy --batch 128 --model results/test/test_both_1 --test data/rk_mini.csv.gz python3 main.py --mode fancy --batch 128 --model results/test/test_both_1 --data data/rk_mini.csv.gz
python3 main.py --mode fancy --batch 128 --model results/test/test_both_2 --test data/rk_mini.csv.gz python3 main.py --mode fancy --batch 128 --model results/test/test_both_2 --data data/rk_mini.csv.gz
python3 main.py --mode fancy --batch 128 --model results/test/test_both_3 --test data/rk_mini.csv.gz python3 main.py --mode fancy --batch 128 --model results/test/test_both_3 --data data/rk_mini.csv.gz
python3 main.py --mode fancy --batch 128 --model results/test/test_both_4 --test data/rk_mini.csv.gz python3 main.py --mode fancy --batch 128 --model results/test/test_both_4 --data data/rk_mini.csv.gz
python3 main.py --mode fancy --batch 128 --model results/test/test_both_5 --test data/rk_mini.csv.gz python3 main.py --mode fancy --batch 128 --model results/test/test_both_5 --data data/rk_mini.csv.gz
python3 main.py --mode fancy --batch 128 --model results/test/test_client_1 --test data/rk_mini.csv.gz python3 main.py --mode fancy --batch 128 --model results/test/test_client_1 --data data/rk_mini.csv.gz
python3 main.py --mode fancy --batch 128 --model results/test/test_client_2 --test data/rk_mini.csv.gz python3 main.py --mode fancy --batch 128 --model results/test/test_client_2 --data data/rk_mini.csv.gz
python3 main.py --mode fancy --batch 128 --model results/test/test_client_3 --test data/rk_mini.csv.gz python3 main.py --mode fancy --batch 128 --model results/test/test_client_3 --data data/rk_mini.csv.gz
python3 main.py --mode fancy --batch 128 --model results/test/test_client_4 --test data/rk_mini.csv.gz python3 main.py --mode fancy --batch 128 --model results/test/test_client_4 --data data/rk_mini.csv.gz
all-fancy: all-fancy:
python3 main.py --mode all_fancy --batch 128 --models results/test/test* --test data/rk_mini.csv.gz \ python3 main.py --mode all_fancy --batch 128 --models results/test/test* --data data/rk_mini.csv.gz \
--out-prefix results/test/ --out-prefix results/test/
hyper: hyper:

View File

@ -71,7 +71,7 @@ class Hyperband:
shuffle=True, shuffle=True,
validation_split=0.4) validation_split=0.4)
return {"loss": history.history['val_loss'][-1], return {"loss": np.min(history.history['val_loss']),
"early_stop": len(history.history["loss"]) < n_iterations} "early_stop": len(history.history["loss"]) < n_iterations}
# can be called multiple times # can be called multiple times

151
main.py
View File

@ -5,7 +5,8 @@ import numpy as np
import pandas as pd import pandas as pd
import tensorflow as tf import tensorflow as tf
from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint
from keras.models import Model, load_model as load_keras_model from keras.models import Model
from sklearn.metrics import confusion_matrix
import arguments import arguments
import dataset import dataset
@ -14,7 +15,7 @@ import models
# create logger # create logger
import visualize import visualize
from arguments import get_model_args from arguments import get_model_args
from utils import exists_or_make_path, get_custom_class_weights from utils import exists_or_make_path, get_custom_class_weights, load_model
logger = logging.getLogger('logger') logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -85,19 +86,13 @@ def create_model(model, output_type):
raise Exception("unknown model output") raise Exception("unknown model output")
def load_model(path, custom_objects=None):
clf = load_keras_model(path, custom_objects)
embd = clf.get_layer("domain_cnn").layer
return embd, clf
def main_paul_best(): def main_paul_best():
pauls_best_params = models.pauls_networks.best_config pauls_best_params = models.pauls_networks.best_config
main_train(pauls_best_params) main_train(pauls_best_params)
def main_hyperband(): def main_hyperband():
params = { param_dist = {
# static params # static params
"type": [args.model_type], "type": [args.model_type],
"depth": [args.model_depth], "depth": [args.model_depth],
@ -119,8 +114,8 @@ def main_hyperband():
} }
logger.info("create training dataset") logger.info("create training dataset")
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.train_h5data, domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
args.train_data, args.data,
args.domain_length, args.domain_length,
args.window) args.window)
server_tr = np.max(server_windows_tr, axis=1) server_tr = np.max(server_windows_tr, axis=1)
@ -128,11 +123,14 @@ def main_hyperband():
if args.model_type in ("inter", "staggered"): if args.model_type in ("inter", "staggered"):
server_tr = np.expand_dims(server_windows_tr, 2) server_tr = np.expand_dims(server_windows_tr, 2)
hp = hyperband.Hyperband(params, hp = hyperband.Hyperband(param_dist,
[domain_tr, flow_tr], [domain_tr, flow_tr],
[client_tr, server_tr]) [client_tr, server_tr],
max_iter=81,
savefile=args.hyperband_results)
results = hp.run() results = hp.run()
joblib.dump(results, args.hyperband_results)
return results
def main_train(param=None): def main_train(param=None):
@ -140,8 +138,8 @@ def main_train(param=None):
exists_or_make_path(args.model_path) exists_or_make_path(args.model_path)
logger.info(f"Use command line arguments: {args}") logger.info(f"Use command line arguments: {args}")
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.train_h5data, domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
args.train_data, args.data,
args.domain_length, args.domain_length,
args.window) args.window)
logger.info("define callbacks") logger.info("define callbacks")
@ -237,8 +235,8 @@ def main_retrain():
logger.info(f"Use command line arguments: {args}") logger.info(f"Use command line arguments: {args}")
exists_or_make_path(args.model_destination) exists_or_make_path(args.model_destination)
domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.train_h5data, domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.data,
args.train_data, args.data,
args.domain_length, args.domain_length,
args.window) args.window)
logger.info("define callbacks") logger.info("define callbacks")
@ -265,7 +263,7 @@ def main_retrain():
custom_class_weights = None custom_class_weights = None
logger.info(f"Load pretrained model") logger.info(f"Load pretrained model")
embedding, model = load_model(source, custom_objects=models.get_metrics()) embedding, model = load_model(source, custom_objects=models.get_custom_objects())
if args.model_type in ("inter", "staggered"): if args.model_type in ("inter", "staggered"):
server_tr = np.expand_dims(server_windows_tr, 2) server_tr = np.expand_dims(server_windows_tr, 2)
@ -293,16 +291,16 @@ def main_retrain():
def main_test(): def main_test():
logger.info("start test: load data") logger.info("start test: load data")
domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.test_h5data, domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data,
args.test_data, args.data,
args.domain_length, args.domain_length,
args.window) args.window)
domain_encs, _ = dataset.load_or_generate_domains(args.test_data, args.domain_length) domain_encs, _ = dataset.load_or_generate_domains(args.data, args.domain_length)
for model_args in get_model_args(args): for model_args in get_model_args(args):
results = {} results = {}
logger.info(f"process model {model_args['model_path']}") logger.info(f"process model {model_args['model_path']}")
embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics()) embd_model, clf_model = load_model(model_args["clf_model"], custom_objects=models.get_custom_objects())
pred = clf_model.predict([domain_val, flow_val], pred = clf_model.predict([domain_val, flow_val],
batch_size=args.batch_size, batch_size=args.batch_size,
@ -324,8 +322,28 @@ def main_test():
def main_visualization(): def main_visualization():
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data, def plot_model(clf_model, path):
args.test_data, embd, model = load_model(clf_model, custom_objects=models.get_custom_objects())
visualize.plot_model_as(embd, os.path.join(path, "model_embd.pdf"))
visualize.plot_model_as(model, os.path.join(path, "model_clf.pdf"))
def vis(model_name, model_path, df, df_paul, aggregation, curve):
visualize.plot_clf()
if aggregation == "user":
df = df.groupby(df.names).max()
df_paul = df_paul.groupby(df_paul.names).max()
if curve == "prc":
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_name)
visualize.plot_precision_recall(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
elif curve == "roc":
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_name)
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
visualize.plot_legend()
visualize.plot_save("{}/{}_{}.png".format(model_path, aggregation, curve))
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
args.data,
args.domain_length, args.domain_length,
args.window) args.window)
@ -343,11 +361,9 @@ def main_visualization():
"hits_vt": paul["testLabel"].flatten(), "hits_trusted": paul["testHits"].flatten() "hits_vt": paul["testLabel"].flatten(), "hits_trusted": paul["testHits"].flatten()
}) })
df_paul["client_val"] = np.logical_or(df_paul.hits_vt == 1.0, df_paul.hits_trusted >= 3) df_paul["client_val"] = np.logical_or(df_paul.hits_vt == 1.0, df_paul.hits_trusted >= 3)
df_paul_user = df_paul.groupby(df_paul.names).max()
logger.info("plot model") logger.info("plot model")
embd, model = load_model(args.clf_model, custom_objects=models.get_metrics()) plot_model(args.clf_model, args.model_path)
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
# logger.info("plot training curve") # logger.info("plot training curve")
# logs = pd.read_csv(args.train_log) # logs = pd.read_csv(args.train_log)
@ -359,31 +375,15 @@ def main_visualization():
# else: # else:
# logger.warning("Error while plotting training curves") # logger.warning("Error while plotting training curves")
logger.info("plot pr curve") logger.info("plot window prc")
visualize.plot_clf() vis(args.model_name, args.model_path, df, df_paul, "window", "prc")
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), args.model_name) logger.info("plot window roc")
visualize.plot_precision_recall(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul") vis(args.model_name, args.model_path, df, df_paul, "window", "roc")
visualize.plot_legend() logger.info("plot user prc")
visualize.plot_save("{}/window_client_prc.png".format(args.model_path)) vis(args.model_name, args.model_path, df, df_paul, "user", "prc")
logger.info("plot user roc")
vis(args.model_name, args.model_path, df, df_paul, "user", "roc")
logger.info("plot roc curve")
visualize.plot_clf()
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), args.model_name)
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
visualize.plot_legend()
visualize.plot_save("{}/window_client_roc.png".format(args.model_path))
visualize.plot_clf()
visualize.plot_precision_recall(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix(), args.model_name)
visualize.plot_precision_recall(df_paul_user.client_val.as_matrix(), df_paul_user.client_pred.as_matrix(), "paul")
visualize.plot_legend()
visualize.plot_save("{}/user_client_prc.png".format(args.model_path))
visualize.plot_clf()
visualize.plot_roc_curve(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix(), args.model_name)
visualize.plot_roc_curve(df_paul_user.client_val.as_matrix(), df_paul_user.client_pred.as_matrix(), "paul")
visualize.plot_legend()
visualize.plot_save("{}/user_client_roc.png".format(args.model_path))
# absolute values # absolute values
visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(), visualize.plot_confusion_matrix(df.client_val.as_matrix(), df.client_pred.as_matrix().round(),
"{}/client_cov.png".format(args.model_path), "{}/client_cov.png".format(args.model_path),
@ -398,25 +398,18 @@ def main_visualization():
visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(), visualize.plot_confusion_matrix(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix().round(),
"{}/user_cov_norm.png".format(args.model_path), "{}/user_cov_norm.png".format(args.model_path),
normalize=True, title="User Confusion Matrix") normalize=True, title="User Confusion Matrix")
logger.info("visualize embedding") plot_embedding(args.model_path, results["domain_embds"], args.data, args.domain_length)
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
domain_embedding = results["domain_embds"]
visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.png".format(args.model_path), method="svd")
visualize.plot_embedding(domain_embedding, labels, path="{}/embd_tsne.png".format(args.model_path), method="tsne")
def plot_embedding(): def plot_embedding(model_path, domain_embedding, data, domain_length):
logger.info("visualize embedding") logger.info("visualize embedding")
results = dataset.load_predictions(args.model_path) domain_encs, labels = dataset.load_or_generate_domains(data, domain_length)
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length) visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.png".format(model_path), method="svd")
domain_embedding = results["domain_embds"]
visualize.plot_embedding(domain_embedding, labels, path="{}/embd_svd.png".format(args.model_path), method="svd")
visualize.plot_embedding(domain_embedding, labels, path="{}/embd_tsne.png".format(args.model_path), method="tsne")
def main_visualize_all(): def main_visualize_all():
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data, _, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
args.test_data, args.data,
args.domain_length, args.domain_length,
args.window) args.window)
@ -480,8 +473,8 @@ import joblib
def main_beta(): def main_beta():
_, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.test_h5data, _, _, name_val, hits_vt, hits_trusted, server_val = dataset.load_or_generate_raw_h5data(args.data,
args.test_data, args.data,
args.domain_length, args.domain_length,
args.window) args.window)
path, model_prefix = os.path.split(os.path.normpath(args.output_prefix)) path, model_prefix = os.path.split(os.path.normpath(args.output_prefix))
@ -489,7 +482,7 @@ def main_beta():
results = joblib.load(f"{path}/curves.joblib") results = joblib.load(f"{path}/curves.joblib")
except Exception: except Exception:
results = {} results = {}
results[model_prefix] = {} results[model_prefix] = {"all": {}}
def load_df(path): def load_df(path):
res = dataset.load_predictions(path) res = dataset.load_predictions(path)
@ -514,7 +507,9 @@ def main_beta():
for model_args in get_model_args(args): for model_args in get_model_args(args):
df = load_df(model_args["model_path"]) df = load_df(model_args["model_path"])
predictions.append(df.client_pred.as_matrix()) predictions.append(df.client_pred.as_matrix())
results[model_prefix]["window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), predictions) results[model_prefix][model_args["model_name"]] = confusion_matrix(df.client_val.as_matrix(),
df.client_pred.as_matrix().round())
results[model_prefix]["all"]["window_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), predictions)
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean") visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
visualize.plot_pr_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul") visualize.plot_pr_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
visualize.plot_legend() visualize.plot_legend()
@ -526,7 +521,9 @@ def main_beta():
for model_args in get_model_args(args): for model_args in get_model_args(args):
df = load_df(model_args["model_path"]) df = load_df(model_args["model_path"])
predictions.append(df.client_pred.as_matrix()) predictions.append(df.client_pred.as_matrix())
results[model_prefix]["window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), predictions) results[model_prefix][model_args["model_name"]] = confusion_matrix(df.client_val.as_matrix(),
df.client_pred.as_matrix().round())
results[model_prefix]["all"]["window_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), predictions)
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean") visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
visualize.plot_roc_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul") visualize.plot_roc_mean(df_paul.client_val.as_matrix(), [df_paul.client_pred.as_matrix()], "paul")
visualize.plot_legend() visualize.plot_legend()
@ -539,7 +536,9 @@ def main_beta():
df = load_df(model_args["model_path"]) df = load_df(model_args["model_path"])
df = df.groupby(df.names).max() df = df.groupby(df.names).max()
predictions.append(df.client_pred.as_matrix()) predictions.append(df.client_pred.as_matrix())
results[model_prefix]["user_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), predictions) results[model_prefix][model_args["model_name"]] = confusion_matrix(df.client_val.as_matrix(),
df.client_pred.as_matrix().round())
results[model_prefix]["all"]["user_prc"] = visualize.calc_pr_mean(df.client_val.as_matrix(), predictions)
visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean") visualize.plot_pr_mean(df.client_val.as_matrix(), predictions, "mean")
visualize.plot_pr_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul") visualize.plot_pr_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
visualize.plot_legend() visualize.plot_legend()
@ -552,7 +551,7 @@ def main_beta():
df = load_df(model_args["model_path"]) df = load_df(model_args["model_path"])
df = df.groupby(df.names).max() df = df.groupby(df.names).max()
predictions.append(df.client_pred.as_matrix()) predictions.append(df.client_pred.as_matrix())
results[model_prefix]["user_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), predictions) results[model_prefix]["all"]["user_roc"] = visualize.calc_roc_mean(df.client_val.as_matrix(), predictions)
visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean") visualize.plot_roc_mean(df.client_val.as_matrix(), predictions, "mean")
visualize.plot_roc_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul") visualize.plot_roc_mean(df_paul_user.client_val.as_matrix(), [df_paul_user.client_pred.as_matrix()], "paul")
visualize.plot_legend() visualize.plot_legend()
@ -576,7 +575,7 @@ def plot_overall_result():
logger.info(f"plot {vis}") logger.info(f"plot {vis}")
visualize.plot_clf() visualize.plot_clf()
for model_key in results.keys(): for model_key in results.keys():
ys_mean, ys_std, score = results[model_key][vis] ys_mean, ys_std, score = results[model_key]["all"][vis]
plt.plot(x, ys_mean, label=f"{model_key} - {score:5.4}") plt.plot(x, ys_mean, label=f"{model_key} - {score:5.4}")
plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, alpha=0.2) plt.fill_between(x, ys_mean - ys_std, ys_mean + ys_std, alpha=0.2)
if vis.endswith("prc"): if vis.endswith("prc"):
@ -604,14 +603,8 @@ def main():
main_visualization() main_visualization()
if "all_fancy" == args.mode: if "all_fancy" == args.mode:
main_visualize_all() main_visualize_all()
if "embd" == args.mode:
plot_embedding()
if "paul" == args.mode:
main_paul_best()
if "beta" == args.mode: if "beta" == args.mode:
main_beta() main_beta()
if "beta_all" == args.mode:
plot_overall_result()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -3,6 +3,7 @@ from operator import itemgetter
import joblib import joblib
import numpy as np import numpy as np
from keras.models import load_model as load_keras_model
from sklearn.utils import class_weight from sklearn.utils import class_weight
@ -27,3 +28,15 @@ def get_custom_sample_weights(client, server):
def load_ordered_hyperband_results(path): def load_ordered_hyperband_results(path):
results = joblib.load(path) results = joblib.load(path)
return sorted(results, itemgetter("loss")) return sorted(results, itemgetter("loss"))
def load_model(path, custom_objects=None):
clf = load_keras_model(path, custom_objects)
try:
embd = clf.get_layer("domain_cnn").layer
except Exception:
# in some version i forgot to specify domain_cnn
# this bug fix is for certain compatibility
embd = clf.layers[1].layer
return embd, clf

View File

@ -38,6 +38,7 @@ def plot_clf():
def plot_save(path, dpi=300): def plot_save(path, dpi=300):
plt.title(path)
fig = plt.gcf() fig = plt.gcf()
fig.set_size_inches(18.5, 10.5) fig.set_size_inches(18.5, 10.5)
fig.savefig(path, dpi=dpi) fig.savefig(path, dpi=dpi)
@ -48,6 +49,10 @@ def plot_legend():
plt.legend() plt.legend()
def mathews_correlation_curve(y, y_pred):
pass
def plot_precision_recall(y, y_pred, label=""): def plot_precision_recall(y, y_pred, label=""):
y = y.flatten() y = y.flatten()
y_pred = y_pred.flatten() y_pred = y_pred.flatten()