explaining_framework/README.md

4.2 KiB

Explaining framework

PyTorch-Geometric add-on for heavy and parallel experiments running to explain Graph Neural Networks models. Based on a config.yaml file that you can set up as you wish, the framework do preprocess, handles simultaneous experiments as well as postprocess operations on his own.

How to

  1. Set up your experiment details (dataset, GNN architecture, explaining method, metrics, GPU workload limit, etc.).
 # ----------------------------------------------------------------------- #
    # Basic options
    # ----------------------------------------------------------------------- #

    # Set print destination: stdout / file / both
    explaining_cfg.print = "both"

    explaining_cfg.out_dir = "./explanations"

    explaining_cfg.cfg_dest = "explaining_config.yaml"

    explaining_cfg.seed = 0

    # ----------------------------------------------------------------------- #
    # Dataset options
    # ----------------------------------------------------------------------- #

    explaining_cfg.dataset = CN()

    explaining_cfg.dataset.name = "Cora"

    explaining_cfg.dataset.item = []

    # ----------------------------------------------------------------------- #
    # Model options
    # ----------------------------------------------------------------------- #

    explaining_cfg.model = CN()

    # Set wether or not load the best model for given dataset or a path
    explaining_cfg.model.ckpt = "best"

    # Setting the path of models folder
    explaining_cfg.model.path = "path"

    # ----------------------------------------------------------------------- #
    # Explainer options
    # ----------------------------------------------------------------------- #

    explaining_cfg.explainer = CN()

    # Name of the explaining method
    explaining_cfg.explainer.name = "EiXGNN"

    # Whether or not to provide specific explaining methods configuration or default configuration
    explaining_cfg.explainer.cfg = "default"

    # Whether or not recomputing explanation if they already exist
    explaining_cfg.explainer.force = False

    # ----------------------------------------------------------------------- #
    # Explaining options
    # ----------------------------------------------------------------------- #

    # 'ExplanationType : 'model' or 'phenomenon'
    explaining_cfg.explanation_type = "model"

    explaining_cfg.model_config = CN()

    # Do not modify it, will be handled by dataset , assuming one dataset = one learning task
    explaining_cfg.model_config.mode = "regression"

    # Do not modify it, will be handled by dataset , assuming one dataset = one learning task
    explaining_cfg.model_config.task_level = None

    # Do not modify it, we always assume here that model output are 'raw'
    explaining_cfg.model_config.return_type = "raw"

    # ----------------------------------------------------------------------- #
    # Thresholding options
    # ----------------------------------------------------------------------- #

    explaining_cfg.threshold = CN()

    explaining_cfg.threshold.config = CN()
    explaining_cfg.threshold.config.type = "all"

    explaining_cfg.threshold.value = CN()
    explaining_cfg.threshold.value.hard = [(i * 10) / 100 for i in range(10)]
    explaining_cfg.threshold.value.topk = [2, 3, 5, 10, 20, 30, 50]

    # which objectives metrics to computes, either all or one in particular if implemented
    explaining_cfg.metrics = CN()
    explaining_cfg.metrics.sparsity = CN()
    explaining_cfg.metrics.sparsity.name = "all"
    explaining_cfg.metrics.fidelity = CN()
    explaining_cfg.metrics.fidelity.name = "all"
    explaining_cfg.metrics.accuracy = CN()
    explaining_cfg.metrics.accuracy.name = "all"

    # Whether or not recomputing metrics if they already exist

    explaining_cfg.adjust = CN()
    explaining_cfg.adjust.strategy = "rpns"

    explaining_cfg.attack = CN()
    explaining_cfg.attack.name = "all"

    # Select device: 'cpu', 'cuda', 'auto'
    explaining_cfg.accelerator = "auto"

  1. Provide the generated .yaml file to main.py or a folder path to a configs file stack for parallel running.

  2. Run

  3. Check already post-processed (statistics, plots) results.