add some print lines for better following the process structure
This commit is contained in:
parent
27f4d086eb
commit
c19d649bc4
19
main.py
19
main.py
@ -16,7 +16,6 @@ import models
|
||||
# create logger
|
||||
import visualize
|
||||
from arguments import get_model_args
|
||||
from server import test_server_only, train_server_only
|
||||
from utils import exists_or_make_path, get_custom_class_weights, get_custom_sample_weights, load_model
|
||||
|
||||
logger = logging.getLogger('cisco_logger')
|
||||
@ -648,7 +647,8 @@ def main_beta():
|
||||
|
||||
return res, df_server
|
||||
|
||||
res = dataset.load_predictions(path)
|
||||
logger.info(f"load results from {args.model_path}")
|
||||
res = dataset.load_predictions(args.model_path)
|
||||
model_keys = sorted(filter(lambda x: x.startswith("clf"), res.keys()), key=lambda x: int(x[4:-3]))
|
||||
|
||||
client_preds = []
|
||||
@ -659,11 +659,13 @@ def main_beta():
|
||||
server_domain_preds = []
|
||||
server_domain_avg_preds = []
|
||||
for model_name in model_keys:
|
||||
logger.info(f"load model {model_name}")
|
||||
df, df_server = load_df(res[model_name])
|
||||
client_preds.append(df.client_pred.as_matrix())
|
||||
if "server_val" in df.columns:
|
||||
server_preds.append(df.server_pred.as_matrix())
|
||||
if df_server is not None:
|
||||
logger.info(f" group servers")
|
||||
server_flow_preds.append(df_server.server_pred.as_matrix())
|
||||
df_domain = df_server.groupby(df_server.domain).max()
|
||||
server_domain_preds.append(df_domain.server_pred.as_matrix())
|
||||
@ -672,6 +674,7 @@ def main_beta():
|
||||
|
||||
curves[model_prefix][model_name] = confusion_matrix(df.client_val.as_matrix(),
|
||||
df.client_pred.as_matrix().round())
|
||||
logger.info(f" group users")
|
||||
df_user = df.groupby(df.names).max()
|
||||
client_user_preds.append(df_user.client_pred.as_matrix())
|
||||
if "server_val" in df.columns:
|
||||
@ -840,25 +843,15 @@ def main():
|
||||
main_retrain()
|
||||
if "hyperband" == args.mode:
|
||||
main_hyperband(args.data, args.domain_length, args.window, args.model_type, args.hyperband_results,
|
||||
arg.hyper_max_iter)
|
||||
args.hyper_max_iter)
|
||||
if "test" == args.mode:
|
||||
main_test()
|
||||
if "fancy" == args.mode:
|
||||
main_visualization()
|
||||
if "all_fancy" == args.mode:
|
||||
main_visualize_all()
|
||||
if "beta" == args.mode:
|
||||
main_beta()
|
||||
if "all_beta" == args.mode:
|
||||
plot_overall_result()
|
||||
if "server" == args.mode:
|
||||
train_server_only()
|
||||
if "server_test" == args.mode:
|
||||
test_server_only()
|
||||
if "embedding" == args.mode:
|
||||
main_visualize_all_embds()
|
||||
if "stats" == args.mode:
|
||||
main_stats()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user