store domain embeddings while test main
This commit is contained in:
parent
452f9e0456
commit
7f1d13658f
10
Makefile
10
Makefile
@ -1,16 +1,16 @@
|
||||
run:
|
||||
python3 main.py --modes train --train data/rk_mini.csv.gz --model results/test --epochs 10 \
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test --epochs 10 \
|
||||
--hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights
|
||||
|
||||
run_new:
|
||||
python3 main.py --modes train --train data/rk_mini.csv.gz --model results/test2 --epochs 10 \
|
||||
python3 main.py --mode train --train data/rk_mini.csv.gz --model results/test2 --epochs 10 \
|
||||
--hidden_char_dims 32 --domain_embd 16 --batch 64 --balanced_weights --new_model
|
||||
|
||||
test:
|
||||
python3 main.py --modes test --batch 128 --model results/test --test data/rk_mini.csv.gz
|
||||
python3 main.py --mode test --batch 128 --model results/test --test data/rk_mini.csv.gz
|
||||
|
||||
fancy:
|
||||
python3 main.py --modes fancy --batch 128 --model results/test --test data/rk_mini.csv.gz
|
||||
python3 main.py --mode fancy --batch 128 --model results/test --test data/rk_mini.csv.gz
|
||||
|
||||
hyper:
|
||||
python3 main.py --modes hyperband --batch 64 --train data/rk_data.csv.gz
|
||||
python3 main.py --mode hyperband --batch 64 --train data/rk_data.csv.gz
|
||||
|
@ -248,13 +248,15 @@ def load_or_generate_domains(train_data, domain_length):
|
||||
return domain_encs, user_flow_df[["serverLabel", "clientLabel"]].as_matrix()
|
||||
|
||||
|
||||
def save_predictions(path, c_pred, s_pred):
|
||||
def save_predictions(path, c_pred, s_pred, embd, labels):
|
||||
f = h5py.File(path, "w")
|
||||
f.create_dataset("client", data=c_pred)
|
||||
f.create_dataset("server", data=s_pred)
|
||||
f.create_dataset("embedding", data=embd)
|
||||
f.create_dataset("labels", data=labels)
|
||||
f.close()
|
||||
|
||||
|
||||
def load_predictions(path):
|
||||
f = h5py.File(path, "r")
|
||||
return f["client"], f["server"]
|
||||
return f["client"], f["server"], f["embedding"], f["labels"]
|
||||
|
14
main.py
14
main.py
@ -194,7 +194,13 @@ def main_test():
|
||||
else:
|
||||
c_pred = np.zeros(0)
|
||||
s_pred = pred
|
||||
dataset.save_predictions(args.future_prediction, c_pred, s_pred)
|
||||
|
||||
model = load_model(args.embedding_model)
|
||||
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||
domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||
|
||||
dataset.save_predictions(args.future_prediction, c_pred, s_pred, domain_embedding, labels)
|
||||
|
||||
|
||||
|
||||
def main_visualization():
|
||||
@ -212,7 +218,7 @@ def main_visualization():
|
||||
except Exception as e:
|
||||
logger.warning(f"could not generate training curves: {e}")
|
||||
|
||||
client_pred, server_pred = dataset.load_predictions(args.future_prediction)
|
||||
client_pred, server_pred, domain_embedding, labels = dataset.load_predictions(args.future_prediction)
|
||||
client_pred, server_pred = client_pred.value, server_pred.value
|
||||
logger.info("plot pr curve")
|
||||
visualize.plot_precision_recall(client_val, client_pred.flatten(), "{}/client_prc.png".format(args.model_path))
|
||||
@ -229,9 +235,7 @@ def main_visualization():
|
||||
# "{}/server_cov.png".format(args.model_path),
|
||||
# normalize=False, title="Server Confusion Matrix")
|
||||
logger.info("visualize embedding")
|
||||
model = load_model(args.embedding_model)
|
||||
domain_encs, labels = dataset.load_or_generate_domains(args.test_data, args.domain_length)
|
||||
domain_embedding = model.predict(domain_encs, batch_size=args.batch_size, verbose=1)
|
||||
|
||||
visualize.plot_embedding(domain_embedding, labels, path="{}/embd.png".format(args.model_path))
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user