add parser argument for naming in multi model modes, minor fixes,
This commit is contained in:
parent
1da31cc97c
commit
ed4f478bad
@ -58,6 +58,9 @@ parser.add_argument("--domain_length", action="store", dest="domain_length",
|
||||
parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
|
||||
default=512, type=int)
|
||||
|
||||
parser.add_argument("--out-prefix", action="store", dest="output_prefix",
|
||||
default="", type=str)
|
||||
|
||||
|
||||
# parser.add_argument("--queue", action="store", dest="queue_size",
|
||||
# default=50, type=int)
|
||||
|
@ -75,7 +75,7 @@ def get_new_model(dropout, flow_features, domain_features, window_size, domain_l
|
||||
y = Conv1D(cnn_dims,
|
||||
kernel_size,
|
||||
activation='relu',
|
||||
input_shape=(window_size, domain_features + flow_features))(merged)
|
||||
input_shape=(window_size, domain_features + flow_features))(y)
|
||||
# remove temporal dimension by global max pooling
|
||||
y = GlobalMaxPooling1D()(y)
|
||||
y = Dropout(dropout)(y)
|
||||
|
20
run.sh
20
run.sh
@ -5,7 +5,7 @@ for output in client both
|
||||
do
|
||||
for depth in small medium
|
||||
do
|
||||
for mtype in inter final staggered
|
||||
for mtype in inter final
|
||||
do
|
||||
|
||||
python main.py --mode train \
|
||||
@ -13,7 +13,7 @@ do
|
||||
--model /tmp/rk/results/${output}_${depth}_${mtype} \
|
||||
--epochs 50 \
|
||||
--embd 64 \
|
||||
--hidden_chaar_dims 128 \
|
||||
--hidden_char_dims 128 \
|
||||
--domain_embd 32 \
|
||||
--batch 256 \
|
||||
--balanced_weights \
|
||||
@ -24,3 +24,19 @@ do
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
for depth in small medium
|
||||
do
|
||||
python main.py --mode train \
|
||||
--train /tmp/rk/currentData.csv \
|
||||
--model /tmp/rk/results/both_${depth}_inter \
|
||||
--epochs 50 \
|
||||
--embd 64 \
|
||||
--hidden_char_dims 128 \
|
||||
--domain_embd 32 \
|
||||
--batch 256 \
|
||||
--balanced_weights \
|
||||
--model_output both \
|
||||
--type inter \
|
||||
--depth ${depth}
|
||||
done
|
17
test.sh
17
test.sh
@ -3,17 +3,8 @@
|
||||
|
||||
for output in client both
|
||||
do
|
||||
for depth in small medium
|
||||
do
|
||||
for mtype in inter final staggered
|
||||
do
|
||||
|
||||
python main.py --mode test --batch 1024 \
|
||||
--model /tmp/rk/results/${output}_${depth}_${mtype} \
|
||||
--test /tmp/rk/futureData.csv \
|
||||
--model_output ${output} \
|
||||
--type ${mtype} \
|
||||
--depth ${depth}
|
||||
done
|
||||
done
|
||||
python3 main.py --mode test --batch 1024 \
|
||||
--models tm/rk/${output}_* \
|
||||
--test data/futureData.csv \
|
||||
--model_output ${output}
|
||||
done
|
||||
|
10
visualize.py
10
visualize.py
@ -5,7 +5,7 @@ import numpy as np
|
||||
from sklearn.decomposition import TruncatedSVD
|
||||
from sklearn.metrics import (
|
||||
auc, classification_report, confusion_matrix, fbeta_score, precision_recall_curve,
|
||||
roc_auc_score, roc_curve
|
||||
roc_auc_score, roc_curve, average_precision_score
|
||||
)
|
||||
|
||||
|
||||
@ -52,7 +52,9 @@ def plot_precision_recall(y, y_pred, label=""):
|
||||
|
||||
# fig, ax = plt.subplots(1, 1)
|
||||
# ax.hold(True)
|
||||
plt.plot(recall, precision, '--', label=label)
|
||||
score = fbeta_score(y, y_pred.round(), 1)
|
||||
# prc_ap = average_precision_score(y, y_pred)
|
||||
plt.plot(recall, precision, '--', label=f"{label} - {score}")
|
||||
# ax.step(recall[::-1], decreasing_max_precision, '-r')
|
||||
plt.xlabel('Recall')
|
||||
plt.ylabel('Precision')
|
||||
@ -87,6 +89,7 @@ def plot_roc_curve(mask, prediction, label=""):
|
||||
y_pred = prediction.flatten()
|
||||
fpr, tpr, thresholds = roc_curve(y, y_pred)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
plt.xscale('log')
|
||||
plt.plot(fpr, tpr, label=f"{label} - {roc_auc}")
|
||||
|
||||
|
||||
@ -154,7 +157,8 @@ def plot_embedding(domain_embedding, labels, path, dpi=600):
|
||||
domain_reduced[:, 1],
|
||||
c=(labels * (1, 2)).sum(1).astype(int),
|
||||
cmap=plt.cm.plasma,
|
||||
s=3)
|
||||
s=3,
|
||||
alpha=0.2)
|
||||
plt.colorbar()
|
||||
plt.savefig(path, dpi=dpi)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user