reorder curve storing
This commit is contained in:
parent
d58dbcb101
commit
349bc92a61
38
main.py
38
main.py
@ -610,17 +610,12 @@ def main_beta():
|
||||
args.domain_length,
|
||||
args.window)
|
||||
path, model_prefix = os.path.split(os.path.normpath(args.model_path))
|
||||
print(path, model_prefix)
|
||||
try:
|
||||
curves = joblib.load(f"{path}/curves.joblib")
|
||||
logger.info(f"load file {path}/curves.joblib successfully")
|
||||
except Exception:
|
||||
curves = {}
|
||||
logger.info(f"currently {len(curves)} models in file: {curves.keys()}")
|
||||
curves[model_prefix] = {"all": {}}
|
||||
curves = {
|
||||
model_prefix: {"all": {}}
|
||||
}
|
||||
|
||||
domains = domain_val.value.reshape(-1, 40)
|
||||
domains = np.apply_along_axis(lambda d: dataset.decode_domain(d), 1, domains)
|
||||
# domains = domain_val.value.reshape(-1, 40)
|
||||
# domains = np.apply_along_axis(lambda d: dataset.decode_domain(d), 1, domains)
|
||||
|
||||
def load_df(res):
|
||||
df_server = None
|
||||
@ -634,12 +629,12 @@ def main_beta():
|
||||
data["server_pred"] = server.flatten()
|
||||
data["server_val"] = val.flatten()
|
||||
|
||||
if res["server_pred"].flatten().shape == server_val.value.flatten().shape:
|
||||
df_server = pd.DataFrame(data={
|
||||
"server_pred": res["server_pred"].flatten(),
|
||||
"domain": domains,
|
||||
"server_val": server_val.value.flatten()
|
||||
})
|
||||
# if res["server_pred"].flatten().shape == server_val.value.flatten().shape:
|
||||
# df_server = pd.DataFrame(data={
|
||||
# "server_pred": res["server_pred"].flatten(),
|
||||
# "domain": domains,
|
||||
# "server_val": server_val.value.flatten()
|
||||
# })
|
||||
|
||||
res = pd.DataFrame(data=data)
|
||||
res["client_val"] = np.logical_or(res.hits_vt == 1.0, res.hits_trusted >= 3)
|
||||
@ -716,8 +711,15 @@ def main_beta():
|
||||
df_domain_avg.server_val.as_matrix(),
|
||||
server_domain_avg_preds)
|
||||
|
||||
joblib.dump(curves, f"{path}/curves.joblib")
|
||||
|
||||
joblib.dump(curves, f"{args.model_path}_curves.joblib")
|
||||
try:
|
||||
curves_all: dict = joblib.load(f"{path}/curves.joblib")
|
||||
logger.info(f"load file {path}/curves.joblib successfully")
|
||||
curves_all[model_prefix] = curves[model_prefix]
|
||||
except Exception:
|
||||
curves_all = curves
|
||||
logger.info(f"currently {len(curves_all)} models in file: {curves_all.keys()}")
|
||||
joblib.dump(curves_all, f"{path}/curves.joblib")
|
||||
|
||||
import matplotlib
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user