Source code for xclim.sdba.detrending

"""
Detrending objects
------------------
"""
from typing import Union

import xarray as xr

from xclim.core.units import convert_units_to

from .base import Grouper, ParametrizableWithDataset, map_groups, parse_group
from .loess import loess_smoothing
from .utils import ADDITIVE, apply_correction, invert


[docs]class BaseDetrend(ParametrizableWithDataset): """Base class for detrending objects. Defines three methods: fit(da) : Compute trend from da and return a new _fitted_ Detrend object. detrend(da) : Return detrended array. retrend(da) : Puts trend back on da. A fitted Detrend object is unique to the trend coordinate of the object used in `fit`, (usually 'time'). The computed trend is stored in ``Detrend.ds.trend``. Subclasses should implement ``_get_trend_group()`` or ``_get_trend()``. The first will be called in a ``group.apply(..., main_only=True)``, and should return a single DataArray. The second allows the use of functions wrapped in :py:func:`map_groups` and should also return a single DataArray. The subclasses may reimplement ``_detrend`` and ``_retrend``. """ @parse_group def __init__( self, *, group: Union[Grouper, str] = "time", kind: str = "+", **kwargs ): """Initialize Detrending object. Parameters ---------- group : Union[str, Grouper] The grouping information. See :py:class:`xclim.sdba.base.Grouper` for details. The fit is performed along the group's main dim. kind : {'*', '+'} The way the trend is removed or added, either additive or multiplicative. """ super().__init__(group=group, kind=kind, **kwargs) @property def fitted(self): return hasattr(self, "ds")
[docs] def fit(self, da: xr.DataArray): """Extract the trend of a DataArray along a specific dimension. Returns a new object that can be used for detrending and retrending. Fitted objects are unique to the fitted coordinate used. """ new = self.__class__(**self.parameters) new.set_dataset(new._get_trend(da).rename("trend").to_dataset()) new.ds.trend.attrs["units"] = da.attrs.get("units", "") return new
def _get_trend(self, da: xr.DataArray): """Computes the trend, along the self.group.dim as found on da. If da is a DataArray (and has a "dtype" attribute), the trend is casted to have the same dtype. This method applies `_get_trend_group` with `self.group`. """ out = self.group.apply( self._get_trend_group, da, ) if hasattr(da, "dtype"): out = out.astype(da.dtype) return out.rename("trend")
[docs] def detrend(self, da: xr.DataArray): """Remove the previously fitted trend from a DataArray.""" if not self.fitted: raise ValueError("You must call fit() before detrending.") trend = self.ds.trend if "units" in da.attrs: trend = convert_units_to(self.ds.trend, da) return self._detrend(da, trend)
[docs] def retrend(self, da: xr.DataArray): """Put the previously fitted trend back on a DataArray.""" if not self.fitted: raise ValueError("You must call fit() before retrending") trend = self.ds.trend if "units" in da.attrs: trend = convert_units_to(self.ds.trend, da) return self._retrend(da, trend)
def _detrend(self, da, trend): # Remove trend from series return apply_correction(da, invert(trend, self.kind), self.kind) def _retrend(self, da, trend): # Add trend to series return apply_correction(da, trend, self.kind) def _get_trend_group(self, grpd, *, dim): raise NotImplementedError def __repr__(self): rep = super().__repr__() if not self.fitted: return f"<{rep} | unfitted>" return rep
[docs]class NoDetrend(BaseDetrend): """Convenience class for polymorphism. Does nothing.""" def _get_trend_group(self, da, *, dim): return da.isel({d: 0 for d in dim}) def _detrend(self, da, trend): return da def _retrend(self, da, trend): return da
[docs]class MeanDetrend(BaseDetrend): """Simple detrending removing only the mean from the data, quite similar to normalizing.""" def _get_trend(self, da): return _meandetrend_get_trend(da, **self).trend
@map_groups(trend=[Grouper.DIM]) def _meandetrend_get_trend(da, *, dim, kind): trend = da.mean(dim).broadcast_like(da) return trend.rename("trend").to_dataset()
[docs]class PolyDetrend(BaseDetrend): """ Detrend time series using a polynomial regression. Parameters ---------- group : Union[str, Grouper] The grouping information. See :py:class:`xclim.sdba.base.Grouper` for details. The fit is performed along the group's main dim. kind : {'*', '+'} The way the trend is removed or added, either additive or multiplicative. degree : int The order of the polynomial to fit. preserve_mean : bool Whether to preserve the mean when de/re-trending. If True, the trend has its mean removed before it is used. """ def __init__(self, group="time", kind=ADDITIVE, degree=4, preserve_mean=False): super().__init__( group=group, kind=kind, degree=degree, preserve_mean=preserve_mean ) def _get_trend(self, da): # Estimate trend over da trend = _polydetrend_get_trend(da, **self) return trend.trend
@map_groups(trend=[Grouper.DIM]) def _polydetrend_get_trend(da, *, dim, degree, preserve_mean, kind): """Polydetrend, atomic func on 1 group.""" if len(dim) > 1: da = da.mean(dim[1:]) dim = dim[0] pfc = da.polyfit(dim=dim, deg=degree) trend = xr.polyval(coord=da[dim], coeffs=pfc.polyfit_coefficients) if preserve_mean: trend = apply_correction(trend, invert(trend.mean(dim=dim), kind), kind) return trend.rename("trend").to_dataset()
[docs]class LoessDetrend(BaseDetrend): """ Detrend time series using a LOESS regression. The fit is a piecewise linear regression. For each point, the contribution of all neighbors is weighted by a bell-shaped curve (gaussian) with parameters sigma (std). The x-coordinate of the DataArray is scaled to [0,1] before the regression is computed. Parameters ---------- group : Union[str, Grouper] The grouping information. See :py:class:`xclim.sdba.base.Grouper` for details. The fit is performed along the group's main dim. kind : {'*', '+'} The way the trend is removed or added, either additive or multiplicative. d: [0, 1] Order of the local regression. Only 0 and 1 currently implemented. f : float Parameter controlling the span of the weights, between 0 and 1. niter : int Number of robustness iterations to execute. weights : ["tricube", "gaussian"] Shape of the weighting function: "tricube" : a smooth top-hat like curve, f gives the span of non-zero values. "gaussian" : a gaussian curve, f gives the span for 95% of the values. Notes ----- LOESS smoothing is computationally expensive. As it relies on a loop on gridpoints, it can be useful to use smaller than usual chunks. Moreover, it suffers from heavy boundary effects. As a rule of thumb, the outermost N * f/2 points should be considered dubious. (N is the number of points along each group) """ def __init__( self, group="time", kind=ADDITIVE, f=0.2, niter=1, d=0, weights="tricube", equal_spacing=None, ): super().__init__( group=group, kind=kind, f=f, niter=niter, d=0, weights=weights, equal_spacing=equal_spacing, ) def _get_trend(self, da): # Estimate trend over da trend = _loessdetrend_get_trend(da, **self) return trend.trend
@map_groups(trend=[Grouper.DIM]) def _loessdetrend_get_trend(da, *, dim, f, niter, d, weights, equal_spacing, kind): if len(dim) > 1: da = da.mean(dim[1:]) trend = loess_smoothing( da, dim=dim[0], f=f, niter=niter, d=d, weights=weights, equal_spacing=equal_spacing, ) return trend.rename("trend").to_dataset()
[docs]class RollingMeanDetrend(BaseDetrend): """ Detrend time series using a rolling mean. Parameters ---------- group : Union[str, Grouper] The grouping information. See :py:class:`xclim.sdba.base.Grouper` for details. The fit is performed along the group's main dim. kind : {'*', '+'} The way the trend is removed or added, either additive or multiplicative. win : int The size of the rolling window. Units are the steps of the grouped data, which means this detrending is best use with either `group='time'` or `group='time.dayofyear'`. Other grouping will have large jumps included within the windows and :py`:class:`LoessDetrend` might offer a better solution. weights : sequence of floats, optional Sequence of length `win`. Defaults to None, which means a flat window. min_periods: int, optional Minimum number of observations in window required to have a value, otherwise the result is NaN. See :py:meth:`xarray.DataArray.rolling`. Defaults to None, which sets it equal to `win`. Setting both `weights` and this is not implemented yet. Notes ----- As for the :py:class:`LoessDetrend` detrending, important boundary effects are to be expected. """ def __init__( self, group="time", kind=ADDITIVE, win=30, weights=None, min_periods=None ): if weights is not None: weights = xr.DataArray(weights, dims=("window",)) weights = weights / weights.sum() if min_periods is not None: raise NotImplementedError( "Setting both `min_periods` and `weights` is not implemented yet." ) super().__init__( group=group, kind=kind, win=win, weights=weights, min_periods=min_periods ) def _get_trend(self, da): # Estimate trend over da trend = _rollingmean_get_trend(da, **self) return trend.trend
@map_groups(trend=[Grouper.DIM]) def _rollingmean_get_trend(da, *, dim, kind, win, weights, min_periods): if len(dim) > 1: da = da.mean(dim[1:]) roll = da.rolling(center=True, min_periods=min_periods, **{dim[0]: win}) if weights is not None: trend = roll.construct("window").dot(weights) else: trend = roll.mean() return trend.rename("trend").to_dataset()