ma_cisco_malware/arguments.py

137 lines
5.2 KiB
Python
Raw Normal View History

import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument("--mode", action="store", dest="mode",
default="")
parser.add_argument("--train", action="store", dest="train_data",
default="data/full_dataset.csv.tar.gz")
parser.add_argument("--data", action="store", dest="train_data",
default="data/full_dataset.csv.tar.gz")
parser.add_argument("--test", action="store", dest="test_data",
default="data/full_future_dataset.csv.tar.gz")
parser.add_argument("--hyper_result", action="store", dest="hyperband_results",
default="")
parser.add_argument("--model", action="store", dest="model_path",
default="results/model_x")
2017-09-28 12:23:22 +02:00
parser.add_argument("--model_src", action="store", dest="model_source",
default="results/model_x")
parser.add_argument("--model_dest", action="store", dest="model_destination",
default="results/model_x")
parser.add_argument("--models", action="store", dest="model_paths", nargs="+",
default=[])
parser.add_argument("--type", action="store", dest="model_type",
default="final") # inter, final, staggered
parser.add_argument("--depth", action="store", dest="model_depth",
default="small") # small, medium
parser.add_argument("--model_output", action="store", dest="model_output",
default="both")
parser.add_argument("--batch", action="store", dest="batch_size",
default=64, type=int)
parser.add_argument("--epochs", action="store", dest="epochs",
default=10, type=int)
2017-09-28 12:23:22 +02:00
parser.add_argument("--init_epoch", action="store", dest="initial_epoch",
default=0, type=int)
# parser.add_argument("--samples", action="store", dest="samples",
# default=100000, type=int)
#
# parser.add_argument("--samples_val", action="store", dest="samples_val",
# default=10000, type=int)
#
parser.add_argument("--embd", action="store", dest="embedding",
default=128, type=int)
parser.add_argument("--filter_embd", action="store", dest="filter_embedding",
default=128, type=int)
parser.add_argument("--dense_embd", action="store", dest="dense_embedding",
default=128, type=int)
parser.add_argument("--kernel_embd", action="store", dest="kernel_embedding",
default=3, type=int)
parser.add_argument("--filter_main", action="store", dest="filter_main",
default=128, type=int)
parser.add_argument("--dense_main", action="store", dest="dense_main",
default=128, type=int)
parser.add_argument("--kernel_main", action="store", dest="kernel_main",
default=3, type=int)
parser.add_argument("--window", action="store", dest="window",
default=10, type=int)
parser.add_argument("--domain_length", action="store", dest="domain_length",
default=40, type=int)
parser.add_argument("--domain_embd", action="store", dest="domain_embedding",
default=512, type=int)
parser.add_argument("--out-prefix", action="store", dest="output_prefix",
default="", type=str)
# parser.add_argument("--queue", action="store", dest="queue_size",
# default=50, type=int)
#
# parser.add_argument("--p", action="store", dest="p_train",
# default=0.5, type=float)
#
# parser.add_argument("--p_val", action="store", dest="p_val",
# default=0.01, type=float)
#
# parser.add_argument("--gpu", action="store", dest="gpu",
# default=0, type=int)
#
# parser.add_argument("--tmp", action="store_true", dest="tmp")
#
parser.add_argument("--stop_early", action="store_true", dest="stop_early")
parser.add_argument("--balanced_weights", action="store_true", dest="class_weights")
parser.add_argument("--gpu", action="store_true", dest="gpu")
parser.add_argument("--new_model", action="store_true", dest="new_model")
def get_model_args(args):
return [{
"model_path": model_path,
"model_name": os.path.split(os.path.normpath(model_path))[1],
"embedding_model": os.path.join(model_path, "embd.h5"),
"clf_model": os.path.join(model_path, "clf.h5"),
"train_log": os.path.join(model_path, "train.log.csv"),
"train_h5data": args.train_data,
"test_h5data": args.test_data,
"future_prediction": os.path.join(model_path, f"{os.path.basename(args.test_data)}_pred")
} for model_path in args.model_paths]
2017-09-28 12:23:22 +02:00
def parse():
args = parser.parse_args()
args.result_path = os.path.split(os.path.normpath(args.output_prefix))[1]
2017-09-10 23:40:14 +02:00
args.model_name = os.path.split(os.path.normpath(args.model_path))[1]
args.embedding_model = os.path.join(args.model_path, "embd.h5")
args.clf_model = os.path.join(args.model_path, "clf.h5")
args.train_log = os.path.join(args.model_path, "train.log.csv")
args.train_h5data = args.train_data
args.test_h5data = args.test_data
args.future_prediction = os.path.join(args.model_path, f"{os.path.basename(args.test_data)}_pred")
return args