New module for parsing explaining stats
This commit is contained in:
parent
3cf4187164
commit
cb8dbf2629
|
@ -0,0 +1,84 @@
|
|||
import argparse
|
||||
import glob
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from explaining_framework.utils.io import read_json, read_yaml
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--exp_dir",
|
||||
help="Parent directory of all explanations",
|
||||
default="./explanations",
|
||||
dest="ed",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--explainer_name",
|
||||
help="Name of the explaining methods you want to parse from",
|
||||
dest="en",
|
||||
)
|
||||
parser.add_argument("--dataset_name", help="Name of the explained dataset", dest="dn")
|
||||
parser.add_argument("--metric_name", help="Name of the objective metric", dest="mn")
|
||||
parser.add_argument("--outfile", help="Path for output CSV", dest="op")
|
||||
args = parser.parse_args()
|
||||
|
||||
PARENT_PATH = os.path.abspath(args.ed)
|
||||
DATA = []
|
||||
DATA_DICT = defaultdict(list)
|
||||
pd1 = pd.DataFrame()
|
||||
|
||||
for p1 in glob.glob(os.path.join(PARENT_PATH, args.dn, "**")):
|
||||
CONFIG_path = os.path.join(p1, "config.yaml")
|
||||
INFO_path = os.path.join(p1, "info.json")
|
||||
|
||||
CONFIG_DICT = read_yaml(CONFIG_path)
|
||||
INFO_DICT = read_yaml(INFO_path)
|
||||
INFO_DICT_KEYS = [k for k in INFO_DICT.keys() if "path" in k]
|
||||
|
||||
for k in INFO_DICT_KEYS:
|
||||
INFO_DICT.pop(k)
|
||||
for p2 in glob.glob(os.path.join(p1, args.en, "**")):
|
||||
EXPLAINER_CONFIG_path = os.path.join(p2, "explainer_cfg.yaml")
|
||||
EXPLAINER_CONFIG_DICT = read_yaml(EXPLAINER_CONFIG_path)
|
||||
|
||||
for p3 in glob.glob(os.path.join(p2, "Attack", "**")):
|
||||
attack_type = os.path.basename(p3)[5:]
|
||||
|
||||
for p4 in glob.glob(os.path.join(p3, "Adjust", "**")):
|
||||
adjust_type = os.path.basename(p4)[9:]
|
||||
|
||||
for p5 in glob.glob(os.path.join(p4, "**", "**")):
|
||||
thres_type = os.path.basename(p5)[15:]
|
||||
|
||||
for p6 in glob.glob(os.path.join(p5, args.mn, "**")):
|
||||
metric_type = os.path.basename(p6)[5:]
|
||||
|
||||
all_file = sorted(
|
||||
glob.glob(os.path.join(p6, "**", "*.json"), recursive=True)
|
||||
)
|
||||
with mp.Pool(mp.cpu_count()) as pool:
|
||||
data = pool.map(read_json, all_file)
|
||||
|
||||
d1 = [
|
||||
pd.json_normalize(INFO_DICT),
|
||||
pd.json_normalize(CONFIG_DICT),
|
||||
pd.json_normalize(EXPLAINER_CONFIG_DICT),
|
||||
pd.DataFrame(
|
||||
{
|
||||
"Adjust": [adjust_type],
|
||||
"Threshold": [thres_type],
|
||||
"Metric": [metric_type],
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
for da in data:
|
||||
v = pd.json_normalize({"val": list(da.values())[0]})
|
||||
d_ = pd.concat((*d1, v), axis=1)
|
||||
pd1 = pd.concat((pd1, d_))
|
||||
|
||||
|
||||
pd1.to_csv(args.op)
|
Loading…
Reference in New Issue