refactor visualization, change arguments for model type and its depth
This commit is contained in:
parent
933eaae04a
commit
dc9180da10
31
Makefile
31
Makefile
@ -1,16 +1,33 @@
|
|||||||
run:
|
run:
|
||||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test --epochs 10 \
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test1 --epochs 10 --depth small \
|
||||||
--hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights
|
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type final
|
||||||
|
|
||||||
run_new:
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test2 --epochs 10 --depth small \
|
||||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test2 --epochs 10 \
|
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type inter
|
||||||
--hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights --new_model
|
|
||||||
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test3 --epochs 10 --depth medium \
|
||||||
|
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type final
|
||||||
|
|
||||||
|
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test4 --epochs 10 --depth medium \
|
||||||
|
--hidden_char_dims 16 --domain_embd 8 --batch 64 --balanced_weights --type inter
|
||||||
|
|
||||||
test:
|
test:
|
||||||
python3 main.py --mode test --batch 128 --model results/test --test data/rk_mini.csv.gz
|
python3 main.py --mode test --batch 128 --models results/test* --test data/rk_mini.csv.gz
|
||||||
|
|
||||||
fancy:
|
fancy:
|
||||||
python3 main.py --mode fancy --batch 128 --model results/test --test data/rk_mini.csv.gz
|
python3 main.py --mode fancy --batch 128 --model results/test1 --test data/rk_mini.csv.gz
|
||||||
|
|
||||||
|
python3 main.py --mode fancy --batch 128 --model results/test2 --test data/rk_mini.csv.gz
|
||||||
|
|
||||||
|
python3 main.py --mode fancy --batch 128 --model results/test3 --test data/rk_mini.csv.gz
|
||||||
|
|
||||||
|
python3 main.py --mode fancy --batch 128 --model results/test4 --test data/rk_mini.csv.gz
|
||||||
|
|
||||||
|
all-fancy:
|
||||||
|
python3 main.py --mode all_fancy --batch 128 --models results/test* --test data/rk_mini.csv.gz
|
||||||
|
|
||||||
hyper:
|
hyper:
|
||||||
python3 main.py --mode hyperband --batch 64 --train data/rk_data.csv.gz
|
python3 main.py --mode hyperband --batch 64 --train data/rk_data.csv.gz
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -r results/test*
|
19
arguments.py
19
arguments.py
@ -19,8 +19,14 @@ parser.add_argument("--test", action="store", dest="test_data",
|
|||||||
parser.add_argument("--model", action="store", dest="model_path",
|
parser.add_argument("--model", action="store", dest="model_path",
|
||||||
default="results/model_x")
|
default="results/model_x")
|
||||||
|
|
||||||
|
parser.add_argument("--models", action="store", dest="model_paths", nargs="+",
|
||||||
|
default=[])
|
||||||
|
|
||||||
parser.add_argument("--type", action="store", dest="model_type",
|
parser.add_argument("--type", action="store", dest="model_type",
|
||||||
default="paul")
|
default="final") # inter, final, staggered
|
||||||
|
|
||||||
|
parser.add_argument("--depth", action="store", dest="model_depth",
|
||||||
|
default="small") # small, medium
|
||||||
|
|
||||||
parser.add_argument("--model_output", action="store", dest="model_output",
|
parser.add_argument("--model_output", action="store", dest="model_output",
|
||||||
default="both")
|
default="both")
|
||||||
@ -74,6 +80,17 @@ parser.add_argument("--new_model", action="store_true", dest="new_model")
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_args(args):
|
||||||
|
return [{
|
||||||
|
"model_path": model_path,
|
||||||
|
"embedding_model": os.path.join(model_path, "embd.h5"),
|
||||||
|
"clf_model": os.path.join(model_path, "clf.h5"),
|
||||||
|
"train_log": os.path.join(model_path, "train.log.csv"),
|
||||||
|
"train_h5data": args.train_data + ".h5",
|
||||||
|
"test_h5data": args.test_data + ".h5",
|
||||||
|
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred.h5")
|
||||||
|
} for model_path in args.model_paths]
|
||||||
|
|
||||||
def parse():
|
def parse():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
||||||
|
@ -233,11 +233,11 @@ def load_or_generate_h5data(h5data, train_data, domain_length, window_size):
|
|||||||
|
|
||||||
def load_or_generate_domains(train_data, domain_length):
|
def load_or_generate_domains(train_data, domain_length):
|
||||||
fn = f"{train_data}_domains.gz"
|
fn = f"{train_data}_domains.gz"
|
||||||
|
char_dict = get_character_dict()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_flow_df = pd.read_csv(fn)
|
user_flow_df = pd.read_csv(fn)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
char_dict = get_character_dict()
|
|
||||||
user_flow_df = get_user_flow_data(train_data)
|
user_flow_df = get_user_flow_data(train_data)
|
||||||
# user_flow_df.reset_index(inplace=True)
|
# user_flow_df.reset_index(inplace=True)
|
||||||
user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0,
|
user_flow_df = user_flow_df[["domain", "serverLabel", "trustedHits", "virusTotalHits"]].dropna(axis=0,
|
||||||
|
67
main.py
67
main.py
@ -16,6 +16,7 @@ import models
|
|||||||
import visualize
|
import visualize
|
||||||
from dataset import load_or_generate_h5data
|
from dataset import load_or_generate_h5data
|
||||||
from utils import exists_or_make_path, get_custom_class_weights
|
from utils import exists_or_make_path, get_custom_class_weights
|
||||||
|
from arguments import get_model_args
|
||||||
|
|
||||||
logger = logging.getLogger('logger')
|
logger = logging.getLogger('logger')
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
@ -56,6 +57,7 @@ if args.gpu:
|
|||||||
# default parameter
|
# default parameter
|
||||||
PARAMS = {
|
PARAMS = {
|
||||||
"type": args.model_type,
|
"type": args.model_type,
|
||||||
|
"depth": args.model_depth,
|
||||||
"batch_size": 64,
|
"batch_size": 64,
|
||||||
"window_size": args.window,
|
"window_size": args.window,
|
||||||
"domain_length": args.domain_length,
|
"domain_length": args.domain_length,
|
||||||
@ -149,8 +151,8 @@ def main_train(param=None):
|
|||||||
logger.info("class weights: set default")
|
logger.info("class weights: set default")
|
||||||
custom_class_weights = None
|
custom_class_weights = None
|
||||||
|
|
||||||
logger.info(f"select model: {'new' if args.new_model else 'old'}")
|
logger.info(f"select model: {args.model_type}")
|
||||||
if args.new_model:
|
if args.model_type == "inter":
|
||||||
server_tr = np.expand_dims(server_windows_tr, 2)
|
server_tr = np.expand_dims(server_windows_tr, 2)
|
||||||
model = new_model
|
model = new_model
|
||||||
logger.info("compile and train model")
|
logger.info("compile and train model")
|
||||||
@ -181,12 +183,19 @@ def main_train(param=None):
|
|||||||
|
|
||||||
|
|
||||||
def main_test():
|
def main_test():
|
||||||
|
logger.info("start test: load data")
|
||||||
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
||||||
args.domain_length, args.window)
|
args.domain_length, args.window)
|
||||||
clf = load_model(args.clf_model, custom_objects=models.get_metrics())
|
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||||
pred = clf.predict([domain_val, flow_val],
|
|
||||||
|
for model_args in get_model_args(args):
|
||||||
|
logger.info(f"process model {model_args['model_path']}")
|
||||||
|
clf_model = load_model(model_args["clf_model"], custom_objects=models.get_metrics())
|
||||||
|
|
||||||
|
pred = clf_model.predict([domain_val, flow_val],
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
verbose=1)
|
verbose=1)
|
||||||
|
|
||||||
if args.model_output == "both":
|
if args.model_output == "both":
|
||||||
c_pred, s_pred = pred
|
c_pred, s_pred = pred
|
||||||
elif args.model_output == "client":
|
elif args.model_output == "client":
|
||||||
@ -195,21 +204,23 @@ def main_test():
|
|||||||
else:
|
else:
|
||||||
c_pred = np.zeros(0)
|
c_pred = np.zeros(0)
|
||||||
s_pred = pred
|
s_pred = pred
|
||||||
dataset.save_predictions(args.future_prediction, c_pred, s_pred)
|
dataset.save_predictions(model_args["future_prediction"], c_pred, s_pred)
|
||||||
|
|
||||||
model = load_model(args.embedding_model)
|
embd_model = load_model(model_args["embedding_model"])
|
||||||
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
domain_embeddings = embd_model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||||
domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
np.save(model_args["model_path"] + "/domain_embds.npy", domain_embeddings)
|
||||||
np.save(args.model_path + "/domain_embds.npy", domain_embedding)
|
|
||||||
|
|
||||||
|
|
||||||
def main_visualization():
|
def main_visualization():
|
||||||
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
||||||
args.domain_length, args.window)
|
args.domain_length, args.window)
|
||||||
client_val, server_val = client_val.value, server_val.value
|
# client_val, server_val = client_val.value, server_val.value
|
||||||
logger.info("plot model")
|
client_val = client_val.value
|
||||||
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
|
||||||
visualize.plot_model(model, os.path.join(args.model_path, "model.png"))
|
# logger.info("plot model")
|
||||||
|
# model = load_model(model_args.clf_model, custom_objects=models.get_metrics())
|
||||||
|
# visualize.plot_model(model, os.path.join(args.model_path, "model.png"))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("plot training curve")
|
logger.info("plot training curve")
|
||||||
logs = pd.read_csv(args.train_log)
|
logs = pd.read_csv(args.train_log)
|
||||||
@ -224,12 +235,16 @@ def main_visualization():
|
|||||||
client_pred, server_pred = dataset.load_predictions(args.future_prediction)
|
client_pred, server_pred = dataset.load_predictions(args.future_prediction)
|
||||||
client_pred, server_pred = client_pred.value, server_pred.value
|
client_pred, server_pred = client_pred.value, server_pred.value
|
||||||
logger.info("plot pr curve")
|
logger.info("plot pr curve")
|
||||||
visualize.plot_precision_recall(client_val, client_pred.flatten(), "{}/client_prc.png".format(args.model_path))
|
visualize.plot_clf()
|
||||||
|
visualize.plot_precision_recall(client_val, client_pred)
|
||||||
|
visualize.plot_save("{}/client_prc.png".format(args.model_path))
|
||||||
# visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path))
|
# visualize.plot_precision_recall(server_val, server_pred, "{}/server_prc.png".format(args.model_path))
|
||||||
# visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path))
|
# visualize.plot_precision_recall_curves(client_val, client_pred, "{}/client_prc2.png".format(args.model_path))
|
||||||
# visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path))
|
# visualize.plot_precision_recall_curves(server_val, server_pred, "{}/server_prc2.png".format(args.model_path))
|
||||||
logger.info("plot roc curve")
|
logger.info("plot roc curve")
|
||||||
visualize.plot_roc_curve(client_val, client_pred.flatten(), "{}/client_roc.png".format(args.model_path))
|
visualize.plot_clf()
|
||||||
|
visualize.plot_roc_curve(client_val, client_pred)
|
||||||
|
visualize.plot_save("{}/client_roc.png".format(args.model_path))
|
||||||
# visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path))
|
# visualize.plot_roc_curve(server_val, server_pred, "{}/server_roc.png".format(args.model_path))
|
||||||
visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(),
|
visualize.plot_confusion_matrix(client_val, client_pred.flatten().round(),
|
||||||
"{}/client_cov.png".format(args.model_path),
|
"{}/client_cov.png".format(args.model_path),
|
||||||
@ -243,6 +258,26 @@ def main_visualization():
|
|||||||
visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
|
visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
|
||||||
|
|
||||||
|
|
||||||
|
def main_visualize_all():
|
||||||
|
domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data,
|
||||||
|
args.domain_length, args.window)
|
||||||
|
logger.info("plot pr curves")
|
||||||
|
visualize.plot_clf()
|
||||||
|
for model_args in get_model_args(args):
|
||||||
|
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||||
|
visualize.plot_precision_recall(client_val.value, client_pred.value, model_args["model_path"])
|
||||||
|
visualize.plot_legend()
|
||||||
|
visualize.plot_save("all_client_prc.png")
|
||||||
|
|
||||||
|
logger.info("plot roc curves")
|
||||||
|
visualize.plot_clf()
|
||||||
|
for model_args in get_model_args(args):
|
||||||
|
client_pred, server_pred = dataset.load_predictions(model_args["future_prediction"])
|
||||||
|
visualize.plot_roc_curve(client_val.value, client_pred.value, model_args["model_path"])
|
||||||
|
visualize.plot_legend()
|
||||||
|
visualize.plot_save("all_client_roc.png")
|
||||||
|
|
||||||
|
|
||||||
def main_data():
|
def main_data():
|
||||||
char_dict = dataset.get_character_dict()
|
char_dict = dataset.get_character_dict()
|
||||||
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
user_flow_df = dataset.get_user_flow_data(args.train_data)
|
||||||
@ -265,6 +300,8 @@ def main():
|
|||||||
main_test()
|
main_test()
|
||||||
if "fancy" == args.mode:
|
if "fancy" == args.mode:
|
||||||
main_visualization()
|
main_visualization()
|
||||||
|
if "all_fancy" == args.mode:
|
||||||
|
main_visualize_all()
|
||||||
if "paul" == args.mode:
|
if "paul" == args.mode:
|
||||||
main_paul_best()
|
main_paul_best()
|
||||||
if "data" == args.mode:
|
if "data" == args.mode:
|
||||||
|
@ -8,6 +8,7 @@ def get_models_by_params(params: dict):
|
|||||||
# decomposing param section
|
# decomposing param section
|
||||||
# mainly embedding model
|
# mainly embedding model
|
||||||
network_type = params.get("type")
|
network_type = params.get("type")
|
||||||
|
network_depth = params.get("depth")
|
||||||
embedding_size = params.get("embedding_size")
|
embedding_size = params.get("embedding_size")
|
||||||
input_length = params.get("input_length")
|
input_length = params.get("input_length")
|
||||||
filter_embedding = params.get("filter_embedding")
|
filter_embedding = params.get("filter_embedding")
|
||||||
@ -24,7 +25,12 @@ def get_models_by_params(params: dict):
|
|||||||
dense_dim = params.get("dense_main")
|
dense_dim = params.get("dense_main")
|
||||||
model_output = params.get("model_output", "both")
|
model_output = params.get("model_output", "both")
|
||||||
# create models
|
# create models
|
||||||
networks = renes_networks if network_type == "rene" else pauls_networks
|
if network_depth == "small":
|
||||||
|
networks = pauls_networks
|
||||||
|
elif network_depth == "medium":
|
||||||
|
networks = renes_networks
|
||||||
|
else:
|
||||||
|
raise Exception("network not found")
|
||||||
embedding_model = networks.get_embedding(embedding_size, input_length, filter_embedding, kernel_embedding,
|
embedding_model = networks.get_embedding(embedding_size, input_length, filter_embedding, kernel_embedding,
|
||||||
hidden_embedding, dropout)
|
hidden_embedding, dropout)
|
||||||
|
|
||||||
|
106
run.sh
106
run.sh
@ -1,98 +1,26 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
python main.py --mode train \
|
|
||||||
|
for output in client both
|
||||||
|
do
|
||||||
|
for depth in small medium
|
||||||
|
do
|
||||||
|
for mtype in inter final staggered
|
||||||
|
do
|
||||||
|
|
||||||
|
python main.py --mode train \
|
||||||
--train /tmp/rk/currentData.csv \
|
--train /tmp/rk/currentData.csv \
|
||||||
--model /tmp/rk/results/small_both \
|
--model /tmp/rk/results/${output}_${depth}_${mtype} \
|
||||||
--epochs 25 \
|
--epochs 50 \
|
||||||
--embd 64 \
|
--embd 64 \
|
||||||
--hidden_chaar_dims 128 \
|
--hidden_chaar_dims 128 \
|
||||||
--domain_embd 32 \
|
--domain_embd 32 \
|
||||||
--batch 256 \
|
--batch 256 \
|
||||||
--balanced_weights \
|
--balanced_weights \
|
||||||
--model_output both
|
--model_output ${output} \
|
||||||
|
--type ${mtype} \
|
||||||
|
--depth ${depth}
|
||||||
|
|
||||||
python main.py --mode train \
|
done
|
||||||
--train /tmp/rk/currentData.csv \
|
done
|
||||||
--model /tmp/rk/results/small_client \
|
done
|
||||||
--epochs 25 \
|
|
||||||
--embd 64 \
|
|
||||||
--hidden_char_dims 128 \
|
|
||||||
--domain_embd 32 \
|
|
||||||
--batch 256 \
|
|
||||||
--balanced_weights \
|
|
||||||
--model_output client
|
|
||||||
|
|
||||||
python main.py --mode train \
|
|
||||||
--train /tmp/rk/currentData.csv \
|
|
||||||
--model /tmp/rk/results/small_new_both \
|
|
||||||
--epochs 25 \
|
|
||||||
--embd 64 \
|
|
||||||
--hidden_char_dims 128 \
|
|
||||||
--domain_embd 32 \
|
|
||||||
--batch 256 \
|
|
||||||
--balanced_weights \
|
|
||||||
--model_output both \
|
|
||||||
--new_model
|
|
||||||
|
|
||||||
python main.py --mode train \
|
|
||||||
--train /tmp/rk/currentData.csv \
|
|
||||||
--model /tmp/rk/results/small_new_client \
|
|
||||||
--epochs 25 \
|
|
||||||
--embd 64 \
|
|
||||||
--hidden_char_dims 128 \
|
|
||||||
--domain_embd 32 \
|
|
||||||
--batch 256 \
|
|
||||||
--balanced_weights \
|
|
||||||
--model_output client \
|
|
||||||
--new_model
|
|
||||||
##
|
|
||||||
##
|
|
||||||
python main.py --mode train \
|
|
||||||
--train /tmp/rk/currentData.csv \
|
|
||||||
--model /tmp/rk/results/medium_both \
|
|
||||||
--epochs 25 \
|
|
||||||
--embd 64 \
|
|
||||||
--hidden_char_dims 128 \
|
|
||||||
--domain_embd 32 \
|
|
||||||
--batch 256 \
|
|
||||||
--balanced_weights \
|
|
||||||
--model_output both \
|
|
||||||
--type rene
|
|
||||||
|
|
||||||
python main.py --mode train \
|
|
||||||
--train /tmp/rk/currentData.csv \
|
|
||||||
--model /tmp/rk/results/medium_client \
|
|
||||||
--epochs 25 \
|
|
||||||
--embd 64 \
|
|
||||||
--hidden_char_dims 128 \
|
|
||||||
--domain_embd 32 \
|
|
||||||
--batch 256 \
|
|
||||||
--balanced_weights \
|
|
||||||
--model_output client \
|
|
||||||
--type rene
|
|
||||||
|
|
||||||
python main.py --mode train \
|
|
||||||
--train /tmp/rk/currentData.csv \
|
|
||||||
--model /tmp/rk/results/medium_new_both \
|
|
||||||
--epochs 25 \
|
|
||||||
--embd 64 \
|
|
||||||
--hidden_char_dims 128 \
|
|
||||||
--domain_embd 32 \
|
|
||||||
--batch 256 \
|
|
||||||
--balanced_weights \
|
|
||||||
--model_output both \
|
|
||||||
--new_model \
|
|
||||||
--type rene
|
|
||||||
|
|
||||||
python main.py --mode train \
|
|
||||||
--train /tmp/rk/currentData.csv \
|
|
||||||
--model /tmp/rk/results/medium_new_client \
|
|
||||||
--epochs 25 \
|
|
||||||
--embd 64 \
|
|
||||||
--hidden_char_dims 128 \
|
|
||||||
--domain_embd 32 \
|
|
||||||
--batch 256 \
|
|
||||||
--balanced_weights \
|
|
||||||
--model_output client \
|
|
||||||
--new_model \
|
|
||||||
--type rene
|
|
||||||
|
45
visualize.py
45
visualize.py
@ -2,7 +2,6 @@ import os
|
|||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from keras.utils.vis_utils import plot_model
|
|
||||||
from sklearn.decomposition import TruncatedSVD
|
from sklearn.decomposition import TruncatedSVD
|
||||||
from sklearn.metrics import (
|
from sklearn.metrics import (
|
||||||
auc, classification_report, confusion_matrix, fbeta_score, precision_recall_curve,
|
auc, classification_report, confusion_matrix, fbeta_score, precision_recall_curve,
|
||||||
@ -32,38 +31,43 @@ def scores(y_true, y_pred):
|
|||||||
print(" f0.5 score:", f05_score)
|
print(" f0.5 score:", f05_score)
|
||||||
|
|
||||||
|
|
||||||
def plot_precision_recall(mask, prediction, path):
|
def plot_clf():
|
||||||
y = mask.flatten()
|
plt.clf()
|
||||||
y_pred = prediction.flatten()
|
|
||||||
|
|
||||||
|
def plot_save(path, dpi=600):
|
||||||
|
plt.savefig(path, dpi=dpi)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_legend():
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_precision_recall(y, y_pred, label=""):
|
||||||
|
y = y.flatten()
|
||||||
|
y_pred = y_pred.flatten()
|
||||||
precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
||||||
decreasing_max_precision = np.maximum.accumulate(precision)[::-1]
|
decreasing_max_precision = np.maximum.accumulate(precision)[::-1]
|
||||||
|
|
||||||
plt.clf()
|
|
||||||
# fig, ax = plt.subplots(1, 1)
|
# fig, ax = plt.subplots(1, 1)
|
||||||
# ax.hold(True)
|
# ax.hold(True)
|
||||||
plt.plot(recall, precision, '--b')
|
plt.plot(recall, precision, '--', label=label)
|
||||||
# ax.step(recall[::-1], decreasing_max_precision, '-r')
|
# ax.step(recall[::-1], decreasing_max_precision, '-r')
|
||||||
plt.xlabel('Recall')
|
plt.xlabel('Recall')
|
||||||
plt.ylabel('Precision')
|
plt.ylabel('Precision')
|
||||||
|
|
||||||
plt.savefig(path, dpi=600)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
|
def plot_precision_recall_curves(y, y_pred):
|
||||||
def plot_precision_recall_curves(mask, prediction, path):
|
y = y.flatten()
|
||||||
y = mask.flatten()
|
y_pred = y_pred.flatten()
|
||||||
y_pred = prediction.flatten()
|
|
||||||
precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
precision, recall, thresholds = precision_recall_curve(y, y_pred)
|
||||||
|
|
||||||
plt.clf()
|
|
||||||
plt.plot(recall, label="Recall")
|
plt.plot(recall, label="Recall")
|
||||||
plt.plot(precision, label="Precision")
|
plt.plot(precision, label="Precision")
|
||||||
plt.xlabel('Threshold')
|
plt.xlabel('Threshold')
|
||||||
plt.ylabel('Score')
|
plt.ylabel('Score')
|
||||||
|
|
||||||
plt.savefig(path, dpi=600)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
|
|
||||||
def score_model(y, prediction):
|
def score_model(y, prediction):
|
||||||
y = y.flatten()
|
y = y.flatten()
|
||||||
@ -78,16 +82,12 @@ def score_model(y, prediction):
|
|||||||
print("F0.5 Score", fbeta_score(y, y_pred.round(), 0.5))
|
print("F0.5 Score", fbeta_score(y, y_pred.round(), 0.5))
|
||||||
|
|
||||||
|
|
||||||
def plot_roc_curve(mask, prediction, path):
|
def plot_roc_curve(mask, prediction, label=""):
|
||||||
y = mask.flatten()
|
y = mask.flatten()
|
||||||
y_pred = prediction.flatten()
|
y_pred = prediction.flatten()
|
||||||
fpr, tpr, thresholds = roc_curve(y, y_pred)
|
fpr, tpr, thresholds = roc_curve(y, y_pred)
|
||||||
roc_auc = auc(fpr, tpr)
|
roc_auc = auc(fpr, tpr)
|
||||||
plt.clf()
|
plt.plot(fpr, tpr, label=label)
|
||||||
plt.plot(fpr, tpr)
|
|
||||||
plt.savefig(path, dpi=600)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
print("roc_auc", roc_auc)
|
print("roc_auc", roc_auc)
|
||||||
|
|
||||||
|
|
||||||
@ -161,4 +161,5 @@ def plot_embedding(domain_embedding, labels, path, dpi=600):
|
|||||||
|
|
||||||
|
|
||||||
def plot_model_as(model, path):
|
def plot_model_as(model, path):
|
||||||
|
from keras.utils.vis_utils import plot_model
|
||||||
plot_model(model, to_file=path, show_shapes=True, show_layer_names=True)
|
plot_model(model, to_file=path, show_shapes=True, show_layer_names=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user