network predicts 2 by 2 classes, refactored threshold to main
This commit is contained in:
parent
8334e9a84f
commit
c972963a19
@ -86,7 +86,7 @@ def get_cisco_features(curDataLine, urlSIPDict):
|
|||||||
return np.zeros([numCiscoFeatures, ]).ravel()
|
return np.zeros([numCiscoFeatures, ]).ravel()
|
||||||
|
|
||||||
|
|
||||||
def create_dataset_from_flows(user_flow_df, char_dict, maxLen, threshold=3, windowSize=10, use_cisco_features=False):
|
def create_dataset_from_flows(user_flow_df, char_dict, maxLen, windowSize=10, use_cisco_features=False):
|
||||||
domainLists = []
|
domainLists = []
|
||||||
dfLists = []
|
dfLists = []
|
||||||
print("get chunks from user data frames")
|
print("get chunks from user data frames")
|
||||||
@ -102,12 +102,12 @@ def create_dataset_from_flows(user_flow_df, char_dict, maxLen, threshold=3, wind
|
|||||||
print("create training dataset")
|
print("create training dataset")
|
||||||
return create_dataset_from_lists(
|
return create_dataset_from_lists(
|
||||||
domains=domainLists, dfs=dfLists, vocab=char_dict,
|
domains=domainLists, dfs=dfLists, vocab=char_dict,
|
||||||
maxLen=maxLen, threshold=threshold,
|
maxLen=maxLen,
|
||||||
use_cisco_features=use_cisco_features, urlSIPDIct=dict(),
|
use_cisco_features=use_cisco_features, urlSIPDIct=dict(),
|
||||||
window_size=windowSize)
|
window_size=windowSize)
|
||||||
|
|
||||||
|
|
||||||
def create_dataset_from_lists(domains, dfs, vocab, maxLen, threshold=3,
|
def create_dataset_from_lists(domains, dfs, vocab, maxLen,
|
||||||
use_cisco_features=False, urlSIPDIct=dict(),
|
use_cisco_features=False, urlSIPDIct=dict(),
|
||||||
window_size=10):
|
window_size=10):
|
||||||
# TODO: check for hits vs vth consistency
|
# TODO: check for hits vs vth consistency
|
||||||
|
28
main.py
28
main.py
@ -37,23 +37,24 @@ def main():
|
|||||||
user_flow_df = dataset.get_user_flow_data()
|
user_flow_df = dataset.get_user_flow_data()
|
||||||
|
|
||||||
print("create training dataset")
|
print("create training dataset")
|
||||||
(X_tr, y_tr, hits_tr, names_tr, server_tr, trusted_hits_tr) = dataset.create_dataset_from_flows(
|
(X_tr, hits_tr, names_tr, server_tr, trusted_hits_tr) = dataset.create_dataset_from_flows(
|
||||||
user_flow_df, char_dict,
|
user_flow_df, char_dict,
|
||||||
maxLen=maxLen, threshold=threshold, windowSize=windowSize)
|
maxLen=maxLen, windowSize=windowSize)
|
||||||
|
# make client labels discrete with 4 different values
|
||||||
pos_idx = np.where(y_tr == 1.0)[0]
|
# TODO: use trusted_hits_tr for client classification too
|
||||||
neg_idx = np.where(y_tr == 0.0)[0]
|
client_labels = np.apply_along_axis(lambda x: dataset.discretize_label(x, 3), 0, np.atleast_2d(hits_tr))
|
||||||
|
# select only 1.0 and 0.0 from training data
|
||||||
|
pos_idx = np.where(client_labels == 1.0)[0]
|
||||||
|
neg_idx = np.where(client_labels == 0.0)[0]
|
||||||
idx = np.concatenate((pos_idx, neg_idx))
|
idx = np.concatenate((pos_idx, neg_idx))
|
||||||
|
# select labels for prediction
|
||||||
|
client_labels = client_labels[idx]
|
||||||
|
server_labels = server_tr[idx]
|
||||||
|
|
||||||
y_tr = y_tr[idx]
|
# TODO: remove when features are flattened
|
||||||
hits_tr = hits_tr[idx]
|
|
||||||
names_tr = names_tr[idx]
|
|
||||||
server_tr = server_tr[idx]
|
|
||||||
trusted_hits_tr = trusted_hits_tr[idx]
|
|
||||||
for i in range(len(X_tr)):
|
for i in range(len(X_tr)):
|
||||||
X_tr[i] = X_tr[i][idx]
|
X_tr[i] = X_tr[i][idx]
|
||||||
|
|
||||||
# TODO: WTF? I don't get it...
|
|
||||||
shared_cnn = models.get_shared_cnn(len(char_dict) + 1, embeddingSize, maxLen,
|
shared_cnn = models.get_shared_cnn(len(char_dict) + 1, embeddingSize, maxLen,
|
||||||
domainFeatures, kernel_size, domainFeatures, 0.5)
|
domainFeatures, kernel_size, domainFeatures, 0.5)
|
||||||
|
|
||||||
@ -65,8 +66,9 @@ def main():
|
|||||||
metrics=['accuracy'])
|
metrics=['accuracy'])
|
||||||
|
|
||||||
epochNumber = 0
|
epochNumber = 0
|
||||||
y_tr = np_utils.to_categorical(y_tr, 2)
|
client_labels = np_utils.to_categorical(client_labels, 2)
|
||||||
model.fit(x=X_tr, y=y_tr, batch_size=128,
|
server_labels = np_utils.to_categorical(server_labels, 2)
|
||||||
|
model.fit(X_tr, [client_labels, server_labels], batch_size=128,
|
||||||
epochs=epochNumber + 1, shuffle=True, initial_epoch=epochNumber) # ,
|
epochs=epochNumber + 1, shuffle=True, initial_epoch=epochNumber) # ,
|
||||||
# validation_data=(testData,testLabel))
|
# validation_data=(testData,testLabel))
|
||||||
|
|
||||||
|
@ -45,9 +45,9 @@ def get_top_cnn(cnn, numFeatures, maxLen, windowSize, domainFeatures, filters, k
|
|||||||
maxPool = GlobalMaxPooling1D()(cnn)
|
maxPool = GlobalMaxPooling1D()(cnn)
|
||||||
cnnDropout = Dropout(cnnDropout)(maxPool)
|
cnnDropout = Dropout(cnnDropout)(maxPool)
|
||||||
cnnDense = Dense(cnnHiddenDims, activation='relu')(cnnDropout)
|
cnnDense = Dense(cnnHiddenDims, activation='relu')(cnnDropout)
|
||||||
cnnOutput = Dense(2, activation='softmax')(cnnDense)
|
cnnOutput1 = Dense(2, activation='softmax')(cnnDense)
|
||||||
|
cnnOutput2 = Dense(2, activation='softmax')(cnnDense)
|
||||||
|
|
||||||
# We define a trainable model linking the
|
# We define a trainable model linking the
|
||||||
# tweet inputs to the predictions
|
# tweet inputs to the predictions
|
||||||
model = Model(inputs=inputList, outputs=cnnOutput)
|
return Model(inputs=inputList, outputs=(cnnOutput1, cnnOutput2))
|
||||||
return model
|
|
||||||
|
@ -3,8 +3,8 @@
|
|||||||
import joblib
|
import joblib
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
datafile = joblib.load("/mnt/projekte/pmlcluster/cisco/trainData/multipleTaskLearning/currentData.joblib")
|
df = joblib.load("/mnt/projekte/pmlcluster/cisco/trainData/multipleTaskLearning/currentData.joblib")
|
||||||
user_flows = datafile["data"]
|
df = df["data"]
|
||||||
df = pd.concat(user_flows)
|
df = pd.concat(df)
|
||||||
df.reset_index(inplace=True)
|
df.reset_index(inplace=True)
|
||||||
df.to_csv("/tmp/rk/full_dataset.csv.gz", compression="gzip")
|
df.to_csv("/tmp/rk/full_dataset.csv.gz", compression="gzip")
|
||||||
|
Loading…
Reference in New Issue
Block a user