fix chunks per user function bug caused by numpy version of array_split
This commit is contained in:
parent
844494eca9
commit
d33c9f44ec
14
dataset.py
14
dataset.py
@ -71,9 +71,13 @@ def get_user_chunks(user_flow, window=10):
|
|||||||
# domains.pop(-1)
|
# domains.pop(-1)
|
||||||
# flows.pop(-1)
|
# flows.pop(-1)
|
||||||
# return domains, flows
|
# return domains, flows
|
||||||
|
result = []
|
||||||
chunk_size = (len(user_flow) // window)
|
chunk_size = (len(user_flow) // window)
|
||||||
last_inrange = chunk_size * window
|
for i in range(chunk_size):
|
||||||
return np.split(user_flow.head(last_inrange), chunk_size) if chunk_size else []
|
result.append(user_flow.iloc[i * window:(i + 1) * window])
|
||||||
|
if result and len(result[-1]) != window:
|
||||||
|
result.pop()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_domain_features(domain, vocab: dict, max_length=40):
|
def get_domain_features(domain, vocab: dict, max_length=40):
|
||||||
@ -153,7 +157,9 @@ def create_dataset_from_lists(chunks, vocab, max_len):
|
|||||||
logger.info(" compute domain features")
|
logger.info(" compute domain features")
|
||||||
domain_features = []
|
domain_features = []
|
||||||
for ds in tqdm(map(lambda f: f.domain, chunks)):
|
for ds in tqdm(map(lambda f: f.domain, chunks)):
|
||||||
assert min(np.atleast_3d(ds).shape) > 0, f"shape of 0 for {ds}"
|
# TODO: fix this correctly
|
||||||
|
# assert min(np.atleast_3d(ds).shape) > 0, f"shape of 0 for {ds}"
|
||||||
|
if not ds: continue
|
||||||
domain_features.append(np.apply_along_axis(get_domain_features_reduced, 2, np.atleast_3d(ds)))
|
domain_features.append(np.apply_along_axis(get_domain_features_reduced, 2, np.atleast_3d(ds)))
|
||||||
domain_features = np.concatenate(domain_features, 0)
|
domain_features = np.concatenate(domain_features, 0)
|
||||||
logger.info(" compute flow features")
|
logger.info(" compute flow features")
|
||||||
@ -161,7 +167,7 @@ def create_dataset_from_lists(chunks, vocab, max_len):
|
|||||||
logger.info(" select hits")
|
logger.info(" select hits")
|
||||||
hits = np.max(np.stack(map(lambda f: f.virusTotalHits, chunks)), axis=1)
|
hits = np.max(np.stack(map(lambda f: f.virusTotalHits, chunks)), axis=1)
|
||||||
logger.info(" select names")
|
logger.info(" select names")
|
||||||
names = np.unique(np.stack(map(lambda f: f.user_hash, chunks)), axis=1)
|
names = np.unique(np.stack(map(lambda f: f.user_hash, chunks)))
|
||||||
logger.info(" select servers")
|
logger.info(" select servers")
|
||||||
servers = np.max(np.stack(map(lambda f: f.serverLabel, chunks)), axis=1)
|
servers = np.max(np.stack(map(lambda f: f.serverLabel, chunks)), axis=1)
|
||||||
logger.info(" select trusted hits")
|
logger.info(" select trusted hits")
|
||||||
|
@ -68,7 +68,7 @@ class Hyperband:
|
|||||||
r = self.max_iter * self.eta ** (-s)
|
r = self.max_iter * self.eta ** (-s)
|
||||||
|
|
||||||
# n random configurations
|
# n random configurations
|
||||||
T = [self.get_params() for i in range(n)]
|
T = [self.get_params() for _ in range(n)]
|
||||||
|
|
||||||
for i in range((s + 1) - int(skip_last)): # changed from s + 1
|
for i in range((s + 1) - int(skip_last)): # changed from s + 1
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user