add regularization to small networks, fix model name in args, fix visualizations
This commit is contained in:
parent
6fef2b8b84
commit
1cf62423e1
@ -102,6 +102,7 @@ parser.add_argument("--new_model", action="store_true", dest="new_model")
|
|||||||
def get_model_args(args):
|
def get_model_args(args):
|
||||||
return [{
|
return [{
|
||||||
"model_path": model_path,
|
"model_path": model_path,
|
||||||
|
"model_name": os.path.split(os.path.normpath(model_path))[1],
|
||||||
"embedding_model": os.path.join(model_path, "embd.h5"),
|
"embedding_model": os.path.join(model_path, "embd.h5"),
|
||||||
"clf_model": os.path.join(model_path, "clf.h5"),
|
"clf_model": os.path.join(model_path, "clf.h5"),
|
||||||
"train_log": os.path.join(model_path, "train.log.csv"),
|
"train_log": os.path.join(model_path, "train.log.csv"),
|
||||||
|
17
fancy.sh
17
fancy.sh
@ -5,19 +5,20 @@ DATADIR=$2
|
|||||||
|
|
||||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_small_final --test ${DATADIR}/futureData.csv --model_output both
|
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_small_final --test ${DATADIR}/futureData.csv --model_output both
|
||||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_small_inter --test ${DATADIR}/futureData.csv --model_output both
|
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_small_inter --test ${DATADIR}/futureData.csv --model_output both
|
||||||
|
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_small_staggered --test ${DATADIR}/futureData.csv --model_output both
|
||||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_small_final --test ${DATADIR}/futureData.csv --model_output client
|
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_small_final --test ${DATADIR}/futureData.csv --model_output client
|
||||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_small_inter --test ${DATADIR}/futureData.csv --model_output client
|
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_small_inter --test ${DATADIR}/futureData.csv --model_output client
|
||||||
|
|
||||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_final --test ${DATADIR}/futureData.csv --model_output both
|
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_final --test ${DATADIR}/futureData.csv --model_output both
|
||||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_inter --test ${DATADIR}/futureData.csv --model_output both
|
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_inter --test ${DATADIR}/futureData.csv --model_output both
|
||||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_final --test ${DATADIR}/futureData.csv --model_output client
|
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_final --test ${DATADIR}/futureData.csv --model_output client
|
||||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_inter --test ${DATADIR}/futureData.csv --model_output client
|
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_inter --test ${DATADIR}/futureData.csv --model_output client
|
||||||
|
|
||||||
python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \
|
#python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \
|
||||||
--models ${RESDIR}/*_small_*/ --out-prefix ${RESDIR}/small
|
# --models ${RESDIR}/*_small_*/ --out-prefix ${RESDIR}/small
|
||||||
|
|
||||||
python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \
|
#python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \
|
||||||
--models ${RESDIR}/*_medium_*/ --out-prefix ${RESDIR}/medium
|
# --models ${RESDIR}/*_medium_*/ --out-prefix ${RESDIR}/medium
|
||||||
|
|
||||||
python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \
|
python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \
|
||||||
--models ${RESDIR}/*/ --out-prefix ${RESDIR}/all
|
--models ${RESDIR}/*/ --out-prefix ${RESDIR}/all
|
32
main.py
32
main.py
@ -289,6 +289,14 @@ def main_visualization():
|
|||||||
df["client_val"] = np.logical_or(df.hits_vt == 1.0, df.hits_trusted >= 3)
|
df["client_val"] = np.logical_or(df.hits_vt == 1.0, df.hits_trusted >= 3)
|
||||||
df_user = df.groupby(df.names).max()
|
df_user = df.groupby(df.names).max()
|
||||||
|
|
||||||
|
paul = dataset.load_predictions("results/paul/")
|
||||||
|
df_paul = pd.DataFrame(data={
|
||||||
|
"names": paul["testNames"].flatten(), "client_pred": paul["testScores"].flatten(),
|
||||||
|
"hits_vt": paul["testLabel"].flatten(), "hits_trusted": paul["testHits"].flatten()
|
||||||
|
})
|
||||||
|
df_paul["client_val"] = np.logical_or(df_paul.hits_vt == 1.0, df_paul.hits_trusted >= 3)
|
||||||
|
df_paul_user = df_paul.groupby(df_paul.names).max()
|
||||||
|
|
||||||
logger.info("plot model")
|
logger.info("plot model")
|
||||||
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
model = load_model(args.clf_model, custom_objects=models.get_metrics())
|
||||||
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
|
visualize.plot_model_as(model, os.path.join(args.model_path, "model.png"))
|
||||||
@ -306,22 +314,26 @@ def main_visualization():
|
|||||||
logger.info("plot pr curve")
|
logger.info("plot pr curve")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), args.model_path)
|
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), args.model_path)
|
||||||
|
visualize.plot_precision_recall(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save("{}/window_client_prc.png".format(args.model_path))
|
visualize.plot_save("{}/window_client_prc.png".format(args.model_path))
|
||||||
|
|
||||||
logger.info("plot roc curve")
|
logger.info("plot roc curve")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), args.model_path)
|
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), args.model_path)
|
||||||
|
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save("{}/window_client_roc.png".format(args.model_path))
|
visualize.plot_save("{}/window_client_roc.png".format(args.model_path))
|
||||||
|
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_precision_recall(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix(), args.model_path)
|
visualize.plot_precision_recall(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix(), args.model_path)
|
||||||
|
visualize.plot_precision_recall(df_paul_user.client_val.as_matrix(), df_paul_user.client_pred.as_matrix(), "paul")
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save("{}/user_client_prc.png".format(args.model_path))
|
visualize.plot_save("{}/user_client_prc.png".format(args.model_path))
|
||||||
|
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
visualize.plot_roc_curve(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix(), args.model_path)
|
visualize.plot_roc_curve(df_user.client_val.as_matrix(), df_user.client_pred.as_matrix(), args.model_path)
|
||||||
|
visualize.plot_roc_curve(df_paul_user.client_val.as_matrix(), df_paul_user.client_pred.as_matrix(), "paul")
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save("{}/user_client_roc.png".format(args.model_path))
|
visualize.plot_save("{}/user_client_roc.png".format(args.model_path))
|
||||||
|
|
||||||
@ -352,11 +364,20 @@ def main_visualize_all():
|
|||||||
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
|
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
paul = dataset.load_predictions("results/paul/")
|
||||||
|
df_paul = pd.DataFrame(data={
|
||||||
|
"names": paul["testNames"].flatten(), "client_pred": paul["testScores"].flatten(),
|
||||||
|
"hits_vt": paul["testLabel"].flatten(), "hits_trusted": paul["testHits"].flatten()
|
||||||
|
})
|
||||||
|
df_paul["client_val"] = np.logical_or(df_paul.hits_vt == 1.0, df_paul.hits_trusted >= 3)
|
||||||
|
df_paul_user = df_paul.groupby(df_paul.names).max()
|
||||||
|
|
||||||
logger.info("plot pr curves")
|
logger.info("plot pr curves")
|
||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
df = load_df(model_args["model_path"])
|
df = load_df(model_args["model_path"])
|
||||||
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_path"])
|
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
|
||||||
|
visualize.plot_precision_recall(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{args.output_prefix}_window_client_prc.png")
|
visualize.plot_save(f"{args.output_prefix}_window_client_prc.png")
|
||||||
|
|
||||||
@ -364,7 +385,8 @@ def main_visualize_all():
|
|||||||
visualize.plot_clf()
|
visualize.plot_clf()
|
||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
df = load_df(model_args["model_path"])
|
df = load_df(model_args["model_path"])
|
||||||
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_path"])
|
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
|
||||||
|
visualize.plot_roc_curve(df_paul.client_val.as_matrix(), df_paul.client_pred.as_matrix(), "paul")
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{args.output_prefix}_window_client_roc.png")
|
visualize.plot_save(f"{args.output_prefix}_window_client_roc.png")
|
||||||
|
|
||||||
@ -373,7 +395,8 @@ def main_visualize_all():
|
|||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
df = load_df(model_args["model_path"])
|
df = load_df(model_args["model_path"])
|
||||||
df = df.groupby(df.names).max()
|
df = df.groupby(df.names).max()
|
||||||
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_path"])
|
visualize.plot_precision_recall(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
|
||||||
|
visualize.plot_precision_recall(df_paul_user.client_val.as_matrix(), df_paul_user.client_pred.as_matrix(), "paul")
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{args.output_prefix}_user_client_prc.png")
|
visualize.plot_save(f"{args.output_prefix}_user_client_prc.png")
|
||||||
|
|
||||||
@ -382,7 +405,8 @@ def main_visualize_all():
|
|||||||
for model_args in get_model_args(args):
|
for model_args in get_model_args(args):
|
||||||
df = load_df(model_args["model_path"])
|
df = load_df(model_args["model_path"])
|
||||||
df = df.groupby(df.names).max()
|
df = df.groupby(df.names).max()
|
||||||
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_path"])
|
visualize.plot_roc_curve(df.client_val.as_matrix(), df.client_pred.as_matrix(), model_args["model_name"])
|
||||||
|
visualize.plot_roc_curve(df_paul_user.client_val.as_matrix(), df_paul_user.client_pred.as_matrix(), "paul")
|
||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")
|
visualize.plot_save(f"{args.output_prefix}_user_client_roc.png")
|
||||||
|
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import keras
|
import keras
|
||||||
from keras.engine import Input, Model as KerasModel
|
from keras.engine import Input, Model as KerasModel
|
||||||
from keras.layers import Embedding, Conv1D, GlobalMaxPooling1D, Dense, Dropout, Activation, TimeDistributed
|
from keras.layers import Activation, Conv1D, Dense, Dropout, Embedding, GlobalMaxPooling1D, TimeDistributed
|
||||||
|
from keras.regularizers import l2
|
||||||
|
|
||||||
import dataset
|
import dataset
|
||||||
|
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
Model = namedtuple("Model", ["in_domains", "in_flows", "out_client", "out_server"])
|
Model = namedtuple("Model", ["in_domains", "in_flows", "out_client", "out_server"])
|
||||||
|
|
||||||
best_config = {
|
best_config = {
|
||||||
@ -33,10 +34,14 @@ best_config = {
|
|||||||
def get_embedding(embedding_size, input_length, filter_size, kernel_size, hidden_dims, drop_out=0.5) -> KerasModel:
|
def get_embedding(embedding_size, input_length, filter_size, kernel_size, hidden_dims, drop_out=0.5) -> KerasModel:
|
||||||
x = y = Input(shape=(input_length,))
|
x = y = Input(shape=(input_length,))
|
||||||
y = Embedding(input_dim=dataset.get_vocab_size(), output_dim=embedding_size)(y)
|
y = Embedding(input_dim=dataset.get_vocab_size(), output_dim=embedding_size)(y)
|
||||||
y = Conv1D(filter_size, kernel_size, activation='relu')(y)
|
y = Conv1D(filter_size,
|
||||||
|
kernel_size,
|
||||||
|
kernel_regularizer=l2(0.01),
|
||||||
|
activation='relu')(y)
|
||||||
y = GlobalMaxPooling1D()(y)
|
y = GlobalMaxPooling1D()(y)
|
||||||
y = Dropout(drop_out)(y)
|
y = Dropout(drop_out)(y)
|
||||||
y = Dense(hidden_dims)(y)
|
y = Dense(hidden_dims,
|
||||||
|
kernel_regularizer=l2(0.01))(y)
|
||||||
y = Activation('relu')(y)
|
y = Activation('relu')(y)
|
||||||
return KerasModel(x, y)
|
return KerasModel(x, y)
|
||||||
|
|
||||||
@ -50,12 +55,13 @@ def get_model(cnnDropout, flow_features, domain_features, window_size, domain_le
|
|||||||
# CNN processing a small slides of flow windows
|
# CNN processing a small slides of flow windows
|
||||||
y = Conv1D(cnn_dims,
|
y = Conv1D(cnn_dims,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
|
kernel_regularizer=l2(0.01),
|
||||||
activation='relu',
|
activation='relu',
|
||||||
input_shape=(window_size, domain_features + flow_features))(merged)
|
input_shape=(window_size, domain_features + flow_features))(merged)
|
||||||
# remove temporal dimension by global max pooling
|
# remove temporal dimension by global max pooling
|
||||||
y = GlobalMaxPooling1D()(y)
|
y = GlobalMaxPooling1D()(y)
|
||||||
y = Dropout(cnnDropout)(y)
|
y = Dropout(cnnDropout)(y)
|
||||||
y = Dense(dense_dim, activation='relu')(y)
|
y = Dense(dense_dim, kernel_regularizer=l2(0.01), activation='relu')(y)
|
||||||
out_client = Dense(1, activation='sigmoid', name="client")(y)
|
out_client = Dense(1, activation='sigmoid', name="client")(y)
|
||||||
out_server = Dense(1, activation='sigmoid', name="server")(y)
|
out_server = Dense(1, activation='sigmoid', name="server")(y)
|
||||||
|
|
||||||
@ -68,18 +74,25 @@ def get_new_model(dropout, flow_features, domain_features, window_size, domain_l
|
|||||||
ipt_flows = Input(shape=(window_size, flow_features), name="ipt_flows")
|
ipt_flows = Input(shape=(window_size, flow_features), name="ipt_flows")
|
||||||
encoded = TimeDistributed(cnn)(ipt_domains)
|
encoded = TimeDistributed(cnn)(ipt_domains)
|
||||||
merged = keras.layers.concatenate([encoded, ipt_flows], -1)
|
merged = keras.layers.concatenate([encoded, ipt_flows], -1)
|
||||||
y = Dense(dense_dim, activation="relu", name="dense_server")(merged)
|
y = Dense(dense_dim,
|
||||||
|
kernel_regularizer=l2(0.01),
|
||||||
|
activation="relu",
|
||||||
|
name="dense_server")(merged)
|
||||||
out_server = Dense(1, activation="sigmoid", name="server")(y)
|
out_server = Dense(1, activation="sigmoid", name="server")(y)
|
||||||
merged = keras.layers.concatenate([merged, y], -1)
|
merged = keras.layers.concatenate([merged, y], -1)
|
||||||
# CNN processing a small slides of flow windows
|
# CNN processing a small slides of flow windows
|
||||||
y = Conv1D(cnn_dims,
|
y = Conv1D(cnn_dims,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
|
kernel_regularizer=l2(0.01),
|
||||||
activation='relu',
|
activation='relu',
|
||||||
input_shape=(window_size, domain_features + flow_features))(merged)
|
input_shape=(window_size, domain_features + flow_features))(merged)
|
||||||
# remove temporal dimension by global max pooling
|
# remove temporal dimension by global max pooling
|
||||||
y = GlobalMaxPooling1D()(y)
|
y = GlobalMaxPooling1D()(y)
|
||||||
y = Dropout(dropout)(y)
|
y = Dropout(dropout)(y)
|
||||||
y = Dense(dense_dim, activation='relu', name="dense_client")(y)
|
y = Dense(dense_dim,
|
||||||
|
kernel_regularizer=l2(0.01),
|
||||||
|
activation='relu',
|
||||||
|
name="dense_client")(y)
|
||||||
|
|
||||||
out_client = Dense(1, activation='sigmoid', name="client")(y)
|
out_client = Dense(1, activation='sigmoid', name="client")(y)
|
||||||
|
|
||||||
|
9
run.sh
9
run.sh
@ -5,6 +5,8 @@ RESDIR=$1
|
|||||||
mkdir -p /tmp/rk/${RESDIR}
|
mkdir -p /tmp/rk/${RESDIR}
|
||||||
DATADIR=$2
|
DATADIR=$2
|
||||||
|
|
||||||
|
EPOCHS=100
|
||||||
|
|
||||||
for output in client both
|
for output in client both
|
||||||
do
|
do
|
||||||
for depth in small
|
for depth in small
|
||||||
@ -15,7 +17,7 @@ do
|
|||||||
python main.py --mode train \
|
python main.py --mode train \
|
||||||
--train ${DATADIR}/currentData.csv \
|
--train ${DATADIR}/currentData.csv \
|
||||||
--model ${RESDIR}/${output}_${depth}_${mtype} \
|
--model ${RESDIR}/${output}_${depth}_${mtype} \
|
||||||
--epochs 50 \
|
--epochs $EPOCHS \
|
||||||
--embd 128 \
|
--embd 128 \
|
||||||
--filter_embd 256 --kernel_embd 8 --dense_embd 128 \
|
--filter_embd 256 --kernel_embd 8 --dense_embd 128 \
|
||||||
--domain_embd 32 \
|
--domain_embd 32 \
|
||||||
@ -35,7 +37,7 @@ do
|
|||||||
python main.py --mode train \
|
python main.py --mode train \
|
||||||
--train ${DATADIR}/currentData.csv \
|
--train ${DATADIR}/currentData.csv \
|
||||||
--model ${RESDIR}/both_${depth}_staggered \
|
--model ${RESDIR}/both_${depth}_staggered \
|
||||||
--epochs 50 \
|
--epochs $EPOCHS \
|
||||||
--embd 128 \
|
--embd 128 \
|
||||||
--filter_embd 256 --kernel_embd 8 --dense_embd 128 \
|
--filter_embd 256 --kernel_embd 8 --dense_embd 128 \
|
||||||
--domain_embd 32 \
|
--domain_embd 32 \
|
||||||
@ -46,3 +48,6 @@ do
|
|||||||
--type staggered \
|
--type staggered \
|
||||||
--depth ${depth}
|
--depth ${depth}
|
||||||
done
|
done
|
||||||
|
|
||||||
|
# python main.py --mode train --epochs 100 --embd 64 --filter_embd 128 --kernel_embd 5 --dense_embd 128 --domain_embd 32 --filter_main 32 --kernel_main 5 --dense_main 512 --batch 256 --balanced_weights --model_output ${output} --type ${mtype} --depth ${depth} --train ${DATADIR}/currentData.csv --model ${RESDIR}/${output}_${depth}_${mtype}
|
||||||
|
# python main.py --mode train --epochs 100 --embd 64 --filter_embd 128 --kernel_embd 5 --dense_embd 128 --domain_embd 32 --filter_main 32 --kernel_main 5 --dense_main 512 --batch 256 --balanced_weights --model_output ${output} --type ${mtype} --depth ${depth} --train ${DATADIR}/currentData.csv --model ${RESDIR}/${output}_${depth}_${mtype}
|
||||||
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
from sklearn.decomposition import TruncatedSVD
|
from sklearn.decomposition import TruncatedSVD
|
||||||
from sklearn.metrics import (
|
from sklearn.metrics import (
|
||||||
auc, classification_report, confusion_matrix, fbeta_score, precision_recall_curve,
|
auc, classification_report, confusion_matrix, fbeta_score, precision_recall_curve,
|
||||||
roc_auc_score, roc_curve, average_precision_score
|
roc_auc_score, roc_curve
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ def plot_precision_recall(y, y_pred, label=""):
|
|||||||
# ax.hold(True)
|
# ax.hold(True)
|
||||||
score = fbeta_score(y, y_pred.round(), 1)
|
score = fbeta_score(y, y_pred.round(), 1)
|
||||||
# prc_ap = average_precision_score(y, y_pred)
|
# prc_ap = average_precision_score(y, y_pred)
|
||||||
plt.plot(recall, precision, '--', label=f"{label} - {score}")
|
plt.plot(recall, precision, '--', label=f"{label} - {score:5.4}")
|
||||||
# ax.step(recall[::-1], decreasing_max_precision, '-r')
|
# ax.step(recall[::-1], decreasing_max_precision, '-r')
|
||||||
plt.xlabel('Recall')
|
plt.xlabel('Recall')
|
||||||
plt.ylabel('Precision')
|
plt.ylabel('Precision')
|
||||||
@ -90,7 +90,7 @@ def plot_roc_curve(mask, prediction, label=""):
|
|||||||
fpr, tpr, thresholds = roc_curve(y, y_pred)
|
fpr, tpr, thresholds = roc_curve(y, y_pred)
|
||||||
roc_auc = auc(fpr, tpr)
|
roc_auc = auc(fpr, tpr)
|
||||||
plt.xscale('log')
|
plt.xscale('log')
|
||||||
plt.plot(fpr, tpr, label=f"{label} - {roc_auc}")
|
plt.plot(fpr, tpr, label=f"{label} - {roc_auc:5.4}")
|
||||||
|
|
||||||
|
|
||||||
def plot_confusion_matrix(y_true, y_pred, path,
|
def plot_confusion_matrix(y_true, y_pred, path,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user