explaining_framework/stat_parser.py

85 lines
3.0 KiB
Python
Raw Normal View History

import argparse
import glob
import multiprocessing as mp
import os
from collections import defaultdict
import pandas as pd
from explaining_framework.utils.io import read_json, read_yaml
parser = argparse.ArgumentParser()
parser.add_argument(
"--exp_dir",
help="Parent directory of all explanations",
default="./explanations",
dest="ed",
)
parser.add_argument(
"--explainer_name",
help="Name of the explaining methods you want to parse from",
dest="en",
)
parser.add_argument("--dataset_name", help="Name of the explained dataset", dest="dn")
parser.add_argument("--metric_name", help="Name of the objective metric", dest="mn")
parser.add_argument("--outfile", help="Path for output CSV", dest="op")
args = parser.parse_args()
PARENT_PATH = os.path.abspath(args.ed)
DATA = []
DATA_DICT = defaultdict(list)
pd1 = pd.DataFrame()
for p1 in glob.glob(os.path.join(PARENT_PATH, args.dn, "**")):
CONFIG_path = os.path.join(p1, "config.yaml")
INFO_path = os.path.join(p1, "info.json")
CONFIG_DICT = read_yaml(CONFIG_path)
INFO_DICT = read_yaml(INFO_path)
INFO_DICT_KEYS = [k for k in INFO_DICT.keys() if "path" in k]
for k in INFO_DICT_KEYS:
INFO_DICT.pop(k)
for p2 in glob.glob(os.path.join(p1, args.en, "**")):
EXPLAINER_CONFIG_path = os.path.join(p2, "explainer_cfg.yaml")
EXPLAINER_CONFIG_DICT = read_yaml(EXPLAINER_CONFIG_path)
for p3 in glob.glob(os.path.join(p2, "Attack", "**")):
attack_type = os.path.basename(p3)[5:]
for p4 in glob.glob(os.path.join(p3, "Adjust", "**")):
adjust_type = os.path.basename(p4)[9:]
for p5 in glob.glob(os.path.join(p4, "**", "**")):
thres_type = os.path.basename(p5)[15:]
for p6 in glob.glob(os.path.join(p5, args.mn, "**")):
metric_type = os.path.basename(p6)[5:]
all_file = sorted(
glob.glob(os.path.join(p6, "**", "*.json"), recursive=True)
)
with mp.Pool(mp.cpu_count()) as pool:
data = pool.map(read_json, all_file)
d1 = [
pd.json_normalize(INFO_DICT),
pd.json_normalize(CONFIG_DICT),
pd.json_normalize(EXPLAINER_CONFIG_DICT),
pd.DataFrame(
{
"Adjust": [adjust_type],
"Threshold": [thres_type],
"Metric": [metric_type],
}
),
]
for da in data:
v = pd.json_normalize({"val": list(da.values())[0]})
d_ = pd.concat((*d1, v), axis=1)
pd1 = pd.concat((pd1, d_))
pd1.to_csv(args.op)