bex package
Subpackages
Module contents
- class bex.Benchmark(batch_size=12, num_workers=8, n_samples=100, corr_level=0.95, n_corr=10, seed=0, logger=<class 'bex.loggers.basic.BasicLogger'>, data_path=None, download=True)
Bases:
object
Main class to evaluate counterfactual explanations methods
Bex benchmark for evaluating explainability methods. For a list of supported methods see
bex.explainers
The benchmark also supports custom methods, for an example on how to create and evaluate your own methods seeExplainerBase
- Parameters:
batch_size (
int
, optional) – dataloader batch size (default: 12)num_workers (
int
, optional) – dataloader number of workers (default: 2)n_samples (
int
, optional) – number of samples to explain per confidence level (default: 100)corr_level (
float
, optional) – 0.50 or 0.95 correlation level of the spuriously correlated attributes \(z_{\text{corr}}\) (default: 0.95)n_corr (
int
, optional) – 6 or 10 number of correlated attributes (default: 10)seed – (
int
, optional) numpy and torch random seed (default: 0)data_path (
str
, optional) path to download the datasets and models, defaults to (~/.bex) –download (
bool
, optional) – True)logger (
BasicLogger
, optional) – logger to log results and examples, if None nothing will be logged (default:BasicLogger
)
- run(explainer, output_path=None, device='', **kwargs)
Evaluates an explainer on the Bex benchmark
- Parameters:
explainer (
string
) – explainability method to be evaluatedoutput_path (
string
, optional) – directory to store results and examples if logger is not None (default: output/datetime.now())device (
string
, optional) – device on which to run (default: ‘cuda’ if available else ‘cpu’)**kwargs – keyword arguments for the explainer
bex.explainers
Example
bn = bex.Benchmark() bn.run("stylex")
- runs(exp_list, **kwargs)
Evaluates a list of explainers on the Bex benchmark
- Parameters:
exp_list (
List[Dict]
) – list of dictionaries containing explainers and its parameters**kwargs – keyword arguments for
run()
Example
bn = bex.Benchmark() # run dive and dice to_run = [{"explainer": "dive": "lr": 0.1}, {"explainer": "dice": "lr": 0.01}] bn.runs(to_run)