add error bar vis, comment unused parameters from parser
This commit is contained in:
parent
b24fa770f9
commit
345afbaef5
24
arguments.py
24
arguments.py
@ -6,14 +6,14 @@ parser = argparse.ArgumentParser()
|
|||||||
parser.add_argument("--mode", action="store", dest="mode",
|
parser.add_argument("--mode", action="store", dest="mode",
|
||||||
default="")
|
default="")
|
||||||
|
|
||||||
parser.add_argument("--train", action="store", dest="train_data",
|
# parser.add_argument("--train", action="store", dest="train_data",
|
||||||
|
# default="data/full_dataset.csv.tar.gz")
|
||||||
|
|
||||||
|
parser.add_argument("--data", action="store", dest="data",
|
||||||
default="data/full_dataset.csv.tar.gz")
|
default="data/full_dataset.csv.tar.gz")
|
||||||
|
|
||||||
parser.add_argument("--data", action="store", dest="train_data",
|
# parser.add_argument("--test", action="store", dest="test_data",
|
||||||
default="data/full_dataset.csv.tar.gz")
|
# default="data/full_future_dataset.csv.tar.gz")
|
||||||
|
|
||||||
parser.add_argument("--test", action="store", dest="test_data",
|
|
||||||
default="data/full_future_dataset.csv.tar.gz")
|
|
||||||
|
|
||||||
parser.add_argument("--hyper_result", action="store", dest="hyperband_results",
|
parser.add_argument("--hyper_result", action="store", dest="hyperband_results",
|
||||||
default="")
|
default="")
|
||||||
@ -117,9 +117,9 @@ def get_model_args(args):
|
|||||||
"embedding_model": os.path.join(model_path, "embd.h5"),
|
"embedding_model": os.path.join(model_path, "embd.h5"),
|
||||||
"clf_model": os.path.join(model_path, "clf.h5"),
|
"clf_model": os.path.join(model_path, "clf.h5"),
|
||||||
"train_log": os.path.join(model_path, "train.log.csv"),
|
"train_log": os.path.join(model_path, "train.log.csv"),
|
||||||
"train_h5data": args.train_data,
|
# "train_h5data": args.train_data,
|
||||||
"test_h5data": args.test_data,
|
# "test_h5data": args.test_data,
|
||||||
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred")
|
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.data)}_pred")
|
||||||
} for model_path in args.model_paths]
|
} for model_path in args.model_paths]
|
||||||
|
|
||||||
|
|
||||||
@ -130,7 +130,7 @@ def parse():
|
|||||||
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
args.embedding_model = os.path.join(args.model_path, "embd.h5")
|
||||||
args.clf_model = os.path.join(args.model_path, "clf.h5")
|
args.clf_model = os.path.join(args.model_path, "clf.h5")
|
||||||
args.train_log = os.path.join(args.model_path, "train.log.csv")
|
args.train_log = os.path.join(args.model_path, "train.log.csv")
|
||||||
args.train_h5data = args.train_data
|
# args.train_h5data = args.train_data
|
||||||
args.test_h5data = args.test_data
|
# args.test_h5data = args.test_data
|
||||||
args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.test_data)}_pred")
|
args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.data)}_pred")
|
||||||
return args
|
return args
|
||||||
|
24
fancy.sh
24
fancy.sh
@ -3,22 +3,22 @@
|
|||||||
RESDIR=$1
|
RESDIR=$1
|
||||||
DATADIR=$2
|
DATADIR=$2
|
||||||
|
|
||||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_final --test ${DATADIR}/futureData.csv --model_output both
|
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_final --data ${DATADIR}/futureData.csv --model_output both
|
||||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_inter --test ${DATADIR}/futureData.csv --model_output both
|
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_inter --data ${DATADIR}/futureData.csv --model_output both
|
||||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_staggered --test ${DATADIR}/futureData.csv --model_output both
|
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_staggered --data ${DATADIR}/futureData.csv --model_output both
|
||||||
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_final --test ${DATADIR}/futureData.csv --model_output client
|
python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_final --data ${DATADIR}/futureData.csv --model_output client
|
||||||
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_inter --test ${DATADIR}/futureData.csv --model_output client
|
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_inter --data ${DATADIR}/futureData.csv --model_output client
|
||||||
|
|
||||||
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_final --test ${DATADIR}/futureData.csv --model_output both
|
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_final --data ${DATADIR}/futureData.csv --model_output both
|
||||||
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_inter --test ${DATADIR}/futureData.csv --model_output both
|
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/both_medium_inter --data ${DATADIR}/futureData.csv --model_output both
|
||||||
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_final --test ${DATADIR}/futureData.csv --model_output client
|
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_final --data ${DATADIR}/futureData.csv --model_output client
|
||||||
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_inter --test ${DATADIR}/futureData.csv --model_output client
|
#python3 main.py --mode fancy --batch 1024 --model ${RESDIR}/client_medium_inter --data ${DATADIR}/futureData.csv --model_output client
|
||||||
|
|
||||||
#python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \
|
#python3 main.py --mode all_fancy --batch 256 --data ${DATADIR}/futureData.csv \
|
||||||
# --models ${RESDIR}/*_small_*/ --out-prefix ${RESDIR}/small
|
# --models ${RESDIR}/*_small_*/ --out-prefix ${RESDIR}/small
|
||||||
|
|
||||||
#python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \
|
#python3 main.py --mode all_fancy --batch 256 --data ${DATADIR}/futureData.csv \
|
||||||
# --models ${RESDIR}/*_medium_*/ --out-prefix ${RESDIR}/medium
|
# --models ${RESDIR}/*_medium_*/ --out-prefix ${RESDIR}/medium
|
||||||
|
|
||||||
python3 main.py --mode all_fancy --batch 256 --test ${DATADIR}/futureData.csv \
|
python3 main.py --mode all_fancy --batch 256 --data ${DATADIR}/futureData.csv \
|
||||||
--models ${RESDIR}/*/ --out-prefix ${RESDIR}/all
|
--models ${RESDIR}/*/ --out-prefix ${RESDIR}/all
|
||||||
|
8
main.py
8
main.py
@ -589,6 +589,12 @@ def plot_overall_result():
|
|||||||
visualize.plot_legend()
|
visualize.plot_legend()
|
||||||
visualize.plot_save(f"{path}/{vis}_all.png")
|
visualize.plot_save(f"{path}/{vis}_all.png")
|
||||||
|
|
||||||
|
for cat, models in results.items():
|
||||||
|
visualize.plot_clf()
|
||||||
|
visualize.plot_error_bars(models)
|
||||||
|
visualize.plot_legend()
|
||||||
|
visualize.plot_save(f"{path}/error_bars_{cat}.png")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if "train" == args.mode:
|
if "train" == args.mode:
|
||||||
@ -605,6 +611,8 @@ def main():
|
|||||||
main_visualize_all()
|
main_visualize_all()
|
||||||
if "beta" == args.mode:
|
if "beta" == args.mode:
|
||||||
main_beta()
|
main_beta()
|
||||||
|
if "all_beta" == args.mode:
|
||||||
|
plot_overall_result()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
13
visualize.py
13
visualize.py
@ -2,6 +2,7 @@ import os
|
|||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
from scipy import interpolate
|
from scipy import interpolate
|
||||||
from sklearn.decomposition import TruncatedSVD
|
from sklearn.decomposition import TruncatedSVD
|
||||||
from sklearn.manifold import TSNE
|
from sklearn.manifold import TSNE
|
||||||
@ -211,6 +212,18 @@ def plot_training_curve(logs, key, path, dpi=600):
|
|||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_error_bars(results):
|
||||||
|
rates = []
|
||||||
|
for m, r in results.items():
|
||||||
|
if m == "all": continue
|
||||||
|
rates.append((r / r.sum(axis=0, keepdims=True)).flatten())
|
||||||
|
rates = pd.DataFrame(np.vstack(rates), columns=("TN", "FP", "FN", "TP"))
|
||||||
|
|
||||||
|
ax = rates.mean().plot.bar(yerr=rates.std())
|
||||||
|
for p in ax.patches:
|
||||||
|
ax.annotate(str(np.round(p.get_height(), 4)), (p.get_x(), 0.5))
|
||||||
|
|
||||||
|
|
||||||
def plot_embedding(domain_embedding, labels, path, dpi=600, method="svd"):
|
def plot_embedding(domain_embedding, labels, path, dpi=600, method="svd"):
|
||||||
if method == "svd":
|
if method == "svd":
|
||||||
red = TruncatedSVD(n_components=2)
|
red = TruncatedSVD(n_components=2)
|
||||||
|
Loading…
Reference in New Issue
Block a user