Source code for xclim.sdba.base

"""Base classes."""
import warnings
from inspect import signature
from types import FunctionType
from typing import Callable, Mapping, Optional, Sequence, Set, Union

import dask.array as dsk
import jsonpickle
import numpy as np
import xarray as xr
from boltons.funcutils import wraps

from xclim.core.calendar import days_in_year, get_calendar, max_doy, parse_offset
from xclim.core.utils import uses_dask


# ## Base class for the sdba module
class Parametrizable(dict):
    """Helper base class resembling a dictionary.

    This object is _completely_ defined by the content of its internal dictionary, accessible through item access
    (`self['attr']`) or in `self.parameters`. When serializing and restoring this object, only members of that internal
    dict are preserved. All other attributes set directly with `self.attr = value` will not be preserved upon serialization
    and restoration of the object with `[json]pickle`.
    dictionary. Other variables set with `self.var = data` will be lost in the serialization process.
    This class is best serialized and restored with `jsonpickle`.
    """

    _repr_hide_params = []

    def __getstate__(self):
        """For (json)pickle, a Parametrizable should be defined by its internal dict only."""
        return self.parameters

    def __setstate__(self, state):
        """For (json)pickle, a Parametrizable in only defined by its internal dict."""
        self.update(state)

    def __getattr__(self, attr):
        """Get attributes."""
        try:
            return self.__getitem__(attr)
        except KeyError as err:
            # Raise the proper error type for getattr
            raise AttributeError(*err.args)

    @property
    def parameters(self):
        """All parameters as a dictionary. Read-only."""
        return dict(**self)

    def __repr__(self):
        """Return a string representation that allows eval to recreate it."""
        params = ", ".join(
            [
                f"{k}={repr(v)}"
                for k, v in self.items()
                if k not in self._repr_hide_params
            ]
        )
        return f"{self.__class__.__name__}({params})"


class ParametrizableWithDataset(Parametrizable):
    """Parametrizeable class that also has a `ds` attribute storing a dataset."""

    _attribute = "_xclim_parameters"

    @classmethod
    def from_dataset(cls, ds: xr.Dataset):
        """Create an instance from a dataset.

        The dataset must have a global attribute with a name corresponding to `cls._attribute`,
        and that attribute must be the result of `jsonpickle.encode(object)` where object is
        of the same type as this object.
        """
        obj = jsonpickle.decode(ds.attrs[cls._attribute])
        obj.set_dataset(ds)
        return obj

    def set_dataset(self, ds: xr.Dataset):
        """Stores an xarray dataset in the `ds` attribute.

        Useful with custom object initialization or if some external processing was performed.
        """
        self.ds = ds
        self.ds.attrs[self._attribute] = jsonpickle.encode(self)


class Grouper(Parametrizable):
    """Helper object to perform grouping actions on DataArrays and Datasets."""

    _repr_hide_params = ["dim", "prop"]  # For a concise repr
    # Two constants for use of `map_blocks` and `map_groups`.
    # They provide better code readability, nothing more
    PROP = "<PROP>"
    DIM = "<DIM>"

    def __init__(
        self,
        group: str,
        window: int = 1,
        add_dims: Optional[Union[Sequence[str], Set[str]]] = None,
        interp: Union[bool, str] = False,
    ):
        """Create the Grouper object.

        Parameters
        ----------
        group : str
          The usual grouping name as xarray understands it. Ex: "time.month" or "time".
          The dimension name before the dot is the "main dimension" stored in `Grouper.dim` and
          the property name after is stored in `Grouper.prop`.
        window : int
          If larger than 1, a centered rolling window along the main dimension is created when grouping data.
          Units are the sampling frequency of the data along the main dimension.
        add_dims : Optional[Sequence[str]]
          Additional dimensions that should be reduced in grouping operations. This behaviour is also controlled
          by the `main_only` parameter of the `apply` method. If any of these dimensions are absent from the dataarrays,
          they will be omitted.
        interp : Union[bool, str]
          Whether to return an interpolatable index in the `get_index` method. Only effective for `month` grouping.
          Interpolation method names are accepted for convenience, "nearest" is translated to False, all other names
          are translated to True.
          This modifies the default, but `get_index` also accepts an `interp` argument overriding the one defined here..
        """
        if "." in group:
            dim, prop = group.split(".")
        else:
            dim, prop = group, "group"

        if isinstance(interp, str):
            interp = interp != "nearest"

        add_dims = add_dims or []
        if window > 1:
            add_dims.insert(1, "window")
        super().__init__(
            dim=dim,
            add_dims=add_dims,
            prop=prop,
            name=group,
            window=window,
            interp=interp,
        )

    @classmethod
    def from_kwargs(cls, **kwargs):
        kwargs["group"] = cls(
            group=kwargs.pop("group"),
            window=kwargs.pop("window", 1),
            add_dims=kwargs.pop("add_dims", []),
            interp=kwargs.get("interp", False),
        )
        return kwargs

    def get_coordinate(self, ds=None):
        """Return the coordinate as in the output of group.apply.

        Currently only implemented for groupings with prop == month or dayofyear.
        For prop == dayfofyear a ds (dataset or dataarray) can be passed to infer
        the max doy from the available years and calendar.
        """
        if self.prop == "month":
            return xr.DataArray(np.arange(1, 13), dims=("month",), name="month")
        if self.prop == "dayofyear":
            if ds is not None:
                cal = get_calendar(ds, dim=self.dim)
                mdoy = max(
                    days_in_year(yr, cal) for yr in np.unique(ds[self.dim].dt.year)
                )
            else:
                mdoy = 365
            return xr.DataArray(
                np.arange(1, mdoy + 1), dims=("dayofyear"), name="dayofyear"
            )
        if self.prop == "group":
            return xr.DataArray([1], dims=("group",), name="group")
        # TODO woups what happens when there is no group? (prop is None)
        raise NotImplementedError()

    def group(
        self,
        da: Union[xr.DataArray, xr.Dataset] = None,
        main_only=False,
        **das: xr.DataArray,
    ):
        """Return a xr.core.groupby.GroupBy object.

        More than one array can be combined to a dataset before grouping using the `das`  kwargs.
        A new `window` dimension is added if `self.window` is larger than 1.
        If `Grouper.dim` is 'time', but 'prop' is None, the whole array is grouped together.

        When multiple arrays are passed, some of them can be grouped along the same group as self.
        They are boadcasted, merged to the grouping dataset and regrouped in the output.
        """
        if das:
            from .utils import broadcast  # pylint: disable=cyclic-import

            if da is not None:
                das[da.name] = da

            da = xr.Dataset(
                data_vars={
                    name: das.pop(name)
                    for name in list(das.keys())
                    if self.dim in das[name].dims
                }
            )

            # "Ungroup" the grouped arrays
            da = da.assign(
                {
                    name: broadcast(var, da[self.dim], group=self, interp="nearest")
                    for name, var in das.items()
                }
            )

        if not main_only and self.window > 1:
            da = da.rolling(center=True, **{self.dim: self.window}).construct(
                window_dim="window"
            )
            if uses_dask(da):
                # Rechunk. There might be padding chunks.
                da = da.chunk({self.dim: -1})

        if self.prop == "group":
            group = self.get_index(da)
        else:
            group = self.name

        return da.groupby(group)

    def get_index(
        self,
        da: Union[xr.DataArray, xr.Dataset],
        interp: Optional[Union[bool, str]] = None,
    ):
        """Return the group index of each element along the main dimension.

        Parameters
        ----------
        da : Union[xr.DataArray, xr.Dataset]
          The input array/dataset for which the group index is returned.
          It must have Grouper.dim as a coordinate.
        interp : Union[bool, str]
          Argument `interp` defaults to `self.interp`. If True, the returned index can be
          used for interpolation. For month grouping, integer values represent the middle of the month, all other
          days are linearly interpolated in between.

        Returns
        -------
        xr.DataArray
          The index of each element along `Grouper.dim`.
          If `Grouper.dim` is `time` and `Grouper.prop` is None, an uniform array of True is returned.
          If `Grouper.prop` is a time accessor (month, dayofyear, etc), an numerical array is returned,
            with a special case of `month` and `interp=True`.
          If `Grouper.dim` is not `time`, the dim is simply returned.
        """
        if self.prop == "group":
            if self.dim == "time":
                return xr.full_like(da[self.dim], 1, dtype=int).rename("group")
            return da[self.dim].rename("group")

        ind = da.indexes[self.dim]
        i = getattr(ind, self.prop)

        if not np.issubdtype(i.dtype, np.integer):
            raise ValueError(
                f"Index {self.name} is not of type int (rather {i.dtype}), but {self.__class__.__name__} requires integer indexes."
            )

        interp = (
            (interp or self.interp)
            if not isinstance(interp, str)
            else interp != "nearest"
        )
        if interp:
            if self.dim == "time":
                if self.prop == "month":
                    i = ind.month - 0.5 + ind.day / ind.days_in_month
                elif self.prop == "dayofyear":
                    i = ind.dayofyear
                else:
                    raise NotImplementedError
            else:
                raise NotImplementedError

        xi = xr.DataArray(
            i,
            dims=self.dim,
            coords={self.dim: da.coords[self.dim]},
            name=self.dim + " group index",
        )

        # Expand dimensions of index to match the dimensions of da
        # We want vectorized indexing with no broadcasting
        # xi = xi.broadcast_like(da)
        xi.name = self.prop
        return xi

    def apply(
        self,
        func: Union[FunctionType, str],
        da: Union[xr.DataArray, Mapping[str, xr.DataArray], xr.Dataset],
        main_only: bool = False,
        **kwargs,
    ):
        """Apply a function group-wise on DataArrays.

        Parameters
        ----------
        func : Union[FunctionType, str]
          The function to apply to the groups, either a callable or a `xr.core.groupby.GroupBy` method name as a string.
          The function will be called as `func(group, dim=dims, **kwargs)`. See `main_only` for the behaviour of `dims`.
        da : Union[xr.DataArray, Mapping[str, xr.DataArray], xr.Dataset]
          The DataArray on which to apply the function. Multiple arrays can be passed through a dictionary. A dataset will be created before grouping.
        main_only : bool
          Whether to call the function with the main dimension only (if True)
          or with all grouping dims (if False, default) (including the window and dimensions given through `add_dims`).
          The dimensions used are also written in the "group_compute_dims" attribute.
          If all the input arrays are missing one of the 'add_dims', it is silently omitted.
        **kwargs :
          Other keyword arguments to pass to the function.

        Returns
        -------
        DataArray or Dataset
          Attributes "group", "group_window" and "group_compute_dims" are added.
          If the function did not reduce the array:
            - The output is sorted along the main dimension.
            - The output is rechunked to match the chunks on the input
                If multiple inputs with differing chunking were given as inputs, the chunking with the smallest number of chunks is used.
          If the function reduces the array:
            - If there is only one group, the singleton dimension is squeezed out of the output
            - The output is rechunked as to have only 1 chunk along the new dimension.


        Notes
        -----
        For the special case where a Dataset is returned, but only some of its variable where reduced by the grouping, xarray's `GroupBy.map` will
        broadcast everything back to the ungrouped dimensions. To overcome this issue, function may add a "_group_apply_reshape" attribute set to
        True on the variables that should be reduced and these will be re-grouped by calling `da.groupby(self.name).first()`.
        """
        if isinstance(da, (dict, xr.Dataset)):
            grpd = self.group(main_only=main_only, **da)
            dim_chunks = min(  # Get smallest chunking to rechunk if the operation is non-grouping
                [
                    d.chunks[d.get_axis_num(self.dim)]
                    for d in da.values()
                    if uses_dask(d) and self.dim in d.dims
                ]
                or [[]],  # pass [[]] if no dataarrays have chunks so min doesnt fail
                key=len,
            )
        else:
            grpd = self.group(da, main_only=main_only)
            # Get chunking to rechunk is the operation is non-grouping
            # To match the behaviour of the case above, an empty list signifies that dask is not used for the input.
            dim_chunks = (
                [] if not uses_dask(da) else da.chunks[da.get_axis_num(self.dim)]
            )

        dims = self.dim
        if not main_only:
            dims = [dims] + [dim for dim in self.add_dims if dim in grpd.dims]

        if isinstance(func, str):
            out = getattr(grpd, func)(dim=dims, **kwargs)
        else:
            out = grpd.map(func, dim=dims, **kwargs)

        # Case where the function wants to return more than one variables
        # and that some have grouped dims and other have the same dimensions as the input.
        # In that specific case, groupby broadcasts everything back to the input's dim, copying the grouped data.
        if isinstance(out, xr.Dataset):
            for name, outvar in out.data_vars.items():
                if "_group_apply_reshape" in outvar.attrs:
                    out[name] = self.group(outvar, main_only=True).first(
                        skipna=False, keep_attrs=True
                    )
                    del out[name].attrs["_group_apply_reshape"]

        # Save input parameters as attributes of output DataArray.
        out.attrs["group"] = self.name
        out.attrs["group_compute_dims"] = dims
        out.attrs["group_window"] = self.window

        # On non reducing ops, drop the constructed window
        if self.window > 1 and "window" in out.dims:
            out = out.isel(window=self.window // 2, drop=True)

        # If the grouped operation did not reduce the array, the result is sometimes unsorted along dim
        if self.dim in out.dims:
            out = out.sortby(self.dim)
            # The expected behavior for downstream methods would be to conserve chunking along dim
            if uses_dask(out):
                # or -1 in case dim_chunks is [], when no input is chunked (only happens if the operation is chunking the output)
                out = out.chunk({self.dim: dim_chunks or -1})
        if self.prop in out.dims and uses_dask(out):
            # Same as above : downstream methods expect only one chunk along the group
            out = out.chunk({self.prop: -1})

        return out


def parse_group(func: Callable) -> Callable:
    """Parse the "group" argument of a function and return a Grouper object.

    Adds the possiblity to pass a window argument and a list of dimensions in group.
    """
    sig = signature(func)
    if "group" in sig.parameters:
        default_group = sig.parameters["group"].default
    else:
        default_group = None

    @wraps(func)
    def _parse_group(*args, **kwargs):
        if default_group:
            kwargs.setdefault("group", default_group)
        elif "group" not in kwargs:
            raise ValueError("'group' argument not given.")
        if not isinstance(kwargs["group"], Grouper):
            kwargs = Grouper.from_kwargs(**kwargs)
        return func(*args, **kwargs)

    return _parse_group


def duck_empty(dims, sizes, chunks=None):
    """Return an empty DataArray based on a numpy or dask backend, depending on the chunks argument."""
    shape = [sizes[dim] for dim in dims]
    if chunks:
        chnks = [chunks.get(dim, (sizes[dim],)) for dim in dims]
        content = dsk.empty(shape, chunks=chnks)
    else:
        content = np.empty(shape)
    return xr.DataArray(content, dims=dims)


def map_blocks(reduces=None, **outvars):
    """
    Decorator for declaring functions and wrapping them into a map_blocks. It takes care of constructing
    the template dataset.

    If `group` is in the kwargs, it is assumed that `group.dim` is the only dimension reduced or modified,
    and that some other dimensions might be added, but no other existing dimension will be modified.

    Arguments to the decorator are mappings from variable name in the output to its *new* dimensions.
    Dimension order is not preserved.
    The placeholders "<PROP>" and "<DIM>" can be used to signify `group.prop` and `group.dim` respectively.

    The decorated function must always have the signature: func(ds, **kwargs), where ds is a DataArray or a Dataset.
    It must always output a dataset matching the mapping passed to the decorator.
    """

    def merge_dimensions(*seqs):
        """Merge several dimensions lists while preserving order."""
        out = seqs[0].copy()
        for seq in seqs[1:]:
            last_index = 0
            for i, e in enumerate(seq):
                if e in out:
                    indx = out.index(e)
                    if indx < last_index:
                        raise ValueError(
                            "Dimensions order mismatch, lists are not mergeable."
                        )
                    last_index = indx
                else:
                    out.insert(last_index + 1, e)
        return out

    # Ordered list of all added dimensions
    out_dims = merge_dimensions(*outvars.values())
    # List of dimensions reduced by the function.
    red_dims = reduces or []

    def _decorator(func):

        # @wraps(func, hide_wrapped=True)
        @parse_group
        def _map_blocks(ds, **kwargs):
            if isinstance(ds, xr.Dataset):
                ds = ds.unify_chunks()

            # Get group if present
            group = kwargs.get("group")

            # Ensure group is given as it might not be in the signature of the wrapped func
            if {Grouper.PROP, Grouper.DIM}.intersection(
                out_dims + red_dims
            ) and group is None:
                raise ValueError("Missing required `group` argument.")

            if uses_dask(ds):
                # Use dask if any of the input is dask-backed.
                chunks = (
                    dict(ds.chunks)
                    if isinstance(ds, xr.Dataset)
                    else dict(zip(ds.dims, ds.chunks))
                )
                if (
                    group is not None
                    and group.dim in chunks
                    and len(chunks[group.dim]) > 1
                ):
                    raise ValueError(
                        f"The dimension over which we group cannot be chunked ({group.dim} has chunks {chunks[group.dim]})."
                    )
            else:
                chunks = None

            # Make translation dict
            if group is not None:
                placeholders = {Grouper.PROP: group.prop, Grouper.DIM: group.dim}
            else:
                placeholders = {}

            # Get new dimensions (in order), translating placeholders to real names.
            new_dims = [placeholders.get(dim, dim) for dim in out_dims]
            reduced_dims = [placeholders.get(dim, dim) for dim in red_dims]

            for dim in new_dims:
                if dim in ds.dims and dim not in reduced_dims:
                    raise ValueError(
                        f"Dimension {dim} is meant to be added by the computation but it is already on one of the inputs."
                    )

            # Dimensions untouched by the function.
            base_dims = list(set(ds.dims) - set(new_dims) - set(reduced_dims))

            # All dimensions of the output data, new_dims are added at the end on purpose.
            all_dims = base_dims + new_dims
            # The coordinates of the output data.
            coords = {}
            for dim in all_dims:
                if dim == group.prop:
                    coords[group.prop] = group.get_coordinate(ds=ds)
                elif dim == group.dim:
                    coords[group.dim] = ds[group.dim]
                elif dim in ds.coords:
                    coords[dim] = ds[dim]
                elif dim in kwargs:
                    coords[dim] = xr.DataArray(kwargs[dim], dims=(dim,), name=dim)
                else:
                    raise ValueError(
                        f"This function adds the {dim} dimension, its coordinate must be provided as a keyword argument."
                    )
            sizes = {name: crd.size for name, crd in coords.items()}

            # Create the output dataset, but empty
            tmpl = xr.Dataset(coords=coords)
            for var, dims in outvars.items():
                # Out variables must have the base dims + new_dims
                dims = base_dims + [placeholders.get(dim, dim) for dim in dims]
                # duck empty calls dask if chunks is not None
                tmpl[var] = duck_empty(dims, sizes, chunks)

            def _call_and_transpose_on_exit(dsblock, **kwargs):
                """Call the decorated func and transpose to ensure the same dim order as on the templace."""
                out = func(dsblock, **kwargs).transpose(*all_dims)
                for name, crd in dsblock.coords.items():
                    if name not in out.coords and set(crd.dims).issubset(out.dims):
                        out = out.assign_coords({name: dsblock[name]})
                return out

            # Fancy patching for explicit dask task names
            _call_and_transpose_on_exit.__name__ = f"block_{func.__name__}"

            out = ds.map_blocks(
                _call_and_transpose_on_exit, template=tmpl, kwargs=kwargs
            )

            return out

        return _map_blocks

    return _decorator


def map_groups(reduces=[Grouper.DIM], main_only=False, **outvars):
    """
    Decorator for declaring functions acting only on groups and wrapping them into a map_blocks.
    See :py:func:`map_blocks`.

    This is the same as `map_blocks` but adds a call to `group.apply()` in the mapped func.

    It also adds an additional "main_only" argument which is the same as for group.apply.

    Finally, the decorated function must have the signature: func(ds, dim, **kwargs).
    Where ds is a DataAray or Dataset, dim is the group.dim (and add_dims). The `group` argument
    is stripped from the kwargs, but must evidently be provided in the call.
    """

    def _decorator(func):
        decorator = map_blocks(reduces=reduces, **outvars)

        def _apply_on_group(dsblock, **kwargs):
            group = kwargs.pop("group")
            return group.apply(func, dsblock, main_only=main_only, **kwargs)

        # Fancy patching for explicit dask task names
        _apply_on_group.__name__ = f"group_{func.__name__}"

        # wraps(func, injected=['dim'], hide_wrapped=True)(
        return decorator(_apply_on_group)

    return _decorator


def _get_number_of_elements_by_year(time):
    """Get the number of elements in time in a year by inferring its sampling frequency.

    Only calendar with uniform year lengths are supported : 360_day, noleap, all_leap.
    """
    cal = get_calendar(time)

    # Calendar check
    if cal in ["standard", "gregorian", "default", "proleptic_gregorian"]:
        raise ValueError(
            "For moving window computations, the data must have a uniform calendar (360_day, no_leap or all_leap)"
        )

    mult, freq, _ = parse_offset(xr.infer_freq(time))
    days_in_year = max_doy[cal]
    elements_in_year = {"Q": 4, "M": 12, "D": days_in_year, "H": days_in_year * 24}
    N_in_year = elements_in_year.get(freq, 1) / int(mult or 1)
    if N_in_year % 1 != 0:
        raise ValueError(
            f"Sampling frequency of the data must be Q, M, D or H and evenly divide a year (got {mult}{freq})."
        )

    return int(N_in_year)


def construct_moving_yearly_window(
    da: xr.Dataset, window: int = 21, step: int = 1, dim: str = "movingwin"
):
    """Construct a moving window DataArray.

    Stacks windows of `da` in a new 'movingwin' dimension.
    Windows are always made of full years, so calendar with non uniform year lengths are not supported.

    Windows are constructed starting at the beginning of `da`, if number of given years is not
    a multiple of `step`, then the last year(s) will be missing as a supplementary window would be incomplete.

    Parameters
    ----------
    da : xr.DataArray
      A DataArray with a `time` dimension.
    window : int
      The length of the moving window as a number of years.
    step : int
      The step between each window as a number of years.
    dim : str
      The new dimension name. If given, must also be given to `unpack_moving_yearly_window`.

    Return
    ------
    xr.DataArray
      A DataArray with a new `movingwin` dimension and a `time` dimension with a length of 1 window.
      This assumes downstream algorithms do not make use of the _absolute_ year of the data.
      The correct timeseries can be reconstructed with :py:func:`unpack_moving_yearly_window`.
      The coordinates of `movingwin` are the first date of the windows.
    """
    # Get number of samples per year (and perform checks)
    N_in_year = _get_number_of_elements_by_year(da.time)

    # Number of samples in a window
    N = window * N_in_year

    first_slice = da.isel(time=slice(0, N))
    first_slice = first_slice.expand_dims({dim: np.atleast_1d(first_slice.time[0])})
    daw = [first_slice]

    i_start = N_in_year * step
    # This is the first time I use `while` in real python code. What an event.
    while i_start + N <= da.time.size:
        # Cut and add _full_ slices only, partial window are thrown out
        # Use isel so that we don't need to deal with a starting date.
        slc = da.isel(time=slice(i_start, i_start + N))
        slc = slc.expand_dims({dim: np.atleast_1d(slc.time[0])})
        slc["time"] = first_slice.time
        daw.append(slc)
        i_start += N_in_year * step

    daw = xr.concat(daw, dim)
    return daw


def unpack_moving_yearly_window(da: xr.DataArray, dim: str = "movingwin"):
    """Unpack a constructed moving window dataset to a normal timeseries, only keeping the central data.

    Unpack DataArrays created with :py:func:`construct_moving_yearly_window` and recreate a timeseries data.
    Only keeps the central non-overlapping years. The final timeseries will be (window - step) years shorter than
    the initial one.

    The window length and window step are inferred from the coordinates.

    Parameters
    ----------
    da: xr.DataArray
      As constructed by :py:func:`construct_moving_yearly_window`.
    dim : str
      The window dimension name as given to the construction function.
    """
    # Get number of samples by year (and perform checks)
    N_in_year = _get_number_of_elements_by_year(da.time)

    # Might be smaller than the original moving window, doesn't matter
    window = da.time.size / N_in_year

    if window % 1 != 0:
        warnings.warn(
            f"Incomplete data received as number of years covered is not an integer ({window})"
        )

    # Get step in number of years
    days_in_year = max_doy[get_calendar(da)]
    step = np.unique(da[dim].diff(dim).dt.days / days_in_year)
    if len(step) > 1:
        raise ValueError("The spacing between the windows is not equal.")
    step = int(step[0])

    # Which years to keep: length step, in the middle of window
    left = int((window - step) // 2)  # first year to keep

    # Keep only the middle years
    da = da.isel(time=slice(left * N_in_year, (left + step) * N_in_year))

    out = []
    for win_start in da[dim]:
        slc = da.sel({dim: win_start}).drop_vars(dim)
        dt = win_start.values - da[dim][0].values
        slc["time"] = slc.time + dt
        out.append(slc)

    return xr.concat(out, "time")


def stack_variables(ds, rechunk=True, dim="variables"):
    """Stack different variables of a dataset into a single DataArray with a new "variables" dimension.

    Variable attributes are all added as lists of attributes.

    Parameters
    ----------
    ds : xr.Dataset
      Input dataset.
    rechunk : bool
      If True (default), dask arrays are rechunked with `variables : -1`.
    dim : str
      Name of dimension along which variables are indexed.

    Returns
    -------
    xr.DataArray
      Array with variables stacked along `dim` dimension.
    """
    # Store original arrays' attributes
    attrs = {}
    nvar = len(ds.data_vars)
    for i, var in enumerate(ds.data_vars.values()):
        for name, attr in var.attrs.items():
            attrs.setdefault(name, [None] * nvar)[i] = attr

    # Special key used for later `unstacking`
    attrs["is_variables"] = True
    var_crd = xr.DataArray(
        list(ds.data_vars.keys()), dims=(dim,), name=dim, attrs=attrs
    )

    da = xr.concat(ds.data_vars.values(), var_crd, combine_attrs="drop")

    if uses_dask(da) and rechunk:
        da = da.chunk({dim: -1})

    da.attrs.update(ds.attrs)
    return da.rename("multivariate")


def unstack_variables(da, dim=None):
    """Unstack a DataArray created by `stack_variables` to a dataset.

    Parameters
    ----------
    da : xr.DataArray
      Array holding different variables along `dim` dimension.
    dim : str
      Name of dimension along which the variables are stacked. If not specified (default),
      `dim` is inferred from attributes of the coordinate.

    Returns
    -------
    xr.Dataset
      Dataset holding each variable in an individual DataArray.
    """
    if dim is None:
        for dim, crd in da.coords.items():
            if crd.attrs.get("is_variables"):
                break
        else:
            raise ValueError("No variable coordinate found, were attributes removed?")

    ds = xr.Dataset(
        {name.item(): da.sel({dim: name.item()}, drop=True) for name in da[dim]},
        attrs=da.attrs,
    )

    # Reset attributes
    for name, attr_list in da.variables.attrs.items():
        if name == "is_variables":
            continue
        for attr, var in zip(attr_list, da.variables):
            if attr is not None:
                ds[var.item()].attrs[name] = attr

    return ds