From 4fc2f0c925e46b079aeda34895ab7addb1926593 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Fri, 10 Nov 2017 10:18:13 +0100 Subject: [PATCH] extract weighting function --- dataset.py | 4 ++++ main.py | 51 ++++++++++++++++++++++++++++------------------ models/__init__.py | 10 +++++++++ 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/dataset.py b/dataset.py index 20482fe..5912a9d 100644 --- a/dataset.py +++ b/dataset.py @@ -263,8 +263,11 @@ def load_or_generate_domains(train_data, domain_length): fn = f"{train_data}_domains.gz" try: + logger.info(f"Load file {fn}.") user_flow_df = pd.read_csv(fn) + logger.info(f"File successfully loaded.") except FileNotFoundError: + logger.info(f"File {fn} not found, recreate.") user_flow_df = get_user_flow_data(train_data) # user_flow_df.reset_index(inplace=True) user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0, @@ -279,6 +282,7 @@ def load_or_generate_domains(train_data, domain_length): user_flow_df.to_csv(fn, compression="gzip") + logger.info(f"Extract features from domains") domain_encs = user_flow_df.domain.apply(lambda d: get_domain_features(d, domain_length)) domain_encs = np.stack(domain_encs) diff --git a/main.py b/main.py index 430bf95..ee9c7f8 100644 --- a/main.py +++ b/main.py @@ -197,6 +197,26 @@ def load_data(data, domain_length, window_size, model_type): return domain_tr, flow_tr, client_tr, server_tr +def get_weighting(class_weights, sample_weights, client, server): + if class_weights: + logger.info("class weights: compute custom weights") + custom_class_weights = get_custom_class_weights(client, server) + logger.info(custom_class_weights) + else: + logger.info("class weights: set default") + custom_class_weights = None + + if sample_weights: + logger.info("class weights: compute custom weights") + custom_sample_weights = get_custom_sample_weights(client, server) + logger.info(custom_sample_weights) + else: + logger.info("class weights: set default") + custom_sample_weights = None + + return custom_class_weights, custom_sample_weights + + def main_train(param=None): logger.info(f"Create model path {args.model_path}") exists_or_make_path(args.model_path) @@ -220,6 +240,10 @@ def main_train(param=None): if not param: param = PARAMS + # custom class or sample weights + custom_class_weights, custom_sample_weights = get_weighting(args.class_weights, args.sample_weights, + client_tr.value, server_tr) + for i in range(args.runs): model_path = os.path.join(args.model_path, f"clf_{i}.h5") train_log_path = os.path.join(args.model_path, f"train_{i}.log.csv") @@ -238,23 +262,6 @@ def main_train(param=None): verbose=False)) custom_metrics = models.get_metric_functions() - # custom class or sample weights - if args.class_weights: - logger.info("class weights: compute custom weights") - custom_class_weights = get_custom_class_weights(client_tr.value, server_tr) - logger.info(custom_class_weights) - else: - logger.info("class weights: set default") - custom_class_weights = None - - if args.sample_weights: - logger.info("class weights: compute custom weights") - custom_sample_weights = get_custom_sample_weights(client_tr.value, server_tr) - logger.info(custom_class_weights) - else: - logger.info("class weights: set default") - custom_sample_weights = None - logger.info(f"Generator model with params: {param}") model = models.get_models_by_params(param) @@ -372,9 +379,13 @@ def main_retrain(): def main_test(): - logger.info("start test: load data") + logger.info("load test data") domain_val, flow_val, _, _, _, _ = dataset.load_or_generate_raw_h5data(args.data, args.domain_length, args.window) + logger.info("load test domains") domain_encs, _, _ = dataset.load_or_generate_domains(args.data, args.domain_length) + + def get_dir(path): + return os.path.split(os.path.normpath(path)) results = {} for model_path in args.model_paths: @@ -398,8 +409,8 @@ def main_test(): domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1) results["domain_embds"] = domain_embeddings - - dataset.save_predictions(get_dir(model_path)[0], results) + # store results every round - safety first! + dataset.save_predictions(get_dir(model_path)[0], results) def main_visualization(): diff --git a/models/__init__.py b/models/__init__.py index 317bb0d..925c276 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,8 +1,17 @@ +from collections import namedtuple + from keras.models import Model from . import networks from .metrics import * +NetworkParameters = namedtuple("NetworkParameters", [ + "type", "flow_features", "window_size", "domain_length", "output", + "embedding_size", + "domain_filter", "domain_kernel", "domain_dense", "domain_dropout", + "main_filter", "main_kernel", "main_dense", "main_dropout", +]) + def create_model(model, output_type): if output_type == "both": @@ -14,6 +23,7 @@ def create_model(model, output_type): def get_models_by_params(params: dict): + K.clear_session() # decomposing param section # mainly embedding model network_type = params.get("type")