refactor network to use new input format

This commit is contained in:
René Knaebel 2017-07-04 20:42:48 +02:00
parent 5743127b7f
commit 59c1176e85
2 changed files with 21 additions and 30 deletions

3
Makefile Normal file

@ -0,0 +1,3 @@
test:
python3 main.py --epochs 1 --batch 64

@ -1,12 +1,12 @@
import keras
from keras.engine import Input, Model
from keras.layers import Embedding, Conv1D, GlobalMaxPooling1D, Dense, Dropout, Activation, Reshape
from keras.layers import Embedding, Conv1D, GlobalMaxPooling1D, Dense, Dropout, Activation, TimeDistributed
def get_shared_cnn(vocabSize, embeddingSize, input_length, filters, kernel_size,
def get_shared_cnn(vocab_size, embedding_size, input_length, filters, kernel_size,
hidden_dims, drop_out):
x = y = Input(shape=(input_length,))
y = Embedding(input_dim=vocabSize, output_dim=embeddingSize)(y)
y = Embedding(input_dim=vocab_size, output_dim=embedding_size)(y)
y = Conv1D(filters, kernel_size, activation='relu')(y)
y = GlobalMaxPooling1D()(y)
y = Dense(hidden_dims)(y)
@ -21,33 +21,21 @@ def get_full_model(vocabSize, embeddingSize, maxLen, domainFeatures, flowFeature
def get_top_cnn(cnn, numFeatures, maxLen, windowSize, domainFeatures, filters, kernel_size, cnnHiddenDims, cnnDropout):
inputList = []
encodedList = []
# TODO: ???
for i in range(windowSize):
inputList.append(Input(shape=(maxLen,)))
encodedList.append(cnn(inputList[-1])) # add shared domain model
inputList.append(Input(shape=(numFeatures,)))
# TODO: ???
merge_layer_input = []
for i in range(windowSize):
merge_layer_input.append(encodedList[i])
merge_layer_input.append(inputList[(2 * i) + 1])
# We can then concatenate the two vectors:
merged_vector = keras.layers.concatenate(merge_layer_input, axis=-1)
reshape = Reshape((windowSize, domainFeatures + numFeatures))(merged_vector)
ipt_domains = Input(shape=(windowSize, maxLen), name="ipt_domains")
encoded = TimeDistributed(cnn)(ipt_domains)
ipt_flows = Input(shape=(windowSize, numFeatures), name="ipt_flows")
merged = keras.layers.concatenate([encoded, ipt_flows], -1)
# add second cnn
cnn = Conv1D(filters,
y = Conv1D(filters,
kernel_size,
activation='relu',
input_shape=(windowSize, domainFeatures + numFeatures))(reshape)
input_shape=(windowSize, domainFeatures + numFeatures))(merged)
# TODO: why global pooling? -> 3D to 2D
# we use max pooling:
maxPool = GlobalMaxPooling1D()(cnn)
cnnDropout = Dropout(cnnDropout)(maxPool)
cnnDense = Dense(cnnHiddenDims, activation='relu')(cnnDropout)
cnnOutput1 = Dense(2, activation='softmax', name="client")(cnnDense)
cnnOutput2 = Dense(2, activation='softmax', name="server")(cnnDense)
y = GlobalMaxPooling1D()(y)
y = Dropout(cnnDropout)(y)
y = Dense(cnnHiddenDims, activation='relu')(y)
y1 = Dense(2, activation='softmax', name="client")(y)
y2 = Dense(2, activation='softmax', name="server")(y)
# We define a trainable model linking the
# tweet inputs to the predictions
return Model(inputs=inputList, outputs=(cnnOutput1, cnnOutput2))
return Model(inputs=[ipt_domains, ipt_flows], outputs=(y1, y2))