85 lines
3.0 KiB
Python
85 lines
3.0 KiB
Python
|
import argparse
|
||
|
import glob
|
||
|
import multiprocessing as mp
|
||
|
import os
|
||
|
from collections import defaultdict
|
||
|
|
||
|
import pandas as pd
|
||
|
|
||
|
from explaining_framework.utils.io import read_json, read_yaml
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument(
|
||
|
"--exp_dir",
|
||
|
help="Parent directory of all explanations",
|
||
|
default="./explanations",
|
||
|
dest="ed",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--explainer_name",
|
||
|
help="Name of the explaining methods you want to parse from",
|
||
|
dest="en",
|
||
|
)
|
||
|
parser.add_argument("--dataset_name", help="Name of the explained dataset", dest="dn")
|
||
|
parser.add_argument("--metric_name", help="Name of the objective metric", dest="mn")
|
||
|
parser.add_argument("--outfile", help="Path for output CSV", dest="op")
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
PARENT_PATH = os.path.abspath(args.ed)
|
||
|
DATA = []
|
||
|
DATA_DICT = defaultdict(list)
|
||
|
pd1 = pd.DataFrame()
|
||
|
|
||
|
for p1 in glob.glob(os.path.join(PARENT_PATH, args.dn, "**")):
|
||
|
CONFIG_path = os.path.join(p1, "config.yaml")
|
||
|
INFO_path = os.path.join(p1, "info.json")
|
||
|
|
||
|
CONFIG_DICT = read_yaml(CONFIG_path)
|
||
|
INFO_DICT = read_yaml(INFO_path)
|
||
|
INFO_DICT_KEYS = [k for k in INFO_DICT.keys() if "path" in k]
|
||
|
|
||
|
for k in INFO_DICT_KEYS:
|
||
|
INFO_DICT.pop(k)
|
||
|
for p2 in glob.glob(os.path.join(p1, args.en, "**")):
|
||
|
EXPLAINER_CONFIG_path = os.path.join(p2, "explainer_cfg.yaml")
|
||
|
EXPLAINER_CONFIG_DICT = read_yaml(EXPLAINER_CONFIG_path)
|
||
|
|
||
|
for p3 in glob.glob(os.path.join(p2, "Attack", "**")):
|
||
|
attack_type = os.path.basename(p3)[5:]
|
||
|
|
||
|
for p4 in glob.glob(os.path.join(p3, "Adjust", "**")):
|
||
|
adjust_type = os.path.basename(p4)[9:]
|
||
|
|
||
|
for p5 in glob.glob(os.path.join(p4, "**", "**")):
|
||
|
thres_type = os.path.basename(p5)[15:]
|
||
|
|
||
|
for p6 in glob.glob(os.path.join(p5, args.mn, "**")):
|
||
|
metric_type = os.path.basename(p6)[5:]
|
||
|
|
||
|
all_file = sorted(
|
||
|
glob.glob(os.path.join(p6, "**", "*.json"), recursive=True)
|
||
|
)
|
||
|
with mp.Pool(mp.cpu_count()) as pool:
|
||
|
data = pool.map(read_json, all_file)
|
||
|
|
||
|
d1 = [
|
||
|
pd.json_normalize(INFO_DICT),
|
||
|
pd.json_normalize(CONFIG_DICT),
|
||
|
pd.json_normalize(EXPLAINER_CONFIG_DICT),
|
||
|
pd.DataFrame(
|
||
|
{
|
||
|
"Adjust": [adjust_type],
|
||
|
"Threshold": [thres_type],
|
||
|
"Metric": [metric_type],
|
||
|
}
|
||
|
),
|
||
|
]
|
||
|
|
||
|
for da in data:
|
||
|
v = pd.json_normalize({"val": list(da.values())[0]})
|
||
|
d_ = pd.concat((*d1, v), axis=1)
|
||
|
pd1 = pd.concat((pd1, d_))
|
||
|
|
||
|
|
||
|
pd1.to_csv(args.op)
|