refactor hyperband implementation
This commit is contained in:
		
							
								
								
									
										42
									
								
								hyperband.py
									
									
									
									
									
								
							
							
						
						
									
										42
									
								
								hyperband.py
									
									
									
									
									
								
							@@ -3,13 +3,15 @@
 | 
			
		||||
# https://arxiv.org/pdf/1603.06560.pdf
 | 
			
		||||
import logging
 | 
			
		||||
import random
 | 
			
		||||
from math import log, ceil
 | 
			
		||||
from math import ceil, log
 | 
			
		||||
from random import random as rng
 | 
			
		||||
from time import time, ctime
 | 
			
		||||
from time import ctime, time
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from keras.callbacks import EarlyStopping
 | 
			
		||||
 | 
			
		||||
import models
 | 
			
		||||
from main import create_model
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger('logger')
 | 
			
		||||
 | 
			
		||||
@@ -22,10 +24,10 @@ def sample_params(param_distribution: dict):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Hyperband:
 | 
			
		||||
    def __init__(self, param_distribution, X, y):
 | 
			
		||||
    def __init__(self, param_distribution, X, y, max_iter=81):
 | 
			
		||||
        self.get_params = lambda: sample_params(param_distribution)
 | 
			
		||||
 | 
			
		||||
        self.max_iter = 81  # maximum iterations per configuration
 | 
			
		||||
        self.max_iter = max_iter  # maximum iterations per configuration
 | 
			
		||||
        self.eta = 3  # defines configuration downsampling rate (default = 3)
 | 
			
		||||
 | 
			
		||||
        self.logeta = lambda x: log(x) / log(self.eta)
 | 
			
		||||
@@ -42,19 +44,31 @@ class Hyperband:
 | 
			
		||||
    
 | 
			
		||||
    def try_params(self, n_iterations, params):
 | 
			
		||||
        n_iterations = int(round(n_iterations))
 | 
			
		||||
        embedding, model = models.get_models_by_params(params)
 | 
			
		||||
        embedding, model, new_model = models.get_models_by_params(params)
 | 
			
		||||
 | 
			
		||||
        model = create_model(model, params["output"])
 | 
			
		||||
        new_model = create_model(new_model, params["output"])
 | 
			
		||||
 | 
			
		||||
        if params["type"] in ("inter", "staggered"):
 | 
			
		||||
            model = new_model
 | 
			
		||||
 | 
			
		||||
        callbacks = [EarlyStopping(monitor='val_loss',
 | 
			
		||||
                                   patience=5,
 | 
			
		||||
                                   verbose=False)]
 | 
			
		||||
        
 | 
			
		||||
        model.compile(optimizer='adam',
 | 
			
		||||
                      loss='categorical_crossentropy',
 | 
			
		||||
                      loss='binary_crossentropy',
 | 
			
		||||
                      metrics=['accuracy'])
 | 
			
		||||
 | 
			
		||||
        history = model.fit(self.X,
 | 
			
		||||
                            self.y,
 | 
			
		||||
                            batch_size=params["batch_size"],
 | 
			
		||||
                            epochs=n_iterations,
 | 
			
		||||
                            callbacks=callbacks,
 | 
			
		||||
                            shuffle=True,
 | 
			
		||||
                            validation_split=0.2)
 | 
			
		||||
                            validation_split=0.4)
 | 
			
		||||
 | 
			
		||||
        return {"loss": history.history['loss'][-1]}
 | 
			
		||||
        return {"loss": history.history['val_loss'][-1], "early_stop": True}
 | 
			
		||||
    
 | 
			
		||||
    # can be called multiple times
 | 
			
		||||
    def run(self, skip_last=0, dry_run=False):
 | 
			
		||||
@@ -68,7 +82,7 @@ class Hyperband:
 | 
			
		||||
            r = self.max_iter * self.eta ** (-s)
 | 
			
		||||
        
 | 
			
		||||
            # n random configurations
 | 
			
		||||
            T = [self.get_params() for _ in range(n)]
 | 
			
		||||
            random_configs = [self.get_params() for _ in range(n)]
 | 
			
		||||
            
 | 
			
		||||
            for i in range((s + 1) - int(skip_last)):  # changed from s + 1
 | 
			
		||||
    
 | 
			
		||||
@@ -79,16 +93,16 @@ class Hyperband:
 | 
			
		||||
                n_iterations = r * self.eta ** (i)
 | 
			
		||||
    
 | 
			
		||||
                logger.info("\n*** {} configurations x {:.1f} iterations each".format(
 | 
			
		||||
                    n_configs, n_iterations))
 | 
			
		||||
                        n_configs, n_iterations))
 | 
			
		||||
                
 | 
			
		||||
                val_losses = []
 | 
			
		||||
                early_stops = []
 | 
			
		||||
    
 | 
			
		||||
                for t in T:
 | 
			
		||||
                for t in random_configs:
 | 
			
		||||
                    
 | 
			
		||||
                    self.counter += 1
 | 
			
		||||
                    logger.info("\n{} | {} | lowest loss so far: {:.4f} (run {})\n".format(
 | 
			
		||||
                        self.counter, ctime(), self.best_loss, self.best_counter))
 | 
			
		||||
                            self.counter, ctime(), self.best_loss, self.best_counter))
 | 
			
		||||
                    
 | 
			
		||||
                    start_time = time()
 | 
			
		||||
 | 
			
		||||
@@ -125,7 +139,7 @@ class Hyperband:
 | 
			
		||||
                # select a number of best configurations for the next loop
 | 
			
		||||
                # filter out early stops, if any
 | 
			
		||||
                indices = np.argsort(val_losses)
 | 
			
		||||
                T = [T[i] for i in indices if not early_stops[i]]
 | 
			
		||||
                T = T[0:int(n_configs / self.eta)]
 | 
			
		||||
                random_configs = [random_configs[i] for i in indices if not early_stops[i]]
 | 
			
		||||
                random_configs = random_configs[0:int(n_configs / self.eta)]
 | 
			
		||||
        
 | 
			
		||||
        return self.results
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										33
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										33
									
								
								main.py
									
									
									
									
									
								
							@@ -1,4 +1,3 @@
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
@@ -100,33 +99,39 @@ def main_paul_best():
 | 
			
		||||
def main_hyperband():
 | 
			
		||||
    params = {
 | 
			
		||||
        # static params
 | 
			
		||||
        "type": ["paul"],
 | 
			
		||||
        "type": [args.model_type],
 | 
			
		||||
        "depth": [args.model_depth],
 | 
			
		||||
        "output": [args.model_output],
 | 
			
		||||
        "batch_size": [args.batch_size],
 | 
			
		||||
        "window_size": [10],
 | 
			
		||||
        "domain_length": [40],
 | 
			
		||||
        "flow_features": [3],
 | 
			
		||||
        "input_length": [40],
 | 
			
		||||
        # model params
 | 
			
		||||
        "embedding_size": [8, 16, 32, 64, 128, 256],
 | 
			
		||||
        "filter_embedding": [8, 16, 32, 64, 128, 256],
 | 
			
		||||
        "embedding_size": [2 ** x for x in range(3, 7)],
 | 
			
		||||
        "filter_embedding": [2 ** x for x in range(1, 10)],
 | 
			
		||||
        "kernel_embedding": [1, 3, 5, 7, 9],
 | 
			
		||||
        "hidden_embedding": [8, 16, 32, 64, 128, 256],
 | 
			
		||||
        "dense_embedding": [2 ** x for x in range(4, 10)],
 | 
			
		||||
        "dropout": [0.5],
 | 
			
		||||
        "domain_features": [8, 16, 32, 64, 128, 256],
 | 
			
		||||
        "filter_main": [8, 16, 32, 64, 128, 256],
 | 
			
		||||
        "kernels_main": [1, 3, 5, 7, 9],
 | 
			
		||||
        "dense_main": [8, 16, 32, 64, 128, 256],
 | 
			
		||||
        "filter_main": [2 ** x for x in range(1, 10)],
 | 
			
		||||
        "kernel_main": [1, 3, 5, 7, 9],
 | 
			
		||||
        "dense_main": [2 ** x for x in range(1, 12)],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    logger.info("create training dataset")
 | 
			
		||||
    domain_tr, flow_tr, name_tr, client_tr, server_tr = dataset.load_or_generate_h5data(args.train_h5data,
 | 
			
		||||
                                                                                        args.train_data,
 | 
			
		||||
                                                                                        args.domain_length, args.window)
 | 
			
		||||
    domain_tr, flow_tr, name_tr, client_tr, server_windows_tr = dataset.load_or_generate_h5data(args.train_h5data,
 | 
			
		||||
                                                                                                args.train_data,
 | 
			
		||||
                                                                                                args.domain_length,
 | 
			
		||||
                                                                                                args.window)
 | 
			
		||||
    server_tr = np.max(server_windows_tr, axis=1)
 | 
			
		||||
 | 
			
		||||
    if args.model_type in ("inter", "staggered"):
 | 
			
		||||
        server_tr = np.expand_dims(server_windows_tr, 2)
 | 
			
		||||
 | 
			
		||||
    hp = hyperband.Hyperband(params,
 | 
			
		||||
                             [domain_tr, flow_tr],
 | 
			
		||||
                             [client_tr, server_tr])
 | 
			
		||||
    results = hp.run()
 | 
			
		||||
    json.dump(results, open("hyperband.json"))
 | 
			
		||||
    joblib.dump(results, "hyperband.joblib")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main_train(param=None):
 | 
			
		||||
 
 | 
			
		||||
@@ -17,7 +17,6 @@ def get_models_by_params(params: dict):
 | 
			
		||||
    dropout = params.get("dropout")
 | 
			
		||||
    # mainly prediction model
 | 
			
		||||
    flow_features = params.get("flow_features")
 | 
			
		||||
    domain_features = params.get("domain_features")
 | 
			
		||||
    window_size = params.get("window_size")
 | 
			
		||||
    domain_length = params.get("domain_length")
 | 
			
		||||
    filter_main = params.get("filter_main")
 | 
			
		||||
@@ -36,10 +35,10 @@ def get_models_by_params(params: dict):
 | 
			
		||||
    embedding_model = networks.get_embedding(embedding_size, input_length, filter_embedding, kernel_embedding,
 | 
			
		||||
                                             hidden_embedding, 0.5)
 | 
			
		||||
 | 
			
		||||
    old_model = networks.get_model(0.25, flow_features, domain_features, window_size, domain_length,
 | 
			
		||||
    old_model = networks.get_model(0.25, flow_features, hidden_embedding, window_size, domain_length,
 | 
			
		||||
                                   filter_main, kernel_main, dense_dim, embedding_model, model_output)
 | 
			
		||||
 | 
			
		||||
    new_model = networks.get_new_model(0.25, flow_features, domain_features, window_size, domain_length,
 | 
			
		||||
    new_model = networks.get_new_model(0.25, flow_features, hidden_embedding, window_size, domain_length,
 | 
			
		||||
                                       filter_main, kernel_main, dense_dim, embedding_model, model_output)
 | 
			
		||||
 | 
			
		||||
    return embedding_model, old_model, new_model
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										16
									
								
								rerun_models.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								rerun_models.sh
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,16 @@
 | 
			
		||||
#!/usr/bin/env bash
 | 
			
		||||
 | 
			
		||||
SRC=$1
 | 
			
		||||
DEST=$2
 | 
			
		||||
DATADIR=$3
 | 
			
		||||
INIT=$4
 | 
			
		||||
EPOCHS=$5
 | 
			
		||||
BS=128
 | 
			
		||||
 | 
			
		||||
for i in `ls -d $SRC*/`
 | 
			
		||||
do
 | 
			
		||||
    echo "retrain model in ${i}"
 | 
			
		||||
    name=$(basename $i)
 | 
			
		||||
    python3 main.py --mode retrain --model_src ${i} --model_dest ${DEST}/${name} --init_epoch $INIT --epochs $EPOCHS --batch $BS --train ${DATADIR}
 | 
			
		||||
 | 
			
		||||
done
 | 
			
		||||
		Reference in New Issue
	
	Block a user