fix covariance normalization; add run_model script for multi times training
This commit is contained in:
parent
6121eac448
commit
6d8d7b19f3
29
run_model.sh
Normal file
29
run_model.sh
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
|
||||||
|
N=$1
|
||||||
|
OUTPUT=$2
|
||||||
|
DEPTH=$3
|
||||||
|
TYPE=$4
|
||||||
|
RESDIR=$5
|
||||||
|
mkdir -p /tmp/rk/${RESDIR}
|
||||||
|
DATADIR=$6
|
||||||
|
|
||||||
|
EPOCHS=100
|
||||||
|
|
||||||
|
for i in {1..$N}
|
||||||
|
do
|
||||||
|
python main.py --mode train \
|
||||||
|
--train ${DATADIR}/currentData.csv \
|
||||||
|
--model ${RESDIR}/${OUTPUT}_${TYPE}_$i \
|
||||||
|
--epochs $EPOCHS \
|
||||||
|
--embd 128 \
|
||||||
|
--filter_embd 256 --kernel_embd 8 --dense_embd 128 \
|
||||||
|
--domain_embd 32 \
|
||||||
|
--filter_main 32 --kernel_main 8 --dense_main 1024 \
|
||||||
|
--batch 256 \
|
||||||
|
--balanced_weights \
|
||||||
|
--model_output ${OUTPUT} \
|
||||||
|
--type ${TYPE} \
|
||||||
|
--depth ${DEPTH}
|
||||||
|
done
|
14
visualize.py
14
visualize.py
@ -104,21 +104,21 @@ def plot_confusion_matrix(y_true, y_pred, path,
|
|||||||
"""
|
"""
|
||||||
plt.clf()
|
plt.clf()
|
||||||
cm = confusion_matrix(y_true, y_pred)
|
cm = confusion_matrix(y_true, y_pred)
|
||||||
plt.imshow(cm, interpolation='nearest', cmap=cmap)
|
|
||||||
plt.title(title)
|
|
||||||
plt.colorbar()
|
|
||||||
tick_marks = np.arange(len(classes))
|
|
||||||
plt.xticks(tick_marks, classes, rotation=45)
|
|
||||||
plt.yticks(tick_marks, classes)
|
|
||||||
|
|
||||||
if normalize:
|
if normalize:
|
||||||
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
||||||
print("Normalized confusion matrix")
|
print("Normalized confusion matrix")
|
||||||
else:
|
else:
|
||||||
print('Confusion matrix, without normalization')
|
print('Confusion matrix, without normalization')
|
||||||
|
|
||||||
print(cm)
|
print(cm)
|
||||||
|
|
||||||
|
plt.imshow(cm, interpolation='nearest', cmap=cmap)
|
||||||
|
plt.title(title)
|
||||||
|
plt.colorbar()
|
||||||
|
tick_marks = np.arange(len(classes))
|
||||||
|
plt.xticks(tick_marks, classes, rotation=45)
|
||||||
|
plt.yticks(tick_marks, classes)
|
||||||
|
|
||||||
thresh = cm.max() / 2.
|
thresh = cm.max() / 2.
|
||||||
for i, j in ((i, j) for i in range(cm.shape[0]) for j in range(cm.shape[1])):
|
for i, j in ((i, j) for i in range(cm.shape[0]) for j in range(cm.shape[1])):
|
||||||
plt.text(j, i, cm[i, j],
|
plt.text(j, i, cm[i, j],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user