Source code for anml.getter.spline

from operator import attrgetter

import numpy as np
from numpy.typing import NDArray
from xspline import XSpline


[docs]class SplineGetter: """Spline getter for :class:`XSpline` instance. Given the settings of the spline, when attach the data it can infer the knots position, construct and return an instance of :class:`XSpline`. Parameters ---------- knots Knots placement of the spline. Depends on `knots_type` this will be used differently. degree Degree of the spline. Default to be 3. l_linear If `True`, spline will use left linear tail. Default to be `False`. r_linear If `True`, spline will use right linear tail. Default to be `False`. include_first_basis If `True`, spline will include the first basis of the spline. Default to be `True`. knots_type : {'abs', 'rel_domain', 'rel_freq'} Type of the spline knots. Can only be choosen from three options, `'abs'`, `'rel_domian'` and `'rel_freq'`. When it is `'abs'` which standards for absolute, the knots will be used as it is. When it is `rel_domain` which standards for relative domain, the knots requires to be between 0 and 1, and will be interpreted as the proportion of the domain. And when it is `rel_freq` which standards for relative frequency, it will be interpreted as the frequency of the data and required to be between 0 and 1. """ knots_type = property(attrgetter("_knots_type")) """Type of the spline knots. Raises ------ ValueError Raised when the input knots type are not one of 'abs', 'rel_domain' or 'rel_freq'. """ def __init__(self, knots: NDArray, degree: int = 3, l_linear: bool = False, r_linear: bool = False, include_first_basis: bool = False, knots_type: str = "abs"): self.knots = knots self.degree = degree self.l_linear = l_linear self.r_linear = r_linear self.include_first_basis = include_first_basis self.knots_type = knots_type @knots_type.setter def knots_type(self, knots_type: str): if knots_type not in ["abs", "rel_domain", "rel_freq"]: raise ValueError("Knots type must be one of 'abs', 'rel_domain' or 'rel_freq'.") self._knots_type = knots_type @property def num_spline_bases(self) -> int: """Number of the spline bases. """ inner_knots = self.knots[int(self.l_linear): len(self.knots) - int(self.r_linear)] return len(inner_knots) - 2 + self.degree + int(self.include_first_basis)
[docs] def get_spline(self, data: NDArray) -> XSpline: """Get spline instance given data array. Parameters ---------- data Given data array to infer the knots placement. Returns ------- XSpline A spline instance. """ if self.knots_type == "abs": knots = self.knots else: if self.knots_type == "rel_domain": lb, ub = data.min(), data.max() knots = lb + self.knots*(ub - lb) else: knots = np.quantile(data, self.knots) return XSpline(knots, self.degree, l_linear=self.l_linear, r_linear=self.r_linear, include_first_basis=self.include_first_basis)