ma_cisco_malware/hyperband.py

147 lines
5.2 KiB
Python

# -*- coding: utf-8 -*-
# implementation of hyperband:
# https://arxiv.org/pdf/1603.06560.pdf
import logging
import random
from math import ceil, log
from random import random as rng
from time import ctime, time
import joblib
import numpy as np
from keras.callbacks import EarlyStopping
import models
logger = logging.getLogger('cisco_logger')
def sample_params(param_distribution: dict):
p = {}
for key, val in param_distribution.items():
p[key] = random.choice(val)
return p
class Hyperband:
def __init__(self, param_distribution, X, y, max_iter=81, savefile=None):
self.get_params = lambda: sample_params(param_distribution)
self.max_iter = max_iter # maximum iterations per configuration
self.eta = 3 # defines configuration downsampling rate (default = 3)
self.logeta = lambda x: log(x) / log(self.eta)
self.s_max = int(self.logeta(self.max_iter))
self.B = (self.s_max + 1) * self.max_iter
self.results = [] # list of dicts
self.counter = 0
self.best_loss = np.inf
self.best_counter = -1
self.savefile = savefile
self.X = X
self.y = y
def try_params(self, n_iterations, params):
n_iterations = int(round(n_iterations))
model = models.get_models_by_params(params)
callbacks = [EarlyStopping(monitor='val_loss',
patience=5,
verbose=False)]
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
history = model.fit(self.X,
self.y[0] if params["model_output"] == "client" else self.y,
batch_size=params["batch_size"],
epochs=n_iterations,
callbacks=callbacks,
shuffle=True,
validation_split=0.4)
return {"loss": np.min(history.history['val_loss']),
"early_stop": len(history.history["loss"]) < n_iterations,
"stop_after": len(history.history["val_loss"])}
# can be called multiple times
def run(self, skip_last=0, dry_run=False):
for s in reversed(range(self.s_max + 1)):
# initial number of configurations
n = int(ceil(self.B / self.max_iter / (s + 1) * self.eta ** s))
# initial number of iterations per config
r = self.max_iter * self.eta ** (-s)
# n random configurations
random_configs = [self.get_params() for _ in range(n)]
for i in range((s + 1) - int(skip_last)): # changed from s + 1
# Run each of the n configs for <iterations>
# and keep best (n_configs / eta) configurations
n_configs = n * self.eta ** (-i)
n_iterations = r * self.eta ** (i)
logger.info("*** {} configurations x {:.1f} iterations each".format(
n_configs, n_iterations))
val_losses = []
early_stops = []
for t in random_configs:
self.counter += 1
logger.info("Config {} | {} | lowest loss so far: {:.4f} (run {})".format(
self.counter, ctime(), self.best_loss, self.best_counter))
start_time = time()
if dry_run:
result = {'loss': rng(), 'log_loss': rng(), 'auc': rng()}
else:
result = self.try_params(n_iterations, t) # <---
assert (type(result) == dict)
assert ('loss' in result)
seconds = int(round(time() - start_time))
logger.info("{} seconds.".format(seconds))
loss = result['loss']
val_losses.append(loss)
early_stop = result.get('early_stop', False)
early_stops.append(early_stop)
# keeping track of the best result so far (for display only)
# could do it be checking results each time, but hey
if loss < self.best_loss:
self.best_loss = loss
self.best_counter = self.counter
result['counter'] = self.counter
result['seconds'] = seconds
result['params'] = t
result['iterations'] = n_iterations
self.results.append(result)
# select a number of best configurations for the next loop
# filter out early stops, if any
indices = np.argsort(val_losses)
random_configs = [random_configs[i] for i in indices if not early_stops[i]]
random_configs = random_configs[0:int(n_configs / self.eta)]
if self.savefile:
joblib.dump(self.results, self.savefile)
return self.results