reorder curve storing

This commit is contained in:
René Knaebel 2017-11-13 14:37:41 +01:00
parent d58dbcb101
commit 349bc92a61

38
main.py
View File

@ -610,17 +610,12 @@ def main_beta():
args.domain_length, args.domain_length,
args.window) args.window)
path, model_prefix = os.path.split(os.path.normpath(args.model_path)) path, model_prefix = os.path.split(os.path.normpath(args.model_path))
print(path, model_prefix) curves = {
try: model_prefix: {"all": {}}
curves = joblib.load(f"{path}/curves.joblib") }
logger.info(f"load file {path}/curves.joblib successfully")
except Exception:
curves = {}
logger.info(f"currently {len(curves)} models in file: {curves.keys()}")
curves[model_prefix] = {"all": {}}
domains = domain_val.value.reshape(-1, 40) # domains = domain_val.value.reshape(-1, 40)
domains = np.apply_along_axis(lambda d: dataset.decode_domain(d), 1, domains) # domains = np.apply_along_axis(lambda d: dataset.decode_domain(d), 1, domains)
def load_df(res): def load_df(res):
df_server = None df_server = None
@ -634,12 +629,12 @@ def main_beta():
data["server_pred"] = server.flatten() data["server_pred"] = server.flatten()
data["server_val"] = val.flatten() data["server_val"] = val.flatten()
if res["server_pred"].flatten().shape == server_val.value.flatten().shape: # if res["server_pred"].flatten().shape == server_val.value.flatten().shape:
df_server = pd.DataFrame(data={ # df_server = pd.DataFrame(data={
"server_pred": res["server_pred"].flatten(), # "server_pred": res["server_pred"].flatten(),
"domain": domains, # "domain": domains,
"server_val": server_val.value.flatten() # "server_val": server_val.value.flatten()
}) # })
res = pd.DataFrame(data=data) res = pd.DataFrame(data=data)
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3) res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
@ -716,8 +711,15 @@ def main_beta():
df_domain_avg.server_val.as_matrix(), df_domain_avg.server_val.as_matrix(),
server_domain_avg_preds) server_domain_avg_preds)
joblib.dump(curves, f"{path}/curves.joblib") joblib.dump(curves, f"{args.model_path}_curves.joblib")
try:
curves_all: dict = joblib.load(f"{path}/curves.joblib")
logger.info(f"load file {path}/curves.joblib successfully")
curves_all[model_prefix] = curves[model_prefix]
except Exception:
curves_all = curves
logger.info(f"currently {len(curves_all)} models in file: {curves_all.keys()}")
joblib.dump(curves_all, f"{path}/curves.joblib")
import matplotlib import matplotlib