Source code for anml.prior.utils
from typing import List, Optional, Type
import anml
from anml.prior.main import Prior
[docs]def get_prior_type(prior_type: str) -> Type:
"""Get prior type from the prior type name.
Parameters
----------
prior_type
Name of the prior type (class).
Returns
-------
Type
The class corresponding to the given prior type name.
Examples
--------
>>> prior_type = get_prior_type("GaussianPrior")
>>> prior_type
<class 'anml.prior.main.GaussianPrior'>
"""
return getattr(anml.prior.main, prior_type)
[docs]def filter_priors(priors: List[Prior],
prior_type: str,
with_mat: Optional[bool] = None) -> List[Prior]:
"""Filter priors from a list of priors by their type and do they contain
linear map or not.
Parameters
----------
priors
Given list of priors. Note that it is user's responsibility to check if
all elements in the list are instances of Prior.
prior_type
Given prior type name.
with_mat
If the filtered priors are all contain a linear map. Default to `None`.
If `with_mat=None`, the final list will include priors that both
contain or not contain the linear map.
Returns
-------
List[Prior]
Filtered priors.
"""
prior_type = get_prior_type(prior_type)
def condition(prior, prior_type=prior_type, with_mat=with_mat):
is_prior_instance = isinstance(prior, prior_type)
if with_mat is None:
return is_prior_instance
if with_mat:
return prior.mat is not None and is_prior_instance
return prior.mat is None and is_prior_instance
return list(filter(condition, priors))