from operator import attrgetter
from typing import List, Optional, Tuple, Type, Union
from anml.data.component import Component
from anml.getter.prior import SplinePriorGetter
from anml.getter.spline import SplineGetter
from anml.prior.main import Prior
from anml.variable.main import Variable
from numpy.typing import NDArray
from pandas import DataFrame
from xspline import XSpline
SplineVariablePrior: Type = Union[Prior, SplinePriorGetter]
"""Allowed prior type for spline variable.
"""
[docs]class SplineVariable(Variable):
"""Variable class that contains information of variable, including name,
priors and spline.
Parameters
----------
component
You can pass in the name of the variable corresponding to the column
name in the data frame. It will be automatically converted into an
instance of :class:`Component` with :class:`NoNans` as the validator.
Alternatively, you can also pass in an instance of :class:`Component`,
with your own set of validators.
spline
Given spline for the variable. You can pass in an instance of
:class:`XSpline` or :class:`SplineGetter`. If input is an instance of
:class:`SplineGetter`, when use attach data it will automatically
envolve into an instance of :class:`XSpline`.
priors
A list of priors corresponding to the variable. The prior in the list
can be either an instance of :class:`Prior` or
:class:`SplinePriorGetter`. When attach data, the instance of
:class:`SplinePriorGetter` will envolve into an instance of
:class:`Prior`.
"""
spline = property(attrgetter("_spline"))
"""Given spline for the variable.
Raises
------
TypeError
Raised if input spline is not an instance of :class:`XSpline` nor
:class:`SplineGetter`.
"""
_prior_types: Tuple[Type, ...] = SplineVariablePrior.__args__
def __init__(self,
component: Union[str, Component],
spline: Union[XSpline, SplineGetter],
priors: Optional[List[SplineVariablePrior]] = None):
super().__init__(component, priors)
self.spline = spline
@spline.setter
def spline(self, spline: Union[XSpline, SplineGetter]):
if not isinstance(spline, (XSpline, SplineGetter)):
raise TypeError("Spline variable input spline must be an instance "
"of XSpline or SplineGetter.")
self._spline = spline
@property
def size(self) -> int:
return self.spline.num_spline_bases
[docs] def attach(self, df: DataFrame):
"""Attach the data to variable. It will attach data to the component.
And create spline and priors if necessary.
Parameters
----------
df
The data frame contains the corresponding data column.
"""
self.component.attach(df)
if isinstance(self.spline, SplineGetter):
self.spline = self.spline.get_spline(self.component.value)
for i in range(len(self.priors)):
if isinstance(self.priors[i], SplinePriorGetter):
self.priors[i] = self.priors[i].get_prior(self.spline)
[docs] def get_design_mat(self, df: DataFrame) -> NDArray:
self.attach(df)
return self.spline.design_mat(self.component.value, l_extra=True, r_extra=True)