refactor visualization, change arguments for model type and its depth

This commit is contained in:
2017-09-01 10:42:26 +02:00
parent 933eaae04a
commit dc9180da10
7 changed files with 156 additions and 150 deletions

View File

@@ -8,6 +8,7 @@ def get_models_by_params(params: dict):
# decomposing param section
# mainly embedding model
network_type = params.get("type")
network_depth = params.get("depth")
embedding_size = params.get("embedding_size")
input_length = params.get("input_length")
filter_embedding = params.get("filter_embedding")
@@ -24,7 +25,12 @@ def get_models_by_params(params: dict):
dense_dim = params.get("dense_main")
model_output = params.get("model_output", "both")
# create models
networks = renes_networks if network_type == "rene" else pauls_networks
if network_depth == "small":
networks = pauls_networks
elif network_depth == "medium":
networks = renes_networks
else:
raise Exception("network not found")
embedding_model = networks.get_embedding(embedding_size, input_length, filter_embedding, kernel_embedding,
hidden_embedding, dropout)