refactor argparser into separate file, add logger
This commit is contained in:
parent
9f0bae33d5
commit
2afaccc84b
2
Makefile
2
Makefile
@ -1,5 +1,5 @@
|
|||||||
test:
|
test:
|
||||||
python3 main.py --modes train --epochs 1 --batch 64 --train data/rk_data.csv.gz
|
python3 main.py --modes train --epochs 1 --batch 128 --train data/rk_mini.csv.gz
|
||||||
|
|
||||||
hyper:
|
hyper:
|
||||||
python3 main.py --modes hyperband --epochs 1 --batch 64 --train data/rk_data.csv.gz
|
python3 main.py --modes hyperband --epochs 1 --batch 64 --train data/rk_data.csv.gz
|
||||||
|
78
arguments.py
Normal file
78
arguments.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--modes", action="store", dest="modes", nargs="+",
|
||||||
|
default=[])
|
||||||
|
|
||||||
|
parser.add_argument("--train", action="store", dest="train_data",
|
||||||
|
default="data/full_dataset.csv.tar.bz2")
|
||||||
|
|
||||||
|
parser.add_argument("--test", action="store", dest="test_data",
|
||||||
|
default="data/full_future_dataset.csv.tar.bz2")
|
||||||
|
|
||||||
|
# parser.add_argument("--h5data", action="store", dest="h5data",
|
||||||
|
# default="")
|
||||||
|
#
|
||||||
|
parser.add_argument("--models", action="store", dest="model_path",
|
||||||
|
default="models/models_x")
|
||||||
|
|
||||||
|
# parser.add_argument("--pred", action="store", dest="pred",
|
||||||
|
# default="")
|
||||||
|
#
|
||||||
|
parser.add_argument("--type", action="store", dest="model_type",
|
||||||
|
default="paul")
|
||||||
|
|
||||||
|
parser.add_argument("--batch", action="store", dest="batch_size",
|
||||||
|
default=64, type=int)
|
||||||
|
|
||||||
|
parser.add_argument("--epochs", action="store", dest="epochs",
|
||||||
|
default=10, type=int)
|
||||||
|
|
||||||
|
# parser.add_argument("--samples", action="store", dest="samples",
|
||||||
|
# default=100000, type=int)
|
||||||
|
#
|
||||||
|
# parser.add_argument("--samples_val", action="store", dest="samples_val",
|
||||||
|
# default=10000, type=int)
|
||||||
|
#
|
||||||
|
parser.add_argument("--embd", action="store", dest="embedding",
|
||||||
|
default=128, type=int)
|
||||||
|
|
||||||
|
parser.add_argument("--hidden_char_dims", action="store", dest="hidden_char_dims",
|
||||||
|
default=256, type=int)
|
||||||
|
|
||||||
|
parser.add_argument("--window", action="store", dest="window",
|
||||||
|
default=10, type=int)
|
||||||
|
|
||||||
|
parser.add_argument("--domain_length", action="store", dest="domain_length",
|
||||||
|
default=40, type=int)
|
||||||
|
|
||||||
|
parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
|
||||||
|
default=512, type=int)
|
||||||
|
|
||||||
|
|
||||||
|
# parser.add_argument("--queue", action="store", dest="queue_size",
|
||||||
|
# default=50, type=int)
|
||||||
|
#
|
||||||
|
# parser.add_argument("--p", action="store", dest="p_train",
|
||||||
|
# default=0.5, type=float)
|
||||||
|
#
|
||||||
|
# parser.add_argument("--p_val", action="store", dest="p_val",
|
||||||
|
# default=0.01, type=float)
|
||||||
|
#
|
||||||
|
# parser.add_argument("--gpu", action="store", dest="gpu",
|
||||||
|
# default=0, type=int)
|
||||||
|
#
|
||||||
|
# parser.add_argument("--tmp", action="store_true", dest="tmp")
|
||||||
|
#
|
||||||
|
# parser.add_argument("--test", action="store_true", dest="test")
|
||||||
|
|
||||||
|
|
||||||
|
def parse():
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
||||||
|
args.clf_model = os.path.join(args.model_path, "clf.h5")
|
||||||
|
args.train_log = os.path.join(args.model_path, "train.log")
|
||||||
|
args.h5data = args.train_data + ".h5"
|
||||||
|
return args
|
16
dataset.py
16
dataset.py
@ -1,4 +1,5 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
import logging
|
||||||
import string
|
import string
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
@ -7,6 +8,8 @@ import pandas as pd
|
|||||||
from keras.utils import np_utils
|
from keras.utils import np_utils
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
logger = logging.getLogger('logger')
|
||||||
|
|
||||||
chars = dict((char, idx + 1) for (idx, char) in
|
chars = dict((char, idx + 1) for (idx, char) in
|
||||||
enumerate(string.ascii_lowercase + string.punctuation + string.digits))
|
enumerate(string.ascii_lowercase + string.punctuation + string.digits))
|
||||||
|
|
||||||
@ -36,7 +39,7 @@ def get_user_chunks(dataFrame, windowSize=10, overlapping=False,
|
|||||||
userIDs = np.arange(len(dataFrame))
|
userIDs = np.arange(len(dataFrame))
|
||||||
for blockID in np.arange(numBlocks):
|
for blockID in np.arange(numBlocks):
|
||||||
curIDs = userIDs[(blockID * windowSize):((blockID + 1) * windowSize)]
|
curIDs = userIDs[(blockID * windowSize):((blockID + 1) * windowSize)]
|
||||||
# print(curIDs)
|
# logger.info(curIDs)
|
||||||
useData = dataFrame.iloc[curIDs]
|
useData = dataFrame.iloc[curIDs]
|
||||||
curDomains = useData['domain']
|
curDomains = useData['domain']
|
||||||
if maxLengthInSeconds != -1:
|
if maxLengthInSeconds != -1:
|
||||||
@ -88,7 +91,7 @@ def get_all_flow_features(features):
|
|||||||
def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10):
|
def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10):
|
||||||
domains = []
|
domains = []
|
||||||
features = []
|
features = []
|
||||||
print("get chunks from user data frames")
|
logger.info("get chunks from user data frames")
|
||||||
for i, user_flow in tqdm(list(enumerate(get_flow_per_user(user_flow_df)))):
|
for i, user_flow in tqdm(list(enumerate(get_flow_per_user(user_flow_df)))):
|
||||||
(domain_windows, feature_windows) = get_user_chunks(user_flow,
|
(domain_windows, feature_windows) = get_user_chunks(user_flow,
|
||||||
windowSize=window_size,
|
windowSize=window_size,
|
||||||
@ -97,7 +100,7 @@ def create_dataset_from_flows(user_flow_df, char_dict, max_len, window_size=10):
|
|||||||
domains += domain_windows
|
domains += domain_windows
|
||||||
features += feature_windows
|
features += feature_windows
|
||||||
|
|
||||||
print("create training dataset")
|
logger.info("create training dataset")
|
||||||
domain_tr, flow_tr, hits_tr, _, server_tr, trusted_hits_tr = create_dataset_from_lists(domains=domains,
|
domain_tr, flow_tr, hits_tr, _, server_tr, trusted_hits_tr = create_dataset_from_lists(domains=domains,
|
||||||
flows=features,
|
flows=features,
|
||||||
vocab=char_dict,
|
vocab=char_dict,
|
||||||
@ -150,13 +153,8 @@ def create_dataset_from_lists(domains, flows, vocab, max_len, window_size=10):
|
|||||||
:param window_size: size of the flow window
|
:param window_size: size of the flow window
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# sample_size = len(domains)
|
|
||||||
|
|
||||||
# domain_features = np.zeros((sample_size, window_size, max_len))
|
|
||||||
flow_features = get_all_flow_features(flows)
|
|
||||||
|
|
||||||
domain_features = np.array([[get_domain_features(d, vocab, max_len) for d in x] for x in domains])
|
domain_features = np.array([[get_domain_features(d, vocab, max_len) for d in x] for x in domains])
|
||||||
|
flow_features = get_all_flow_features(flows)
|
||||||
hits = np.max(np.stack(map(lambda f: f.virusTotalHits, flows)), axis=1)
|
hits = np.max(np.stack(map(lambda f: f.virusTotalHits, flows)), axis=1)
|
||||||
names = np.unique(np.stack(map(lambda f: f.user_hash, flows)), axis=1)
|
names = np.unique(np.stack(map(lambda f: f.user_hash, flows)), axis=1)
|
||||||
servers = np.max(np.stack(map(lambda f: f.serverLabel, flows)), axis=1)
|
servers = np.max(np.stack(map(lambda f: f.serverLabel, flows)), axis=1)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# implementation of hyperband:
|
# implementation of hyperband:
|
||||||
# https://arxiv.org/pdf/1603.06560.pdf
|
# https://arxiv.org/pdf/1603.06560.pdf
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
from math import log, ceil
|
from math import log, ceil
|
||||||
from random import random as rng
|
from random import random as rng
|
||||||
@ -10,6 +11,8 @@ import numpy as np
|
|||||||
|
|
||||||
import models
|
import models
|
||||||
|
|
||||||
|
logger = logging.getLogger('logger')
|
||||||
|
|
||||||
|
|
||||||
def sample_params(param_distribution: dict):
|
def sample_params(param_distribution: dict):
|
||||||
p = {}
|
p = {}
|
||||||
@ -75,7 +78,7 @@ class Hyperband:
|
|||||||
n_configs = n * self.eta ** (-i)
|
n_configs = n * self.eta ** (-i)
|
||||||
n_iterations = r * self.eta ** (i)
|
n_iterations = r * self.eta ** (i)
|
||||||
|
|
||||||
print("\n*** {} configurations x {:.1f} iterations each".format(
|
logger.info("\n*** {} configurations x {:.1f} iterations each".format(
|
||||||
n_configs, n_iterations))
|
n_configs, n_iterations))
|
||||||
|
|
||||||
val_losses = []
|
val_losses = []
|
||||||
@ -84,7 +87,7 @@ class Hyperband:
|
|||||||
for t in T:
|
for t in T:
|
||||||
|
|
||||||
self.counter += 1
|
self.counter += 1
|
||||||
print("\n{} | {} | lowest loss so far: {:.4f} (run {})\n".format(
|
logger.info("\n{} | {} | lowest loss so far: {:.4f} (run {})\n".format(
|
||||||
self.counter, ctime(), self.best_loss, self.best_counter))
|
self.counter, ctime(), self.best_loss, self.best_counter))
|
||||||
|
|
||||||
start_time = time()
|
start_time = time()
|
||||||
@ -98,7 +101,7 @@ class Hyperband:
|
|||||||
assert ('loss' in result)
|
assert ('loss' in result)
|
||||||
|
|
||||||
seconds = int(round(time() - start_time))
|
seconds = int(round(time() - start_time))
|
||||||
print("\n{} seconds.".format(seconds))
|
logger.info("\n{} seconds.".format(seconds))
|
||||||
|
|
||||||
loss = result['loss']
|
loss = result['loss']
|
||||||
val_losses.append(loss)
|
val_losses.append(loss)
|
||||||
|
138
main.py
138
main.py
@ -1,85 +1,46 @@
|
|||||||
import argparse
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
|
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
|
||||||
from keras.models import load_model
|
from keras.models import load_model
|
||||||
|
|
||||||
|
import arguments
|
||||||
import dataset
|
import dataset
|
||||||
import hyperband
|
import hyperband
|
||||||
import models
|
import models
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
# create logger
|
||||||
|
logger = logging.getLogger('logger')
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
parser.add_argument("--modes", action="store", dest="modes", nargs="+",
|
# create console handler and set level to debug
|
||||||
default=[])
|
ch = logging.StreamHandler()
|
||||||
|
ch.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
parser.add_argument("--train", action="store", dest="train_data",
|
# create formatter
|
||||||
default="data/full_dataset.csv.tar.bz2")
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
parser.add_argument("--test", action="store", dest="test_data",
|
# add formatter to ch
|
||||||
default="data/full_future_dataset.csv.tar.bz2")
|
ch.setFormatter(formatter)
|
||||||
|
|
||||||
# parser.add_argument("--h5data", action="store", dest="h5data",
|
# add ch to logger
|
||||||
# default="")
|
logger.addHandler(ch)
|
||||||
#
|
|
||||||
parser.add_argument("--models", action="store", dest="model_path",
|
|
||||||
default="models/models_x")
|
|
||||||
|
|
||||||
# parser.add_argument("--pred", action="store", dest="pred",
|
ch = logging.FileHandler("info.log")
|
||||||
# default="")
|
ch.setLevel(logging.DEBUG)
|
||||||
#
|
|
||||||
parser.add_argument("--type", action="store", dest="model_type",
|
|
||||||
default="paul")
|
|
||||||
|
|
||||||
parser.add_argument("--batch", action="store", dest="batch_size",
|
# create formatter
|
||||||
default=64, type=int)
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
parser.add_argument("--epochs", action="store", dest="epochs",
|
# add formatter to ch
|
||||||
default=10, type=int)
|
ch.setFormatter(formatter)
|
||||||
|
|
||||||
# parser.add_argument("--samples", action="store", dest="samples",
|
# add ch to logger
|
||||||
# default=100000, type=int)
|
logger.addHandler(ch)
|
||||||
#
|
|
||||||
# parser.add_argument("--samples_val", action="store", dest="samples_val",
|
|
||||||
# default=10000, type=int)
|
|
||||||
#
|
|
||||||
parser.add_argument("--embd", action="store", dest="embedding",
|
|
||||||
default=128, type=int)
|
|
||||||
|
|
||||||
parser.add_argument("--hidden_char_dims", action="store", dest="hidden_char_dims",
|
args = arguments.parse()
|
||||||
default=256, type=int)
|
|
||||||
|
|
||||||
parser.add_argument("--window", action="store", dest="window",
|
|
||||||
default=10, type=int)
|
|
||||||
|
|
||||||
parser.add_argument("--domain_length", action="store", dest="domain_length",
|
|
||||||
default=40, type=int)
|
|
||||||
|
|
||||||
parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
|
|
||||||
default=512, type=int)
|
|
||||||
|
|
||||||
# parser.add_argument("--queue", action="store", dest="queue_size",
|
|
||||||
# default=50, type=int)
|
|
||||||
#
|
|
||||||
# parser.add_argument("--p", action="store", dest="p_train",
|
|
||||||
# default=0.5, type=float)
|
|
||||||
#
|
|
||||||
# parser.add_argument("--p_val", action="store", dest="p_val",
|
|
||||||
# default=0.01, type=float)
|
|
||||||
#
|
|
||||||
# parser.add_argument("--gpu", action="store", dest="gpu",
|
|
||||||
# default=0, type=int)
|
|
||||||
#
|
|
||||||
# parser.add_argument("--tmp", action="store_true", dest="tmp")
|
|
||||||
#
|
|
||||||
# parser.add_argument("--test", action="store_true", dest="test")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
|
||||||
args.clf_model = os.path.join(args.model_path, "clf.h5")
|
|
||||||
args.train_log = os.path.join(args.model_path, "train.log")
|
|
||||||
args.h5data = args.train_data + ".h5"
|
|
||||||
|
|
||||||
|
|
||||||
# config = tf.ConfigProto(log_device_placement=True)
|
# config = tf.ConfigProto(log_device_placement=True)
|
||||||
@ -125,7 +86,7 @@ def main_hyperband():
|
|||||||
params = {
|
params = {
|
||||||
# static params
|
# static params
|
||||||
"type": ["paul"],
|
"type": ["paul"],
|
||||||
"batch_size": [64],
|
"batch_size": [args.batch_size],
|
||||||
"vocab_size": [len(char_dict) + 1],
|
"vocab_size": [len(char_dict) + 1],
|
||||||
"window_size": [10],
|
"window_size": [10],
|
||||||
"domain_length": [40],
|
"domain_length": [40],
|
||||||
@ -143,32 +104,35 @@ def main_hyperband():
|
|||||||
"dense_main": [16, 32, 64, 128, 256, 512],
|
"dense_main": [16, 32, 64, 128, 256, 512],
|
||||||
}
|
}
|
||||||
param = hyperband.sample_params(params)
|
param = hyperband.sample_params(params)
|
||||||
print(param)
|
logger.info(param)
|
||||||
|
|
||||||
print("create training dataset")
|
logger.info("create training dataset")
|
||||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict,
|
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict,
|
||||||
max_len=args.domain_length,
|
max_len=args.domain_length,
|
||||||
window_size=args.window)
|
window_size=args.window)
|
||||||
|
|
||||||
hp = hyperband.Hyperband(params, [domain_tr, flow_tr], [client_tr, server_tr])
|
hp = hyperband.Hyperband(params,
|
||||||
hp.run()
|
[domain_tr, flow_tr],
|
||||||
|
[client_tr, server_tr])
|
||||||
|
results = hp.run()
|
||||||
|
json.dump(results, open("hyperband.json"))
|
||||||
|
|
||||||
|
|
||||||
def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
||||||
char_dict = dataset.get_character_dict()
|
char_dict = dataset.get_character_dict()
|
||||||
print("check for h5data", h5data)
|
logger.info(f"check for h5data {h5data}")
|
||||||
try:
|
try:
|
||||||
open(h5data, "r")
|
open(h5data, "r")
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print("h5 data not found - load csv file")
|
logger.info("h5 data not found - load csv file")
|
||||||
user_flow_df = dataset.get_user_flow_data(train_data)
|
user_flow_df = dataset.get_user_flow_data(train_data)
|
||||||
print("create training dataset")
|
logger.info("create training dataset")
|
||||||
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict,
|
domain_tr, flow_tr, client_tr, server_tr = dataset.create_dataset_from_flows(user_flow_df, char_dict,
|
||||||
max_len=domain_length,
|
max_len=domain_length,
|
||||||
window_size=window_size)
|
window_size=window_size)
|
||||||
print("store training dataset as h5 file")
|
logger.info("store training dataset as h5 file")
|
||||||
dataset.store_h5dataset(args.h5data, domain_tr, flow_tr, client_tr, server_tr)
|
dataset.store_h5dataset(args.h5data, domain_tr, flow_tr, client_tr, server_tr)
|
||||||
print("load h5 dataset")
|
logger.info("load h5 dataset")
|
||||||
return dataset.load_h5dataset(h5data)
|
return dataset.load_h5dataset(h5data)
|
||||||
|
|
||||||
|
|
||||||
@ -204,7 +168,7 @@ def main_train():
|
|||||||
embedding, model = models.get_models_by_params(param)
|
embedding, model = models.get_models_by_params(param)
|
||||||
embedding.summary()
|
embedding.summary()
|
||||||
model.summary()
|
model.summary()
|
||||||
print("define callbacks")
|
logger.info("define callbacks")
|
||||||
cp = ModelCheckpoint(filepath=args.clf_model,
|
cp = ModelCheckpoint(filepath=args.clf_model,
|
||||||
monitor='val_loss',
|
monitor='val_loss',
|
||||||
verbose=False,
|
verbose=False,
|
||||||
@ -213,11 +177,11 @@ def main_train():
|
|||||||
early = EarlyStopping(monitor='val_loss',
|
early = EarlyStopping(monitor='val_loss',
|
||||||
patience=5,
|
patience=5,
|
||||||
verbose=False)
|
verbose=False)
|
||||||
print("compile model")
|
logger.info("compile model")
|
||||||
model.compile(optimizer='adam',
|
model.compile(optimizer='adam',
|
||||||
loss='categorical_crossentropy',
|
loss='categorical_crossentropy',
|
||||||
metrics=['accuracy'])
|
metrics=['accuracy'])
|
||||||
print("start training")
|
logger.info("start training")
|
||||||
model.fit([domain_tr, flow_tr],
|
model.fit([domain_tr, flow_tr],
|
||||||
[client_tr, server_tr],
|
[client_tr, server_tr],
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
@ -225,40 +189,40 @@ def main_train():
|
|||||||
callbacks=[cp, csv, early],
|
callbacks=[cp, csv, early],
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
validation_split=0.2)
|
validation_split=0.2)
|
||||||
print("save embedding")
|
logger.info("save embedding")
|
||||||
embedding.save(args.embedding_model)
|
embedding.save(args.embedding_model)
|
||||||
|
|
||||||
|
|
||||||
def main_test():
|
def main_test():
|
||||||
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.h5data, args.train_data,
|
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.h5data, args.train_data,
|
||||||
args.domain_length, args.window)
|
args.domain_length, args.window)
|
||||||
# embedding = load_model(args.embedding_model)
|
|
||||||
clf = load_model(args.clf_model)
|
clf = load_model(args.clf_model)
|
||||||
|
|
||||||
loss, _, _, client_acc, server_acc = clf.evaluate([domain_val, flow_val],
|
loss, _, _, client_acc, server_acc = clf.evaluate([domain_val, flow_val],
|
||||||
[client_val, server_val],
|
[client_val, server_val],
|
||||||
batch_size=args.batch_size)
|
batch_size=args.batch_size)
|
||||||
|
logger.info(f"loss: {loss}\nclient acc: {client_acc}\nserver acc: {server_acc}")
|
||||||
print(f"loss: {loss}\nclient acc: {client_acc}\nserver acc: {server_acc}")
|
y_pred = clf.predict([domain_val, flow_val],
|
||||||
|
batch_size=args.batch_size)
|
||||||
|
np.save(os.path.join(args.model_path, "future_predict.npy"), y_pred)
|
||||||
|
|
||||||
|
|
||||||
def main_visualization():
|
def main_visualization():
|
||||||
mask = dataset.load_mask_eval(args.data, args.test_image)
|
mask = dataset.load_mask_eval(args.data, args.test_image)
|
||||||
y_pred_path = args.model_path + "pred.npy"
|
y_pred_path = args.model_path + "pred.npy"
|
||||||
print("plot model")
|
logger.info("plot model")
|
||||||
model = load_model(args.model_path + "model.h5",
|
model = load_model(args.model_path + "model.h5",
|
||||||
custom_objects=evaluation.get_metrics())
|
custom_objects=evaluation.get_metrics())
|
||||||
visualize.plot_model(model, args.model_path + "model.png")
|
visualize.plot_model(model, args.model_path + "model.png")
|
||||||
print("plot training curve")
|
logger.info("plot training curve")
|
||||||
logs = pd.read_csv(args.model_path + "train.log")
|
logs = pd.read_csv(args.model_path + "train.log")
|
||||||
visualize.plot_training_curve(logs, "{}/train.png".format(args.model_path))
|
visualize.plot_training_curve(logs, "{}/train.png".format(args.model_path))
|
||||||
pred = np.load(y_pred_path)
|
pred = np.load(y_pred_path)
|
||||||
print("plot pr curve")
|
logger.info("plot pr curve")
|
||||||
visualize.plot_precision_recall(mask, pred, "{}/prc.png".format(args.model_path))
|
visualize.plot_precision_recall(mask, pred, "{}/prc.png".format(args.model_path))
|
||||||
visualize.plot_precision_recall_curves(mask, pred, "{}/prc2.png".format(args.model_path))
|
visualize.plot_precision_recall_curves(mask, pred, "{}/prc2.png".format(args.model_path))
|
||||||
print("plot roc curve")
|
logger.info("plot roc curve")
|
||||||
visualize.plot_roc_curve(mask, pred, "{}/roc.png".format(args.model_path))
|
visualize.plot_roc_curve(mask, pred, "{}/roc.png".format(args.model_path))
|
||||||
print("store prediction image")
|
logger.info("store prediction image")
|
||||||
visualize.save_image_as(pred, "{}/pred.png".format(args.model_path))
|
visualize.save_image_as(pred, "{}/pred.png".format(args.model_path))
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user