From 787f43b328f69ae3b0436b79e974fdef4f16c203 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Knaebel?= Date: Thu, 3 Aug 2017 07:51:58 +0200 Subject: [PATCH] fix test predictions depending on model output specification --- main.py | 14 +++++++++++--- run.sh | 26 ++++++++++++++------------ scripts/make_csv_dataset.py | 17 +++++++++++++---- 3 files changed, 38 insertions(+), 19 deletions(-) diff --git a/main.py b/main.py index f723346..9581f27 100644 --- a/main.py +++ b/main.py @@ -183,9 +183,17 @@ def main_test(): domain_val, flow_val, client_val, server_val = load_or_generate_h5data(args.test_h5data, args.test_data, args.domain_length, args.window) clf = load_model(args.clf_model, custom_objects=models.get_metrics()) - c_pred, s_pred = clf.predict([domain_val, flow_val], - batch_size=args.batch_size, - verbose=1) + pred = clf.predict([domain_val, flow_val], + batch_size=args.batch_size, + verbose=1) + if args.model_output == "both": + c_pred, s_pred = pred + elif args.model_output == "client": + c_pred = pred + s_pred = np.array() + else: + c_pred = np.array() + s_pred = pred dataset.save_predictions(args.future_prediction, c_pred, s_pred) diff --git a/run.sh b/run.sh index 4434280..f6a14da 100644 --- a/run.sh +++ b/run.sh @@ -1,49 +1,51 @@ -python3 main.py --mode train \ +#!/usr/bin/env bash + +python main.py --mode train \ --train /tmp/rk/currentData.csv \ --model /tmp/rk/results/simple_both \ --epochs 25 \ - --hidden_char_dims 64 \ + --hidden_char_dims 128 \ --domain_embd 32 \ --batch 256 \ --balanced_weights \ --model_output both -python3 main.py --mode test --batch 512 --model /tmp/rk/results/simple_both --test /tmp/rk/futureData.csv +python main.py --mode test --batch 512 --model /tmp/rk/results/simple_both --test /tmp/rk/futureData.csv --model_output both -python3 main.py --mode train \ +python main.py --mode train \ --train /tmp/rk/currentData.csv \ --model /tmp/rk/results/simple_client \ --epochs 25 \ - --hidden_char_dims 64 \ + --hidden_char_dims 128 \ --domain_embd 32 \ --batch 256 \ --balanced_weights \ --model_output client -python3 main.py --mode test --batch 512 --model /tmp/rk/results/simple_client --test /tmp/rk/futureData.csv +python main.py --mode test --batch 512 --model /tmp/rk/results/simple_client --test /tmp/rk/futureData.csv --model_output client -python3 main.py --mode train \ +python main.py --mode train \ --train /tmp/rk/currentData.csv \ --model /tmp/rk/results/simple_new_both \ --epochs 25 \ - --hidden_char_dims 64 \ + --hidden_char_dims 128 \ --domain_embd 32 \ --batch 256 \ --balanced_weights \ --model_output both \ --new_model -python3 main.py --mode test --batch 512 --model /tmp/rk/results/simple_new_both --test /tmp/rk/futureData.csv +python main.py --mode test --batch 512 --model /tmp/rk/results/simple_new_both --test /tmp/rk/futureData.csv --model_output both -python3 main.py --mode train \ +python main.py --mode train \ --train /tmp/rk/currentData.csv \ --model /tmp/rk/results/simple_new_client \ --epochs 25 \ - --hidden_char_dims 64 \ + --hidden_char_dims 128 \ --domain_embd 32 \ --batch 256 \ --balanced_weights \ --model_output client \ --new_model -python3 main.py --mode test --batch 512 --model /tmp/rk/results/simple_new_client --test /tmp/rk/futureData.csv \ No newline at end of file +python main.py --mode test --batch 512 --model /tmp/rk/results/simple_new_client --test /tmp/rk/futureData.csv --model_output client \ No newline at end of file diff --git a/scripts/make_csv_dataset.py b/scripts/make_csv_dataset.py index 5473880..2af0279 100644 --- a/scripts/make_csv_dataset.py +++ b/scripts/make_csv_dataset.py @@ -1,17 +1,26 @@ #!/usr/bin/python2 +import sys + import joblib import numpy as np import pandas as pd -df = joblib.load("/mnt/projekte/pmlcluster/cisco/trainData/multipleTaskLearning/currentData.joblib") +fn = sys.argv[1] + +df = joblib.load("/mnt/projekte/pmlcluster/cisco/trainData/multipleTaskLearning/{}.joblib".format(fn)) df = pd.concat(df["data"]) df.reset_index(inplace=True) df.dropna(axis=0, how="any", inplace=True) -df[["duration", "bytes_down", "bytes_up"]] = df[["duration", "bytes_down", "bytes_up"]].astype(np.int) -df[["domain", "server_ip"]] = df[["domain", "server_ip"]].astype(str) + +df.serverLabel = pd.to_numeric(df.serverLabel, errors='coerce') +df.duration = pd.to_numeric(df.duration, errors='coerce') +df.bytes_down = pd.to_numeric(df.bytes_down, errors='coerce') +df.bytes_up = pd.to_numeric(df.bytes_up, errors='coerce') + +df.http_method = df.http_method.astype("category") df.serverLabel = df.serverLabel.astype(np.bool) df.virusTotalHits = df.virusTotalHits.astype(np.int8) df.trustedHits = df.trustedHits.astype(np.int8) -df.to_csv("/tmp/rk/full_future_dataset.csv.gz", compression="gzip") +df.to_csv("/tmp/rk/{}.csv".format(fn), encoding="utf-8")