bex.explainers package
Module contents
- class bex.explainers.Dice(num_explanations=10, lr=0.1, num_iters=50, proximity_weight=1, diversity_weight=1)
Bases:
ExplainerBase
DiCE explainer as described in https://arxiv.org/abs/1905.07697
- Parameters:
num_explanations (
int
, optional) – number of counterfactuals to be generated (default: 10)lr (
float
, optional) – learning rate (default: 0.1)num_iters (
int
, optional) – number of gradient descent steps to perform (default: 50)proximity_weight (
float
, optional) – weight of the reconstruction term \(\lambda_1\) in the loss function (default: 1.0)diversity_weight (
float
, optional) – weight of the diversity term \(\lambda_2\) in the loss function (default: 1.0)
- explain_batch(latents, logits, images, classifier, generator)
Method to generate a set of counterfactuals for a given batch
- Parameters:
latents (
torch.Tensor
) – standardized latent \(\textbf{z}\) representation of samples to be perturbedlogits (
torch.Tensor
) – classifier logits given \(\textbf{z}\)images (
torch.Tensor
) – images \(x\) produced by the generator given \(\textbf{z}\)classifier (
torch.nn.Module
) – classifier to explain \(\hat{f}(x)\)generator (
callable
) – function that takes a batch of latents \(\textbf{z'}\) and returns a batch of images
- Returns:
the obtained counterfactuals \(\textbf{z'}\) for each batch element
- Return type:
(
torch.Tensor
)
- Shape:
latents \((B, Z)\)
logits \((B, 2)\)
images \((B, C, H, W)\)
obtained counterfactuals: \((B, n\_explanations, Z)\)
- class bex.explainers.Dive(num_explanations=10, lr=0.1, num_iters=50, diversity_weight=0.001, lasso_weight=0.1, reconstruction_weight=0.0001, method='fisher_spectral')
Bases:
ExplainerBase
DiVE algorithm as described in https://arxiv.org/abs/2103.10226
- Parameters:
num_explanations (
int
, optional) – number of counterfactuals to be generated (default: 10)lr (
float
, optional) – learning rate (default: 0.1)num_iters (
int
, optional) – number of gradient descent steps to perform (default: 50)diversity_weight (
float
, optional) – weight of the diversity term in the loss function (default: 0)lasso_weight (
float
, optional) – factor \(\gamma\) that controls the sparsity of the latent space (default: 0.1)reconstruction_weight (
float
, optional) – weight of the reconstruction term in the loss function (default: 0.001)method (
string
, optional) – method used for gradient masking (default: ‘fisher_spectral’)
- explain_batch(latents, logits, images, classifier, generator)
Method to generate a set of counterfactuals for a given batch
- Parameters:
latents (
torch.Tensor
) – standardized latent \(\textbf{z}\) representation of samples to be perturbedlogits (
torch.Tensor
) – classifier logits given \(\textbf{z}\)images (
torch.Tensor
) – images \(x\) produced by the generator given \(\textbf{z}\)classifier (
torch.nn.Module
) – classifier to explain \(\hat{f}(x)\)generator (
callable
) – function that takes a batch of latents \(\textbf{z'}\) and returns a batch of images
- Returns:
the obtained counterfactuals \(\textbf{z'}\) for each batch element
- Return type:
(
torch.Tensor
)
- Shape:
latents \((B, Z)\)
logits \((B, 2)\)
images \((B, C, H, W)\)
obtained counterfactuals: \((B, n\_explanations, Z)\)
- class bex.explainers.ExplainerBase
Bases:
object
Base class for all explainer methods
If you wish to test your own explainer on our benchmark use this as a base class and override the
explain_batch
methodExample
import random from bex.explainers import ExplainerBase class DummyExplainer(ExplainerBase): def __init__(self, num_explanations): super().__init__() self.num_explanations = num_explanations def explain_batch(self, latents, logits, images, classifier, generator): b = latents.shape[0] # we will produce self.num_explanations counterfactuals per sample z = latents[:, None, :].repeat(1, self.num_explanations, 1) z_perturbed = z + random.random() # create counterfactuals z' return z_perturbed.view(b, self.num_explanations, -1) bn = bex.Benchmark() bn.run(DummyExplainer, num_explanations=10)
- abstract explain_batch(latents, logits, images, classifier, generator)
Method to generate a set of counterfactuals for a given batch
- Parameters:
latents (
torch.Tensor
) – standardized latent \(\textbf{z}\) representation of samples to be perturbedlogits (
torch.Tensor
) – classifier logits given \(\textbf{z}\)images (
torch.Tensor
) – images \(x\) produced by the generator given \(\textbf{z}\)classifier (
torch.nn.Module
) – classifier to explain \(\hat{f}(x)\)generator (
callable
) – function that takes a batch of latents \(\textbf{z'}\) and returns a batch of images
- Returns:
the obtained counterfactuals \(\textbf{z'}\) for each batch element
- Return type:
(
torch.Tensor
)
- Shape:
latents \((B, Z)\)
logits \((B, 2)\)
images \((B, C, H, W)\)
obtained counterfactuals: \((B, n\_explanations, Z)\)
- class bex.explainers.GrowingSpheres(num_explanations=10, n_candidates=50, first_radius=10, decrease_radius=2)
Bases:
ExplainerBase
Growing Spheres explainer as described in https://arxiv.org/abs/1712.08443
num_explanations (
int
, optional): number of counterfactuals to be generated (default: 10)n_candidates (
int
, optional): number of observations \(n\) to generate at each step (default: 50) first_radius(float
, optional): radius \(\eta\) of the first hyperball generated (default: 10.0) decrease_radius(float
, optional): parameter controlling the size of the radius at each step (default: 2.0)- explain_batch(latents, logits, images, classifier, generator)
Method to generate a set of counterfactuals for a given batch
- Parameters:
latents (
torch.Tensor
) – standardized latent \(\textbf{z}\) representation of samples to be perturbedlogits (
torch.Tensor
) – classifier logits given \(\textbf{z}\)images (
torch.Tensor
) – images \(x\) produced by the generator given \(\textbf{z}\)classifier (
torch.nn.Module
) – classifier to explain \(\hat{f}(x)\)generator (
callable
) – function that takes a batch of latents \(\textbf{z'}\) and returns a batch of images
- Returns:
the obtained counterfactuals \(\textbf{z'}\) for each batch element
- Return type:
(
torch.Tensor
)
- Shape:
latents \((B, Z)\)
logits \((B, 2)\)
images \((B, C, H, W)\)
obtained counterfactuals: \((B, n\_explanations, Z)\)
- class bex.explainers.LCF(num_explanations=10, lr=0.1, num_iters=50, p=0.1, tolerance=0.5)
Bases:
ExplainerBase
Latent-CF explainer as described in https://arxiv.org/abs/2012.09301
- Parameters:
num_explanations (
int
, optional) – number of counterfactuals to be generated (default: 10)lr (
float
, optional) – learning rate (default: 0.1)num_iters (
int
, optional) – max number of gradient descent steps to perform without convergence (default: 50)p (
float
, optional) – probability \(p\) of target counterfactual class.tolerance (
float
, optional) – 0.5)
- explain_batch(latents, logits, images, classifier, generator)
Method to generate a set of counterfactuals for a given batch
- Parameters:
latents (
torch.Tensor
) – standardized latent \(\textbf{z}\) representation of samples to be perturbedlogits (
torch.Tensor
) – classifier logits given \(\textbf{z}\)images (
torch.Tensor
) – images \(x\) produced by the generator given \(\textbf{z}\)classifier (
torch.nn.Module
) – classifier to explain \(\hat{f}(x)\)generator (
callable
) – function that takes a batch of latents \(\textbf{z'}\) and returns a batch of images
- Returns:
the obtained counterfactuals \(\textbf{z'}\) for each batch element
- Return type:
(
torch.Tensor
)
- Shape:
latents \((B, Z)\)
logits \((B, 2)\)
images \((B, C, H, W)\)
obtained counterfactuals: \((B, n\_explanations, Z)\)
- class bex.explainers.Stylex(num_explanations=10, t=0.3, shift_size=0.8, strategy='independent')
Bases:
ExplainerBase
StylEx explainer as described in https://arxiv.org/abs/2104.13369
num_explanations (
int
, optional): number of counterfactuals to be generated (default: 10) t (float
, optional): perturbation threshold \(t\) to consider a sample explained (default: 0.3) shift_size (float
, optional): amount of shift applied to each coordinate (default: 0.8) strategy(string
, optional): selection strategy ‘independent’ or ‘subset’ (default: ‘independent’)- explain_batch(latents, logits, images, classifier, generator)
Method to generate a set of counterfactuals for a given batch
- Parameters:
latents (
torch.Tensor
) – standardized latent \(\textbf{z}\) representation of samples to be perturbedlogits (
torch.Tensor
) – classifier logits given \(\textbf{z}\)images (
torch.Tensor
) – images \(x\) produced by the generator given \(\textbf{z}\)classifier (
torch.nn.Module
) – classifier to explain \(\hat{f}(x)\)generator (
callable
) – function that takes a batch of latents \(\textbf{z'}\) and returns a batch of images
- Returns:
the obtained counterfactuals \(\textbf{z'}\) for each batch element
- Return type:
(
torch.Tensor
)
- Shape:
latents \((B, Z)\)
logits \((B, 2)\)
images \((B, C, H, W)\)
obtained counterfactuals: \((B, n\_explanations, Z)\)
- class bex.explainers.Xgem(num_explanations=10, lr=0.1, num_iters=50, reconstruction_weight=0.001)
Bases:
Dive
xGEM explainer as described in https://arxiv.org/abs/1806.08867
- Parameters:
num_explanations (
int
, optional) – number of counterfactuals to be generated (default: 10)lr (
float
, optional) – learning rate (default: 0.1)num_iters (
int
, optional) – number of gradient descent steps to perform (default: 50)reconstruction_weight (
float
, optional) – weight of the reconstruction term in the loss function (default: 0.01)
- explain_batch(latents, logits, images, classifier, generator)
Method to generate a set of counterfactuals for a given batch
- Parameters:
latents (
torch.Tensor
) – standardized latent \(\textbf{z}\) representation of samples to be perturbedlogits (
torch.Tensor
) – classifier logits given \(\textbf{z}\)images (
torch.Tensor
) – images \(x\) produced by the generator given \(\textbf{z}\)classifier (
torch.nn.Module
) – classifier to explain \(\hat{f}(x)\)generator (
callable
) – function that takes a batch of latents \(\textbf{z'}\) and returns a batch of images
- Returns:
the obtained counterfactuals \(\textbf{z'}\) for each batch element
- Return type:
(
torch.Tensor
)
- Shape:
latents \((B, Z)\)
logits \((B, 2)\)
images \((B, C, H, W)\)
obtained counterfactuals: \((B, n\_explanations, Z)\)