Source code for xclim.core.bootstrapping

"""Module comprising the bootstrapping algorithm for indicators."""

from __future__ import annotations

import warnings
from inspect import signature
from typing import Any, Callable

import cftime
import numpy as np
import xarray
from boltons.funcutils import wraps
from xarray.core.dataarray import DataArray

import xclim.core.utils

from .calendar import convert_calendar, parse_offset, percentile_doy

BOOTSTRAP_DIM = "_bootstrap"


[docs] def percentile_bootstrap(func): """Decorator applying a bootstrap step to the calculation of exceedance over a percentile threshold. This feature is experimental. Bootstrapping avoids discontinuities in the exceedance between the reference period over which percentiles are computed, and "out of reference" periods. See `bootstrap_func` for details. Declaration example: .. code-block:: python @declare_units(tas="[temperature]", t90="[temperature]") @percentile_bootstrap def tg90p( tas: xarray.DataArray, t90: xarray.DataArray, freq: str = "YS", bootstrap: bool = False, ) -> xarray.DataArray: pass Examples -------- >>> from xclim.core.calendar import percentile_doy >>> from xclim.indices import tg90p >>> tas = xr.open_dataset(path_to_tas_file).tas >>> # To start bootstrap reference period must not fully overlap the studied period. >>> tas_ref = tas.sel(time=slice("1990-01-01", "1992-12-31")) >>> t90 = percentile_doy(tas_ref, window=5, per=90) >>> tg90p(tas=tas, tas_per=t90.sel(percentiles=90), freq="YS", bootstrap=True) """ @wraps(func) def wrapper(*args, **kwargs): ba = signature(func).bind(*args, **kwargs) ba.apply_defaults() bootstrap = ba.arguments.get("bootstrap", False) if bootstrap is False: return func(*args, **kwargs) return bootstrap_func(func, **ba.arguments) return wrapper
[docs] def bootstrap_func(compute_index_func: Callable, **kwargs) -> xarray.DataArray: r"""Bootstrap the computation of percentile-based indices. Indices measuring exceedance over percentile-based thresholds (such as tx90p) may contain artificial discontinuities at the beginning and end of the reference period used to calculate percentiles. The bootstrap procedure can reduce those discontinuities by iteratively computing the percentile estimate and the index on altered reference periods. These altered reference periods are themselves built iteratively: When computing the index for year `x`, the bootstrapping creates as many altered reference periods as the number of years in the reference period. To build one altered reference period, the values of year `x` are replaced by the values of another year in the reference period, then the index is computed on this altered period. This is repeated for each year of the reference period, excluding year `x`. The final result of the index for year `x` is then the average of all the index results on altered years. Parameters ---------- compute_index_func : Callable Index function. \*\*kwargs Arguments to `func`. Returns ------- xr.DataArray The result of func with bootstrapping. References ---------- :cite:cts:`zhang_avoiding_2005` Notes ----- This function is meant to be used by the `percentile_bootstrap` decorator. The parameters of the percentile calculation (percentile, window, reference_period) are stored in the attributes of the percentile DataArray. The bootstrap algorithm implemented here does the following:: For each temporal grouping in the calculation of the index If the group `g_t` is in the reference period For every other group `g_s` in the reference period Replace group `g_t` by `g_s` Compute percentile on resampled time series Compute index function using percentile Average output from index function over all resampled time series Else compute index function using original percentile """ # Identify the input and the percentile arrays from the bound arguments per_key = None for name, val in kwargs.items(): if isinstance(val, DataArray): if "percentile_doy" in val.attrs.get("history", ""): per_key = name else: da_key = name # Extract the DataArray inputs from the arguments da: DataArray = kwargs.pop(da_key) per_da: DataArray | None = kwargs.pop(per_key, None) if per_da is None: # per may be empty on non doy percentiles raise KeyError( "`bootstrap` can only be used with percentiles computed using `percentile_doy`" ) # Boundary years of reference period clim = per_da.attrs["climatology_bounds"] if xclim.core.utils.uses_dask(da) and len(da.chunks[da.get_axis_num("time")]) > 1: warnings.warn( "The input data is chunked on time dimension and must be fully re-chunked to" " run percentile bootstrapping." " Beware, this operation can significantly increase the number of tasks dask" " has to handle.", stacklevel=2, ) chunking = {d: "auto" for d in da.dims} chunking["time"] = -1 # no chunking on time to use map_block da = da.chunk(chunking) # overlap of studied `da` and reference period used to compute percentile overlap_da = da.sel(time=slice(*clim)) if len(overlap_da.time) == len(da.time): raise KeyError( "`bootstrap` is unnecessary when all years are overlapping between reference " "(percentiles period) and studied (index period) periods" ) if len(overlap_da) == 0: raise KeyError( "`bootstrap` is unnecessary when no year overlap between reference " "(percentiles period) and studied (index period) periods." ) pdoy_args = dict( window=per_da.attrs["window"], alpha=per_da.attrs["alpha"], beta=per_da.attrs["beta"], per=per_da.percentiles.data[()], ) bfreq = _get_bootstrap_freq(kwargs["freq"]) # Group input array in years, with an offset matching freq overlap_years_groups = overlap_da.resample(time=bfreq).groups da_years_groups = da.resample(time=bfreq).groups per_template = per_da.copy(deep=True) acc = [] # Compute bootstrapped index on each year of overlapping years for year_key, year_slice in da_years_groups.items(): kw = {da_key: da.isel(time=year_slice), **kwargs} if _get_year_label(year_key) in overlap_da.get_index("time").year: # If the group year is in both reference and studied periods, run the bootstrap bda = build_bootstrap_year_da(overlap_da, overlap_years_groups, year_key) if BOOTSTRAP_DIM not in per_template.dims: per_template = per_template.expand_dims( {BOOTSTRAP_DIM: np.arange(len(bda._bootstrap))} ) if xclim.core.utils.uses_dask(bda): chunking = { d: bda.chunks[bda.get_axis_num(d)] for d in set(bda.dims).intersection(set(per_template.dims)) } per_template = per_template.chunk(chunking) per = xarray.map_blocks( percentile_doy.__wrapped__, # strip history update from percentile_doy obj=bda, kwargs={**pdoy_args, "copy": False}, template=per_template, ) if "percentiles" not in per_da.dims: per = per.squeeze("percentiles") kw[per_key] = per value = compute_index_func(**kw).mean(dim=BOOTSTRAP_DIM, keep_attrs=True) else: # Otherwise, run the normal computation using the original percentile kw[per_key] = per_da value = compute_index_func(**kw) acc.append(value) result = xarray.concat(acc, dim="time") result.attrs["units"] = value.attrs["units"] return result
[docs] def _get_bootstrap_freq(freq): _, base, start_anchor, anchor = parse_offset(freq) # noqa bfreq = "Y" if start_anchor: bfreq += "S" else: bfreq += "E" if base in ["A", "Y", "Q"] and anchor is not None: bfreq = f"{bfreq}-{anchor}" return bfreq
[docs] def _get_year_label(year_dt) -> str: if isinstance(year_dt, cftime.datetime): year_label = year_dt.year else: year_label = year_dt.astype("datetime64[Y]").astype(int) + 1970 return year_label
# TODO: Return a generator instead and assess performance
[docs] def build_bootstrap_year_da( da: DataArray, groups: dict[Any, slice], label: Any, dim: str = "time" ) -> DataArray: """Return an array where a group in the original is replaced by every other groups along a new dimension. Parameters ---------- da : DataArray Original input array over reference period. groups : dict Output of grouping functions, such as `DataArrayResample.groups`. label : Any Key identifying the group item to replace. dim : str Dimension recognized as time. Default: `time`. Returns ------- DataArray: Array where one group is replaced by values from every other group along the `bootstrap` dimension. """ gr = groups.copy() # Location along dim that must be replaced bloc = da[dim][gr.pop(label)] # Initialize output array with new bootstrap dimension out = da.expand_dims({BOOTSTRAP_DIM: np.arange(len(gr))}).copy(deep=True) # With dask, mutating the views of out is not working, thus the accumulator out_accumulator = [] # Replace `bloc` by every other group for i, (_, group_slice) in enumerate(gr.items()): source = da.isel({dim: group_slice}) out_view = out.loc[{BOOTSTRAP_DIM: i}] if len(source[dim]) < 360 and len(source[dim]) < len(bloc): # This happens when the sampling frequency is anchored thus # source[dim] would be only a few months on the first and last year pass elif len(source[dim]) == len(bloc): out_view.loc[{dim: bloc}] = source.data elif len(bloc) == 365: out_view.loc[{dim: bloc}] = convert_calendar(source, "365_day").data elif len(bloc) == 366: out_view.loc[{dim: bloc}] = convert_calendar( source, "366_day", missing=np.NAN ).data elif len(bloc) < 365: # 360 days calendar case or anchored years for both source[dim] and bloc case out_view.loc[{dim: bloc}] = source.data[: len(bloc)] else: raise NotImplementedError out_accumulator.append(out_view) return xarray.concat(out_accumulator, dim=BOOTSTRAP_DIM)