From d33c9f44ecdf264856c4eb2de0d005dab46743cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Sun, 16 Jul 2017 18:49:14 +0200 Subject: [PATCH] fix chunks per user function bug caused by numpy version of array_split --- dataset.py | 14 ++++++++++---- hyperband.py | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/dataset.py b/dataset.py index 732da4b..c4d3e62 100644 --- a/dataset.py +++ b/dataset.py @@ -71,9 +71,13 @@ def get_user_chunks(user_flow, window=10): # domains.pop(-1) # flows.pop(-1) # return domains, flows + result = [] chunk_size = (len(user_flow) // window) - last_inrange = chunk_size * window - return np.split(user_flow.head(last_inrange), chunk_size) if chunk_size else [] + for i in range(chunk_size): + result.append(user_flow.iloc[i * window:(i + 1) * window]) + if result and len(result[-1]) != window: + result.pop() + return result def get_domain_features(domain, vocab: dict, max_length=40): @@ -153,7 +157,9 @@ def create_dataset_from_lists(chunks, vocab, max_len): logger.info(" compute domain features") domain_features = [] for ds in tqdm(map(lambda f: f.domain, chunks)): - assert min(np.atleast_3d(ds).shape) > 0, f"shape of 0 for {ds}" + # TODO: fix this correctly + # assert min(np.atleast_3d(ds).shape) > 0, f"shape of 0 for {ds}" + if not ds: continue domain_features.append(np.apply_along_axis(get_domain_features_reduced, 2, np.atleast_3d(ds))) domain_features = np.concatenate(domain_features, 0) logger.info(" compute flow features") @@ -161,7 +167,7 @@ def create_dataset_from_lists(chunks, vocab, max_len): logger.info(" select hits") hits = np.max(np.stack(map(lambda f: f.virusTotalHits, chunks)), axis=1) logger.info(" select names") - names = np.unique(np.stack(map(lambda f: f.user_hash, chunks)), axis=1) + names = np.unique(np.stack(map(lambda f: f.user_hash, chunks))) logger.info(" select servers") servers = np.max(np.stack(map(lambda f: f.serverLabel, chunks)), axis=1) logger.info(" select trusted hits") diff --git a/hyperband.py b/hyperband.py index c9f30c2..d01ce11 100644 --- a/hyperband.py +++ b/hyperband.py @@ -68,7 +68,7 @@ class Hyperband: r = self.max_iter * self.eta ** (-s) # n random configurations - T = [self.get_params() for i in range(n)] + T = [self.get_params() for _ in range(n)] for i in range((s + 1) - int(skip_last)): # changed from s + 1