Skip to content

Datasets API

data_dict

Author: Wenyu Ouyang Date: 2021-12-31 11:08:29 LastEditTime: 2025-07-13 15:40:07 LastEditors: Wenyu Ouyang Description: A dict used for data source and data loader FilePath: orchhydro orchhydro\datasets\data_dict.py Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.

data_scalers

Author: Wenyu Ouyang Date: 2024-04-08 18:17:44 LastEditTime: 2025-10-29 08:53:29 LastEditors: Wenyu Ouyang Description: normalize the data FilePath: orchhydro orchhydro\datasets\data_scalers.py Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved.

DapengScaler

Source code in torchhydro/datasets/data_scalers.py
class DapengScaler(object):
    def __init__(
        self,
        vars_data,
        data_cfgs: dict,
        is_tra_val_te: str,
        other_vars: Optional[dict] = None,
        prcp_norm_cols=None,
        gamma_norm_cols=None,
        pbm_norm=False,
        data_source: object = None,
    ):
        """
        The normalization and denormalization methods from Dapeng's 1st WRR paper.
        Some use StandardScaler, and some use special norm methods

        Parameters
        ----------
        vars_data: dict
            data for all variables used
        data_cfgs
            data parameter config in data source
        is_tra_val_te
            train/valid/test
        other_vars
            if more input are needed, list them in other_vars
        prcp_norm_cols
            data items which use _prcp_norm method to normalize
        gamma_norm_cols
            data items which use log(\sqrt(x)+.1) method to normalize
        pbm_norm
            if true, use pbm_norm method to normalize; the output of pbms is not normalized data, so its inverse is different.
        """
        if prcp_norm_cols is None:
            prcp_norm_cols = [
                "streamflow",
            ]
        if gamma_norm_cols is None:
            gamma_norm_cols = [
                "gpm_tp",
                "sta_tp",
                "total_precipitation_hourly",
                "temperature_2m",
                "dewpoint_temperature_2m",
                "surface_net_solar_radiation",
                "sm_surface",
                "sm_rootzone",
            ]
        self.data_target = vars_data["target_cols"]
        self.data_cfgs = data_cfgs
        self.t_s_dict = wrap_t_s_dict(data_cfgs, is_tra_val_te)
        self.data_other = other_vars
        self.prcp_norm_cols = prcp_norm_cols
        self.gamma_norm_cols = gamma_norm_cols
        # both prcp_norm_cols and gamma_norm_cols use log(\sqrt(x)+.1) method to normalize
        self.log_norm_cols = gamma_norm_cols + prcp_norm_cols
        self.pbm_norm = pbm_norm
        self.data_source = data_source
        # save stat_dict of training period in case_dir for valid/test
        stat_file = os.path.join(data_cfgs["case_dir"], "dapengscaler_stat.json")
        # for testing sometimes such as pub cases, we need stat_dict_file from trained dataset
        if is_tra_val_te == "train" and data_cfgs["stat_dict_file"] is None:
            self.stat_dict = self.cal_stat_all(vars_data)
            with open(stat_file, "w") as fp:
                json.dump(self.stat_dict, fp)
        else:
            # for valid/test, we need to load stat_dict from train
            if data_cfgs["stat_dict_file"] is not None:
                # we used a assigned stat file, typically for PUB exps
                # shutil.copy(data_cfgs["stat_dict_file"], stat_file)
                try:
                    shutil.copy(data_cfgs["stat_dict_file"], stat_file)
                except SameFileError:
                    print(
                        f"The source file and the target file are the same: {data_cfgs['stat_dict_file']}, skipping the copy operation."
                    )
                except Exception as e:
                    print(f"Error: {e}")
            assert os.path.isfile(stat_file)
            with open(stat_file, "r") as fp:
                self.stat_dict = json.load(fp)

    @property
    def mean_prcp(self):
        """This property is used to be divided by streamflow to normalize streamflow,
        hence, its unit is same as streamflow

        Returns
        -------
        np.ndarray
            mean_prcp with the same unit as streamflow
        """
        # Get the first target variable (usually flow variable) instead of hardcoding "streamflow"
        flow_var_name = self.data_cfgs["target_cols"][0]
        final_unit = self.data_target.attrs["units"][flow_var_name]
        mean_prcp = self.data_source.read_mean_prcp(
            self.t_s_dict["sites_id"], unit=final_unit
        )
        return mean_prcp.to_array().transpose("basin", "variable").to_numpy()

    def inverse_transform(self, target_values):
        """
        Denormalization for output variables

        Parameters
        ----------
        target_values
            output variables

        Returns
        -------
        np.array
            denormalized predictions
        """
        stat_dict = self.stat_dict
        target_vars = self.data_cfgs["target_cols"]
        if self.pbm_norm:
            # for (differentiable models) pbm's output, its unit is mm/day, so we don't need to recover its unit
            pred = target_values
        else:
            pred = _trans_norm(
                target_values,
                target_vars,
                stat_dict,
                log_norm_cols=self.log_norm_cols,
                to_norm=False,
            )
            for i in range(len(self.data_cfgs["target_cols"])):
                var = self.data_cfgs["target_cols"][i]
                if var in self.prcp_norm_cols:
                    pred.loc[dict(variable=var)] = _prcp_norm(
                        pred.sel(variable=var).to_numpy(),
                        self.mean_prcp,
                        to_norm=False,
                    )
                else:
                    pred.loc[dict(variable=var)] = pred.sel(variable=var)
        # add attrs for units
        pred.attrs.update(self.data_target.attrs)
        return pred.to_dataset(dim="variable")

    def cal_stat_all(self, vars_data):
        """
        Calculate statistics of outputs(streamflow etc), and inputs(forcing and attributes)
        Parameters
        ----------
        vars_data: dict
            data for all variables used

        Returns
        -------
        dict
            a dict with statistic values
        """
        stat_dict = {}
        for k, v in vars_data.items():
            if v is None:
                continue
            for i in range(len(v.coords["variable"].values)):
                var_name = v.coords["variable"].values[i]
                if var_name in self.prcp_norm_cols:
                    stat_dict[var_name] = cal_stat_prcp_norm(
                        v.sel(variable=var_name).to_numpy(),
                        self.mean_prcp,
                    )
                elif var_name in self.gamma_norm_cols:
                    stat_dict[var_name] = cal_stat_gamma(
                        v.sel(variable=var_name).to_numpy()
                    )
                else:
                    stat_dict[var_name] = cal_stat(v.sel(variable=var_name).to_numpy())

        return stat_dict

    def get_data_norm(self, data, to_norm: bool = True) -> np.ndarray:
        """
        Get normalized values

        Parameters
        ----------
        data
            origin data
        to_norm
            if true, perform normalization
            if false, perform denormalization

        Returns
        -------
        np.array
            the output value for modeling
        """
        stat_dict = self.stat_dict
        out = xr.full_like(data, np.nan)
        # if we don't set a copy() here, the attrs of data will be changed, which is not our wish
        out.attrs = copy.deepcopy(data.attrs)
        _vars = data.coords["variable"].values
        if "units" not in out.attrs:
            Warning("The attrs of output data does not contain units")
            out.attrs["units"] = {}
        for i in range(len(_vars)):
            var = _vars[i]
            if var in self.prcp_norm_cols:
                out.loc[dict(variable=var)] = _prcp_norm(
                    data.sel(variable=var).to_numpy(),
                    self.mean_prcp,
                    to_norm=True,
                )
            else:
                out.loc[dict(variable=var)] = data.sel(variable=var).to_numpy()
            out.attrs["units"][var] = "dimensionless"
        out = _trans_norm(
            out,
            _vars,
            stat_dict,
            log_norm_cols=self.log_norm_cols,
            to_norm=to_norm,
        )
        return out

    def load_norm_data(self, vars_data):
        """
        Read data and perform normalization for DL models
        Parameters
        ----------
        vars_data: dict
            data for all variables used

        Returns
        -------
        tuple
            x: 3-d  gages_num*time_num*var_num
            y: 3-d  gages_num*time_num*1
            c: 2-d  gages_num*var_num
        """
        if vars_data is None:
            return None
        return {
            k: self.get_data_norm(v) if v is not None else None
            for k, v in vars_data.items()
        }

mean_prcp property readonly

This property is used to be divided by streamflow to normalize streamflow, hence, its unit is same as streamflow

Returns

np.ndarray mean_prcp with the same unit as streamflow

__init__(self, vars_data, data_cfgs, is_tra_val_te, other_vars=None, prcp_norm_cols=None, gamma_norm_cols=None, pbm_norm=False, data_source=None) special

The normalization and denormalization methods from Dapeng's 1st WRR paper. Some use StandardScaler, and some use special norm methods

Parameters

!!! vars_data "dict" data for all variables used data_cfgs data parameter config in data source is_tra_val_te train/valid/test other_vars if more input are needed, list them in other_vars prcp_norm_cols data items which use _prcp_norm method to normalize gamma_norm_cols data items which use log(\sqrt(x)+.1) method to normalize pbm_norm if true, use pbm_norm method to normalize; the output of pbms is not normalized data, so its inverse is different.

Source code in torchhydro/datasets/data_scalers.py
def __init__(
    self,
    vars_data,
    data_cfgs: dict,
    is_tra_val_te: str,
    other_vars: Optional[dict] = None,
    prcp_norm_cols=None,
    gamma_norm_cols=None,
    pbm_norm=False,
    data_source: object = None,
):
    """
    The normalization and denormalization methods from Dapeng's 1st WRR paper.
    Some use StandardScaler, and some use special norm methods

    Parameters
    ----------
    vars_data: dict
        data for all variables used
    data_cfgs
        data parameter config in data source
    is_tra_val_te
        train/valid/test
    other_vars
        if more input are needed, list them in other_vars
    prcp_norm_cols
        data items which use _prcp_norm method to normalize
    gamma_norm_cols
        data items which use log(\sqrt(x)+.1) method to normalize
    pbm_norm
        if true, use pbm_norm method to normalize; the output of pbms is not normalized data, so its inverse is different.
    """
    if prcp_norm_cols is None:
        prcp_norm_cols = [
            "streamflow",
        ]
    if gamma_norm_cols is None:
        gamma_norm_cols = [
            "gpm_tp",
            "sta_tp",
            "total_precipitation_hourly",
            "temperature_2m",
            "dewpoint_temperature_2m",
            "surface_net_solar_radiation",
            "sm_surface",
            "sm_rootzone",
        ]
    self.data_target = vars_data["target_cols"]
    self.data_cfgs = data_cfgs
    self.t_s_dict = wrap_t_s_dict(data_cfgs, is_tra_val_te)
    self.data_other = other_vars
    self.prcp_norm_cols = prcp_norm_cols
    self.gamma_norm_cols = gamma_norm_cols
    # both prcp_norm_cols and gamma_norm_cols use log(\sqrt(x)+.1) method to normalize
    self.log_norm_cols = gamma_norm_cols + prcp_norm_cols
    self.pbm_norm = pbm_norm
    self.data_source = data_source
    # save stat_dict of training period in case_dir for valid/test
    stat_file = os.path.join(data_cfgs["case_dir"], "dapengscaler_stat.json")
    # for testing sometimes such as pub cases, we need stat_dict_file from trained dataset
    if is_tra_val_te == "train" and data_cfgs["stat_dict_file"] is None:
        self.stat_dict = self.cal_stat_all(vars_data)
        with open(stat_file, "w") as fp:
            json.dump(self.stat_dict, fp)
    else:
        # for valid/test, we need to load stat_dict from train
        if data_cfgs["stat_dict_file"] is not None:
            # we used a assigned stat file, typically for PUB exps
            # shutil.copy(data_cfgs["stat_dict_file"], stat_file)
            try:
                shutil.copy(data_cfgs["stat_dict_file"], stat_file)
            except SameFileError:
                print(
                    f"The source file and the target file are the same: {data_cfgs['stat_dict_file']}, skipping the copy operation."
                )
            except Exception as e:
                print(f"Error: {e}")
        assert os.path.isfile(stat_file)
        with open(stat_file, "r") as fp:
            self.stat_dict = json.load(fp)

cal_stat_all(self, vars_data)

Calculate statistics of outputs(streamflow etc), and inputs(forcing and attributes) Parameters


!!! vars_data "dict" data for all variables used

Returns

dict a dict with statistic values

Source code in torchhydro/datasets/data_scalers.py
def cal_stat_all(self, vars_data):
    """
    Calculate statistics of outputs(streamflow etc), and inputs(forcing and attributes)
    Parameters
    ----------
    vars_data: dict
        data for all variables used

    Returns
    -------
    dict
        a dict with statistic values
    """
    stat_dict = {}
    for k, v in vars_data.items():
        if v is None:
            continue
        for i in range(len(v.coords["variable"].values)):
            var_name = v.coords["variable"].values[i]
            if var_name in self.prcp_norm_cols:
                stat_dict[var_name] = cal_stat_prcp_norm(
                    v.sel(variable=var_name).to_numpy(),
                    self.mean_prcp,
                )
            elif var_name in self.gamma_norm_cols:
                stat_dict[var_name] = cal_stat_gamma(
                    v.sel(variable=var_name).to_numpy()
                )
            else:
                stat_dict[var_name] = cal_stat(v.sel(variable=var_name).to_numpy())

    return stat_dict

get_data_norm(self, data, to_norm=True)

Get normalized values

Parameters

data origin data to_norm if true, perform normalization if false, perform denormalization

Returns

np.array the output value for modeling

Source code in torchhydro/datasets/data_scalers.py
def get_data_norm(self, data, to_norm: bool = True) -> np.ndarray:
    """
    Get normalized values

    Parameters
    ----------
    data
        origin data
    to_norm
        if true, perform normalization
        if false, perform denormalization

    Returns
    -------
    np.array
        the output value for modeling
    """
    stat_dict = self.stat_dict
    out = xr.full_like(data, np.nan)
    # if we don't set a copy() here, the attrs of data will be changed, which is not our wish
    out.attrs = copy.deepcopy(data.attrs)
    _vars = data.coords["variable"].values
    if "units" not in out.attrs:
        Warning("The attrs of output data does not contain units")
        out.attrs["units"] = {}
    for i in range(len(_vars)):
        var = _vars[i]
        if var in self.prcp_norm_cols:
            out.loc[dict(variable=var)] = _prcp_norm(
                data.sel(variable=var).to_numpy(),
                self.mean_prcp,
                to_norm=True,
            )
        else:
            out.loc[dict(variable=var)] = data.sel(variable=var).to_numpy()
        out.attrs["units"][var] = "dimensionless"
    out = _trans_norm(
        out,
        _vars,
        stat_dict,
        log_norm_cols=self.log_norm_cols,
        to_norm=to_norm,
    )
    return out

inverse_transform(self, target_values)

Denormalization for output variables

Parameters

target_values output variables

Returns

np.array denormalized predictions

Source code in torchhydro/datasets/data_scalers.py
def inverse_transform(self, target_values):
    """
    Denormalization for output variables

    Parameters
    ----------
    target_values
        output variables

    Returns
    -------
    np.array
        denormalized predictions
    """
    stat_dict = self.stat_dict
    target_vars = self.data_cfgs["target_cols"]
    if self.pbm_norm:
        # for (differentiable models) pbm's output, its unit is mm/day, so we don't need to recover its unit
        pred = target_values
    else:
        pred = _trans_norm(
            target_values,
            target_vars,
            stat_dict,
            log_norm_cols=self.log_norm_cols,
            to_norm=False,
        )
        for i in range(len(self.data_cfgs["target_cols"])):
            var = self.data_cfgs["target_cols"][i]
            if var in self.prcp_norm_cols:
                pred.loc[dict(variable=var)] = _prcp_norm(
                    pred.sel(variable=var).to_numpy(),
                    self.mean_prcp,
                    to_norm=False,
                )
            else:
                pred.loc[dict(variable=var)] = pred.sel(variable=var)
    # add attrs for units
    pred.attrs.update(self.data_target.attrs)
    return pred.to_dataset(dim="variable")

load_norm_data(self, vars_data)

Read data and perform normalization for DL models Parameters


!!! vars_data "dict" data for all variables used

Returns

tuple x: 3-d gages_numtime_numvar_num y: 3-d gages_numtime_num1 c: 2-d gages_num*var_num

Source code in torchhydro/datasets/data_scalers.py
def load_norm_data(self, vars_data):
    """
    Read data and perform normalization for DL models
    Parameters
    ----------
    vars_data: dict
        data for all variables used

    Returns
    -------
    tuple
        x: 3-d  gages_num*time_num*var_num
        y: 3-d  gages_num*time_num*1
        c: 2-d  gages_num*var_num
    """
    if vars_data is None:
        return None
    return {
        k: self.get_data_norm(v) if v is not None else None
        for k, v in vars_data.items()
    }

ScalerHub

A class for Scaler

Source code in torchhydro/datasets/data_scalers.py
class ScalerHub(object):
    """
    A class for Scaler
    """

    def __init__(
        self,
        vars_data,
        data_cfgs=None,
        is_tra_val_te=None,
        data_source=None,
        **kwargs,
    ):
        """
        Perform normalization

        Parameters
        ----------
        vars_data
            data for all variables used.
            the dim must be (basin, time, lead_step, var) for 4-d array;
            the dim must be (basin, time, var) for 3-d array;
            the dim must be (basin, time) for 2-d array;
        data_cfgs
            configs for reading data
        is_tra_val_te
            train, valid or test
        data_source
            data source to get original data info
        kwargs
            other optional parameters for ScalerHub
        """
        self.data_cfgs = data_cfgs
        scaler_type = data_cfgs["scaler"]
        pbm_norm = data_cfgs["scaler_params"]["pbm_norm"]
        if scaler_type == "DapengScaler":
            gamma_norm_cols = data_cfgs["scaler_params"]["gamma_norm_cols"]
            prcp_norm_cols = data_cfgs["scaler_params"]["prcp_norm_cols"]
            scaler = DapengScaler(
                vars_data,
                data_cfgs,
                is_tra_val_te,
                prcp_norm_cols=prcp_norm_cols,
                gamma_norm_cols=gamma_norm_cols,
                pbm_norm=pbm_norm,
                data_source=data_source,
            )
        elif scaler_type in SCALER_DICT.keys():
            scaler = SklearnScaler(
                vars_data,
                data_cfgs,
                is_tra_val_te,
                pbm_norm=pbm_norm,
            )
        else:
            raise NotImplementedError(
                "We don't provide this Scaler now!!! Please choose another one: DapengScaler or key in SCALER_DICT"
            )
        self.norm_data = scaler.load_norm_data(vars_data)
        # we will use target_scaler during denormalization
        self.target_scaler = scaler
        print("Finish Normalization\n")

__init__(self, vars_data, data_cfgs=None, is_tra_val_te=None, data_source=None, **kwargs) special

Perform normalization

Parameters

vars_data data for all variables used. the dim must be (basin, time, lead_step, var) for 4-d array; the dim must be (basin, time, var) for 3-d array; the dim must be (basin, time) for 2-d array; data_cfgs configs for reading data is_tra_val_te train, valid or test data_source data source to get original data info kwargs other optional parameters for ScalerHub

Source code in torchhydro/datasets/data_scalers.py
def __init__(
    self,
    vars_data,
    data_cfgs=None,
    is_tra_val_te=None,
    data_source=None,
    **kwargs,
):
    """
    Perform normalization

    Parameters
    ----------
    vars_data
        data for all variables used.
        the dim must be (basin, time, lead_step, var) for 4-d array;
        the dim must be (basin, time, var) for 3-d array;
        the dim must be (basin, time) for 2-d array;
    data_cfgs
        configs for reading data
    is_tra_val_te
        train, valid or test
    data_source
        data source to get original data info
    kwargs
        other optional parameters for ScalerHub
    """
    self.data_cfgs = data_cfgs
    scaler_type = data_cfgs["scaler"]
    pbm_norm = data_cfgs["scaler_params"]["pbm_norm"]
    if scaler_type == "DapengScaler":
        gamma_norm_cols = data_cfgs["scaler_params"]["gamma_norm_cols"]
        prcp_norm_cols = data_cfgs["scaler_params"]["prcp_norm_cols"]
        scaler = DapengScaler(
            vars_data,
            data_cfgs,
            is_tra_val_te,
            prcp_norm_cols=prcp_norm_cols,
            gamma_norm_cols=gamma_norm_cols,
            pbm_norm=pbm_norm,
            data_source=data_source,
        )
    elif scaler_type in SCALER_DICT.keys():
        scaler = SklearnScaler(
            vars_data,
            data_cfgs,
            is_tra_val_te,
            pbm_norm=pbm_norm,
        )
    else:
        raise NotImplementedError(
            "We don't provide this Scaler now!!! Please choose another one: DapengScaler or key in SCALER_DICT"
        )
    self.norm_data = scaler.load_norm_data(vars_data)
    # we will use target_scaler during denormalization
    self.target_scaler = scaler
    print("Finish Normalization\n")

SklearnScaler

Source code in torchhydro/datasets/data_scalers.py
class SklearnScaler(object):
    def __init__(
        self,
        vars_data,
        data_cfgs,
        is_tra_val_te,
        pbm_norm=False,
    ):
        """_summary_

        Parameters
        ----------
        vars_data : dict
            vars data map
        data_cfgs : _type_
            _description_
        is_tra_val_te : bool
            _description_
        pbm_norm : bool, optional
            _description_, by default False
        """
        # we will use data_target and target_scaler for denormalization
        self.data_target = vars_data["target_cols"]
        self.target_scaler = None
        self.data_cfgs = data_cfgs
        self.is_tra_val_te = is_tra_val_te
        self.pbm_norm = pbm_norm

    def load_norm_data(self, vars_data):
        # TODO: not fully tested for differentiable models
        norm_dict = {}
        scaler_type = self.data_cfgs["scaler"]
        for k, v in vars_data.items():
            scaler = SCALER_DICT[scaler_type]()
            if v.ndim == 3:
                # for forcings and outputs
                num_instances, num_time_steps, num_features = v.shape
                v_np = v.to_numpy().reshape(-1, num_features)
                scaler, data_norm = self._sklearn_scale(
                    self.data_cfgs, self.is_tra_val_te, scaler, k, v_np
                )
                data_norm = data_norm.reshape(
                    num_instances, num_time_steps, num_features
                )
                norm_xrarray = xr.DataArray(
                    data_norm,
                    coords={
                        "basin": v.coords["basin"],
                        "time": v.coords["time"],
                        "variable": v.coords["variable"],
                    },
                    dims=["basin", "time", "variable"],
                )
            elif v.ndim == 2:
                num_instances, num_features = v.shape
                v_np = v.to_numpy().reshape(-1, num_features)
                scaler, data_norm = self._sklearn_scale(
                    self.data_cfgs, self.is_tra_val_te, scaler, k, v_np
                )
                # don't need to reshape data_norm again as it is 2-d
                norm_xrarray = xr.DataArray(
                    data_norm,
                    coords={
                        "basin": v.coords["basin"],
                        "variable": v.coords["variable"],
                    },
                    dims=["basin", "variable"],
                )
            elif v.ndim == 4:
                # for forecast data
                num_instances, num_time_steps, num_lead_steps, num_features = v.shape
                v_np = v.to_numpy().reshape(-1, num_features)
                scaler, data_norm = self._sklearn_scale(
                    self.data_cfgs, self.is_tra_val_te, scaler, k, v_np
                )
                data_norm = data_norm.reshape(
                    num_instances, num_time_steps, num_lead_steps, num_features
                )
                norm_xrarray = xr.DataArray(
                    data_norm,
                    coords={
                        "basin": v.coords["basin"],
                        "time": v.coords["time"],
                        "lead_step": v.coords["lead_step"],
                        "variable": v.coords["variable"],
                    },
                    dims=["basin", "time", "lead_step", "variable"],
                )
            else:
                raise NotImplementedError(
                    "Please check your data, the dim of data must be 2, 3 or 4"
                )

            norm_dict[k] = norm_xrarray
            if k == "target_cols":
                # we need target cols scaler for denormalization
                self.target_scaler = scaler
        return norm_dict

    def _sklearn_scale(self, data_cfgs, is_tra_val_te, scaler, norm_key, data):
        save_file = os.path.join(data_cfgs["case_dir"], f"{norm_key}_scaler.pkl")
        if is_tra_val_te == "train" and data_cfgs["stat_dict_file"] is None:
            data_norm = scaler.fit_transform(data)
            # Save scaler in case_dir for valid/test
            with open(save_file, "wb") as outfile:
                pkl.dump(scaler, outfile)
        else:
            if data_cfgs["stat_dict_file"] is not None:
                # NOTE: you need to set data_cfgs["stat_dict_file"] as a str with ";" as its seperator
                # the sequence of the stat_dict_file must be same as the sequence of norm_keys
                # for example: "stat_dict_file": "target_stat_dict_file;relevant_stat_dict_file;constant_stat_dict_file"
                shutil.copy(data_cfgs["stat_dict_file"][norm_key], save_file)
            if not os.path.isfile(save_file):
                raise FileNotFoundError("Please genereate xx_scaler.pkl file")
            with open(save_file, "rb") as infile:
                scaler = pkl.load(infile)
                data_norm = scaler.transform(data)
        return scaler, data_norm

    def inverse_transform(self, target_values):
        """
        Denormalization for output variables

        Parameters
        ----------
        target_values
            output variables (xr.DataArray or np.ndarray)

        Returns
        -------
        xr.Dataset
            denormalized predictions or observations
        """
        coords = self.data_target.coords
        attrs = self.data_target.attrs
        # input must be xr.DataArray
        if not isinstance(target_values, xr.DataArray):
            # the shape of target_values must be (basin, time, variable)
            target_values = xr.DataArray(
                target_values,
                coords={
                    "basin": coords["basin"],
                    "time": coords["time"],
                    "variable": coords["variable"],
                },
                dims=["basin", "time", "variable"],
            )
        # transform to numpy array for sklearn inverse_transform
        shape = target_values.shape
        arr = target_values.to_numpy().reshape(-1, shape[-1])
        # sklearn inverse_transform
        arr_inv = self.target_scaler.inverse_transform(arr)
        # reshape to original shape
        arr_inv = arr_inv.reshape(shape)
        result = xr.DataArray(
            arr_inv,
            coords=target_values.coords,
            dims=target_values.dims,
            attrs=attrs,
        )
        # add attrs for units
        result.attrs.update(self.data_target.attrs)
        return result.to_dataset(dim="variable")

__init__(self, vars_data, data_cfgs, is_tra_val_te, pbm_norm=False) special

summary

Parameters

vars_data : dict vars data map data_cfgs : type description is_tra_val_te : bool description pbm_norm : bool, optional description, by default False

Source code in torchhydro/datasets/data_scalers.py
def __init__(
    self,
    vars_data,
    data_cfgs,
    is_tra_val_te,
    pbm_norm=False,
):
    """_summary_

    Parameters
    ----------
    vars_data : dict
        vars data map
    data_cfgs : _type_
        _description_
    is_tra_val_te : bool
        _description_
    pbm_norm : bool, optional
        _description_, by default False
    """
    # we will use data_target and target_scaler for denormalization
    self.data_target = vars_data["target_cols"]
    self.target_scaler = None
    self.data_cfgs = data_cfgs
    self.is_tra_val_te = is_tra_val_te
    self.pbm_norm = pbm_norm

inverse_transform(self, target_values)

Denormalization for output variables

Parameters

target_values output variables (xr.DataArray or np.ndarray)

Returns

xr.Dataset denormalized predictions or observations

Source code in torchhydro/datasets/data_scalers.py
def inverse_transform(self, target_values):
    """
    Denormalization for output variables

    Parameters
    ----------
    target_values
        output variables (xr.DataArray or np.ndarray)

    Returns
    -------
    xr.Dataset
        denormalized predictions or observations
    """
    coords = self.data_target.coords
    attrs = self.data_target.attrs
    # input must be xr.DataArray
    if not isinstance(target_values, xr.DataArray):
        # the shape of target_values must be (basin, time, variable)
        target_values = xr.DataArray(
            target_values,
            coords={
                "basin": coords["basin"],
                "time": coords["time"],
                "variable": coords["variable"],
            },
            dims=["basin", "time", "variable"],
        )
    # transform to numpy array for sklearn inverse_transform
    shape = target_values.shape
    arr = target_values.to_numpy().reshape(-1, shape[-1])
    # sklearn inverse_transform
    arr_inv = self.target_scaler.inverse_transform(arr)
    # reshape to original shape
    arr_inv = arr_inv.reshape(shape)
    result = xr.DataArray(
        arr_inv,
        coords=target_values.coords,
        dims=target_values.dims,
        attrs=attrs,
    )
    # add attrs for units
    result.attrs.update(self.data_target.attrs)
    return result.to_dataset(dim="variable")

data_sets

Author: Wenyu Ouyang Date: 2024-04-08 18:16:53 LastEditTime: 2025-11-07 09:39:57 LastEditors: Wenyu Ouyang Description: A pytorch dataset class; references to https://github.com/neuralhydrology/neuralhydrology FilePath: orchhydro orchhydro\datasets\data_sets.py Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved.

AugmentedFloodEventDataset (FloodEventDataset)

Dataset class for augmented flood event data with discontinuous time ranges.

This dataset is designed to handle flood event data that includes augmented (generated) future data alongside historical data, where time ranges may be discontinuous (e.g., historical data 1990-2010, then augmented data 2026+).

It connects to hydrodatasource.reader.floodevent.FloodEventDatasource and uses the read_ts_xrdataset_augmented method to read augmented data.

Source code in torchhydro/datasets/data_sets.py
class AugmentedFloodEventDataset(FloodEventDataset):
    """Dataset class for augmented flood event data with discontinuous time ranges.

    This dataset is designed to handle flood event data that includes augmented
    (generated) future data alongside historical data, where time ranges may be
    discontinuous (e.g., historical data 1990-2010, then augmented data 2026+).

    It connects to hydrodatasource.reader.floodevent.FloodEventDatasource
    and uses the read_ts_xrdataset_augmented method to read augmented data.
    """

    def __init__(self, cfgs: dict, is_tra_val_te: str):
        """Initialize AugmentedFloodEventDataset

        Parameters
        ----------
        cfgs : dict
            Configuration dictionary containing data_cfgs, training_cfgs, evaluation_cfgs
        is_tra_val_te : str
            One of 'train', 'valid', or 'test'
        """
        super(AugmentedFloodEventDataset, self).__init__(cfgs, is_tra_val_te)

        if not hasattr(self.data_source, "read_ts_xrdataset_augmented"):
            raise ValueError(
                "Data source must support read_ts_xrdataset_augmented method"
            )

    def _read_xyc_period(self, start_date, end_date):
        """Override template method to read augmented flood event data for a specific period

        This method leverages the parent class's multi-period handling while using
        augmented data reading methods for generated future data.

        Parameters
        ----------
        start_date : str
            start time
        end_date : str
            end time

        Returns
        -------
        dict
            Dictionary containing relevant_cols, target_cols, and constant_cols data
        """
        return self._read_xyc_specified_time_augmented(start_date, end_date)

    def _read_xyc_specified_time_augmented(self, start_date, end_date):
        """Read x, y, c data from both historical and augmented data sources

        This method reads both historical observed data (using read_ts_xrdataset)
        and augmented future data (using read_ts_xrdataset_augmented), then
        concatenates them along the time dimension to provide a complete dataset.

        Parameters
        ----------
        start_date : str
            start time
        end_date : str
            end time

        Returns
        -------
        dict
            Dictionary containing relevant_cols, target_cols, and constant_cols data
        """

        # Read historical observed data using standard method
        relevant_cols = self.data_cfgs.get("relevant_cols", ["rain"])
        target_cols = self.data_cfgs.get("target_cols", ["inflow", "flood_event"])

        try:
            data_forcing_hist_ = self.data_source.read_ts_xrdataset(
                self.t_s_dict["sites_id"],
                [start_date, end_date],
                relevant_cols,
            )
            data_output_hist_ = self.data_source.read_ts_xrdataset(
                self.t_s_dict["sites_id"],
                [start_date, end_date],
                target_cols,
            )

            # Process historical data
            data_forcing_hist_ = self._rm_timeunit_key(data_forcing_hist_)
            data_output_hist_ = self._rm_timeunit_key(data_output_hist_)

        except Exception as e:
            LOGGER.info(f"无法读取历史数据,可能是时间范围不在历史数据中: {e}")
            data_forcing_hist_ = None
            data_output_hist_ = None

        # Read augmented data using augmented method
        try:
            data_forcing_aug_ = self.data_source.read_ts_xrdataset_augmented(
                self.t_s_dict["sites_id"],
                [start_date, end_date],
                relevant_cols,
            )
            data_output_aug_ = self.data_source.read_ts_xrdataset_augmented(
                self.t_s_dict["sites_id"],
                [start_date, end_date],
                target_cols,
            )

            # Process augmented data
            data_forcing_aug_ = self._rm_timeunit_key(data_forcing_aug_)
            data_output_aug_ = self._rm_timeunit_key(data_output_aug_)

        except Exception as e:
            LOGGER.info(f"无法读取增强数据,可能是时间范围不在增强数据中: {e}")
            data_forcing_aug_ = None
            data_output_aug_ = None

        # Combine historical and augmented data
        data_forcing_ds = self._combine_historical_and_augmented_data(
            data_forcing_hist_, data_forcing_aug_, "forcing"
        )
        data_output_ds = self._combine_historical_and_augmented_data(
            data_output_hist_, data_output_aug_, "target"
        )

        # Check and process combined data
        data_forcing_ds, data_output_ds = self._check_ts_xrds_unit(
            data_forcing_ds, data_output_ds
        )

        # Read constant/attribute data (same as parent class)
        data_attr_ds = self.data_source.read_attr_xrdataset(
            self.t_s_dict["sites_id"],
            self.data_cfgs["constant_cols"],
            all_number=True,
        )

        # Convert to DataArray with units
        x_origin, y_origin, c_origin = self._to_dataarray_with_unit(
            data_forcing_ds, data_output_ds, data_attr_ds
        )

        return {
            "relevant_cols": x_origin.transpose("basin", "time", "variable"),
            "target_cols": y_origin.transpose("basin", "time", "variable"),
            "constant_cols": (
                c_origin.transpose("basin", "variable")
                if c_origin is not None
                else None
            ),
        }

    def _combine_historical_and_augmented_data(self, hist_data, aug_data, data_type):
        """Combine historical observed data and augmented generated data

        This method concatenates historical and augmented data along the time dimension,
        handling cases where data may be discontinuous or overlapping.

        Parameters
        ----------
        hist_data : xr.Dataset or None
            Historical observed data
        aug_data : xr.Dataset or None
            Augmented generated data
        data_type : str
            Type of data ("forcing" or "target") for logging purposes

        Returns
        -------
        xr.Dataset
            Combined dataset with historical and augmented data concatenated
        """
        import xarray as xr

        # Handle cases where one or both data sources are None
        if hist_data is None and aug_data is None:
            raise ValueError(f"Both historical and augmented {data_type} data are None")
        elif hist_data is None:
            LOGGER.info(
                f"No historical {data_type} data found, using only augmented data"
            )
            return aug_data
        elif aug_data is None:
            LOGGER.info(
                f"No augmented {data_type} data found, using only historical data"
            )
            return hist_data

        # Both datasets exist - need to combine them
        try:
            # Check if there's time overlap between datasets
            hist_times = hist_data.time.values if "time" in hist_data.dims else []
            aug_times = aug_data.time.values if "time" in aug_data.dims else []

            if len(hist_times) == 0:
                LOGGER.info(
                    f"Historical {data_type} data has no time dimension, using only augmented data"
                )
                return aug_data
            elif len(aug_times) == 0:
                LOGGER.info(
                    f"Augmented {data_type} data has no time dimension, using only historical data"
                )
                return hist_data

            # Find overlap period
            hist_start, hist_end = hist_times[0], hist_times[-1]
            aug_start, aug_end = aug_times[0], aug_times[-1]

            # Check for overlap
            if hist_end < aug_start:
                # No overlap - historical data ends before augmented data starts
                LOGGER.info(
                    f"No temporal overlap for {data_type} data, concatenating sequentially"
                )
                combined_data = xr.concat([hist_data, aug_data], dim="time")
            elif aug_end < hist_start:
                # No overlap - augmented data ends before historical data starts
                LOGGER.info(
                    f"Augmented {data_type} data precedes historical data, concatenating"
                )
                combined_data = xr.concat([aug_data, hist_data], dim="time")
            else:
                # There is overlap - need to handle carefully
                LOGGER.info(
                    f"Temporal overlap detected for {data_type} data, "
                    f"merging with priority to historical data"
                )

                # Create time index for the full range
                all_times = sorted(set(list(hist_times) + list(aug_times)))

                # Reindex both datasets to the full time range
                hist_reindexed = hist_data.reindex(time=all_times, method=None)
                aug_reindexed = aug_data.reindex(time=all_times, method=None)

                # Combine: use historical data where available, fill with augmented data
                combined_data = hist_reindexed.where(
                    ~hist_reindexed.isnull(), aug_reindexed
                )

            # Sort by time to ensure proper ordering
            combined_data = combined_data.sortby("time")

            LOGGER.info(
                f"Successfully combined {data_type} data: "
                f"historical shape {hist_data.dims if hasattr(hist_data, 'dims') else 'N/A'}, "
                f"augmented shape {aug_data.dims if hasattr(aug_data, 'dims') else 'N/A'}, "
                f"combined shape {combined_data.dims}"
            )

            return combined_data

        except Exception as e:
            LOGGER.error(f"Failed to combine {data_type} data: {e}")
            # Fallback: prefer historical data if combination fails
            LOGGER.warning(f"Falling back to historical {data_type} data only")
            return hist_data

    def _handle_discontinuous_time_ranges(self, data_dict, start_date, end_date):
        """Handle discontinuous time ranges by filling gaps with NaN values

        This method creates a continuous time index and fills missing periods
        with NaN values, handling cases such as training data covers 1990-2010,
        augmented data starts from 2026+, and test data covers 2011-2025.

        Parameters
        ----------
        data_dict : dict
            Dictionary containing xarray data with keys 'relevant_cols',
            'target_cols', 'constant_cols'
        start_date : str
            Overall start date for the continuous timeline
        end_date : str
            Overall end date for the continuous timeline

        Returns
        -------
        dict
            Dictionary with continuous time index and NaN-filled gaps
        """
        # Create continuous daily time index from start_date to end_date
        try:
            continuous_time = pd.date_range(start=start_date, end=end_date, freq="D")
        except Exception as e:
            LOGGER.warning(f"Failed to create continuous time index: {e}")
            return data_dict

        # Process each data type (relevant_cols, target_cols)
        processed_dict = {}

        for data_key in ["relevant_cols", "target_cols"]:
            if data_key in data_dict and data_dict[data_key] is not None:
                original_data = data_dict[data_key]

                try:
                    # Check if data has time dimension
                    if "time" not in original_data.dims:
                        LOGGER.warning(
                            f"{data_key} has no time dimension, skipping time alignment"
                        )
                        processed_dict[data_key] = original_data
                        continue

                    # Reindex to continuous time, filling gaps with NaN
                    aligned_data = original_data.reindex(
                        time=continuous_time,
                        method=None,  # No interpolation, fill with NaN
                        fill_value=float("nan"),
                    )

                    processed_dict[data_key] = aligned_data

                    # Log information about the alignment
                    original_time_points = len(original_data.time)
                    aligned_time_points = len(aligned_data.time)
                    nan_points = aligned_data.isnull().sum().sum().values

                    LOGGER.info(
                        f"{data_key}: aligned from {original_time_points} to "
                        f"{aligned_time_points} time points, with {nan_points} "
                        f"NaN values for discontinuous periods"
                    )

                except Exception as e:
                    LOGGER.error(
                        f"Failed to align {data_key} to continuous timeline: {e}"
                    )
                    processed_dict[data_key] = original_data
            else:
                processed_dict[data_key] = data_dict.get(data_key)

        # Constant cols don't need time alignment
        processed_dict["constant_cols"] = data_dict.get("constant_cols")

        return processed_dict

__init__(self, cfgs, is_tra_val_te) special

Initialize AugmentedFloodEventDataset

Parameters

cfgs : dict Configuration dictionary containing data_cfgs, training_cfgs, evaluation_cfgs is_tra_val_te : str One of 'train', 'valid', or 'test'

Source code in torchhydro/datasets/data_sets.py
def __init__(self, cfgs: dict, is_tra_val_te: str):
    """Initialize AugmentedFloodEventDataset

    Parameters
    ----------
    cfgs : dict
        Configuration dictionary containing data_cfgs, training_cfgs, evaluation_cfgs
    is_tra_val_te : str
        One of 'train', 'valid', or 'test'
    """
    super(AugmentedFloodEventDataset, self).__init__(cfgs, is_tra_val_te)

    if not hasattr(self.data_source, "read_ts_xrdataset_augmented"):
        raise ValueError(
            "Data source must support read_ts_xrdataset_augmented method"
        )

BaseDataset (Dataset)

Base data set class to load and preprocess data (batch-first) using PyTorch's Dataset

Source code in torchhydro/datasets/data_sets.py
class BaseDataset(Dataset):
    """Base data set class to load and preprocess data (batch-first) using PyTorch's Dataset"""

    def __init__(self, cfgs: dict, is_tra_val_te: str):
        """
        Parameters
        ----------
        cfgs
            configs, including data and training + evaluation settings
            which will be used for organizing batch data
        is_tra_val_te
            train, vaild or test
        """
        super(BaseDataset, self).__init__()
        self.data_cfgs = cfgs["data_cfgs"]
        self.training_cfgs = cfgs["training_cfgs"]
        self.evaluation_cfgs = cfgs["evaluation_cfgs"]
        self._pre_load_data(is_tra_val_te)
        # load and preprocess data
        self._load_data()

    def _pre_load_data(self, is_tra_val_te):
        """
        some preprocessing before loading data, such as
        setting the way to organize batch data

        Parameters
        ----------
        is_tra_val_te: bool
            train, valid or test

        Raises
        ------
        ValueError
            _description_
        """
        if is_tra_val_te in {"train", "valid", "test"}:
            self.is_tra_val_te = is_tra_val_te
        else:
            raise ValueError(
                "'is_tra_val_te' must be one of 'train', 'valid' or 'test' "
            )
        self.train_mode = self.is_tra_val_te == "train"
        self.t_s_dict = wrap_t_s_dict(self.data_cfgs, self.is_tra_val_te)
        self.rho = self.training_cfgs["hindcast_length"]
        self.warmup_length = self.training_cfgs["warmup_length"]
        self.horizon = self.training_cfgs["forecast_length"]
        valid_batch_mode = self.training_cfgs["valid_batch_mode"]
        # train + valid with valid_mode is train means we will use the same batch data for train and valid
        self.is_new_batch_way = (
            is_tra_val_te != "valid" or valid_batch_mode != "train"
        ) and is_tra_val_te != "train"
        rolling = self.evaluation_cfgs.get("rolling", 0)
        if self.evaluation_cfgs["hrwin"] is None:
            hrwin = self.rho
        else:
            hrwin = self.evaluation_cfgs["hrwin"]
        if self.evaluation_cfgs["frwin"] is None:
            frwin = self.horizon
        else:
            frwin = self.evaluation_cfgs["frwin"]
        if rolling == 0:
            hrwin = 0 if hrwin is None else hrwin
            frwin = self.nt - hrwin - self.warmup_length
        if self.is_new_batch_way:
            # we will set the batch data for valid and test
            self.rolling = rolling
            self.rho = hrwin
            self.horizon = frwin

    @property
    def data_source(self):
        source_name = self.data_cfgs["source_cfgs"]["source_name"]
        source_path = self.data_cfgs["source_cfgs"]["source_path"]

        # 传递除了 source_name 和 source_path 之外的所有参数

        # 先获取所有参数
        other_settings = self.data_cfgs["source_cfgs"].get("other_settings", {})

        # 排除 source_name, source_path
        other_settings.pop("source_name", None)
        other_settings.pop("source_path", None)

        return data_sources_dict[source_name](source_path, **other_settings)

    @property
    def streamflow_name(self):
        return self.data_cfgs["target_cols"][0]

    @property
    def precipitation_name(self):
        return self.data_cfgs["relevant_cols"][0]

    @property
    def ngrid(self):
        """How many basins/grids in the dataset

        Returns
        -------
        int
            number of basins/grids
        """
        return len(self.basins)

    @property
    def noutputvar(self):
        """How many output variables in the dataset
        Used in evaluation.

        Returns
        -------
        int
            number of variables
        """
        return len(self.data_cfgs["target_cols"])

    @property
    def nt(self):
        """length of longest time series in all basins

        Returns
        -------
        int
            number of longest time steps
        """
        if isinstance(self.t_s_dict["t_final_range"][0], tuple):
            trange_type_num = len(self.t_s_dict["t_final_range"])
            if trange_type_num not in [self.ngrid, 1]:
                raise ValueError(
                    "The number of time ranges should be equal to the number of basins "
                    "if you choose different time ranges for different basins"
                )
            earliest_date = None
            latest_date = None
            for start_date_str, end_date_str in self.t_s_dict["t_final_range"]:
                date_format = detect_date_format(start_date_str)

                start_date = datetime.strptime(start_date_str, date_format)
                end_date = datetime.strptime(end_date_str, date_format)

                if earliest_date is None or start_date < earliest_date:
                    earliest_date = start_date
                if latest_date is None or end_date > latest_date:
                    latest_date = end_date
            earliest_date = earliest_date.strftime(date_format)
            latest_date = latest_date.strftime(date_format)
        else:
            trange_type_num = 1
            earliest_date = self.t_s_dict["t_final_range"][0]
            latest_date = self.t_s_dict["t_final_range"][1]
        min_time_unit = self.data_cfgs["min_time_unit"]
        min_time_interval = self.data_cfgs["min_time_interval"]

        # 计算时间步长(以小时为单位)
        unit_to_hours = {
            "h": 1,
            "H": 1,
            "d": 24,
            "D": 24,
            "m": 1 / 60,
            "M": 1 / 60,
            "s": 1 / 3600,
            "S": 1 / 3600,
        }
        hours_per_step = min_time_interval * unit_to_hours.get(min_time_unit, 1)

        # 解析时间字符串
        date_format = detect_date_format(
            earliest_date[0]
            if isinstance(earliest_date, (list, tuple))
            else earliest_date
        )

        # 获取开始和结束时间
        if isinstance(earliest_date, (list, tuple)):
            s_date = datetime.strptime(
                earliest_date[0], date_format
            )  # 使用第一个元素作为开始时间
        else:
            s_date = datetime.strptime(earliest_date, date_format)

        if isinstance(latest_date, (list, tuple)):
            e_date = datetime.strptime(
                latest_date[-1], date_format
            )  # 使用最后一个元素作为结束时间
        else:
            e_date = datetime.strptime(latest_date, date_format)

        # 计算总小时数
        total_hours = (e_date - s_date).total_seconds() / 3600

        # 计算时间步数
        return int(total_hours / hours_per_step) + 1

    @property
    def basins(self):
        """Return the basins of the dataset"""
        return self.t_s_dict["sites_id"]

    @property
    def times(self):
        """Return the times of all basins

        TODO: Although we support get different time ranges for different basins,
        we didn't implement the reading function for this case in _read_xyc method.
        Hence, it's better to choose unified time range for all basins
        """
        min_time_unit = self.data_cfgs["min_time_unit"]
        min_time_interval = self.data_cfgs["min_time_interval"]
        time_step = f"{min_time_interval}{min_time_unit}"
        if isinstance(self.t_s_dict["t_final_range"][0], tuple):
            times_ = []
            trange_type_num = len(self.t_s_dict["t_final_range"])
            if trange_type_num not in [self.ngrid, 1]:
                raise ValueError(
                    "The number of time ranges should be equal to the number of basins "
                    "if you choose different time ranges for different basins"
                )
            detect_date_format(self.t_s_dict["t_final_range"][0][0])
            for start_date_str, end_date_str in self.t_s_dict["t_final_range"]:
                s_date = pd.to_datetime(start_date_str)
                e_date = pd.to_datetime(end_date_str)
                time_series = pd.date_range(start=s_date, end=e_date, freq=time_step)
                times_.append(time_series)
        else:
            detect_date_format(self.t_s_dict["t_final_range"][0])
            trange_type_num = 1
            s_date = pd.to_datetime(self.t_s_dict["t_final_range"][0])
            e_date = pd.to_datetime(self.t_s_dict["t_final_range"][1])
            times_ = pd.date_range(start=s_date, end=e_date, freq=time_step)
        return times_

    def __len__(self):
        return self.num_samples

    def __getitem__(self, item: int):
        """Get one sample from the dataset with a unified return format.

        Args:
            item: The index of the sample to retrieve.

        Returns:
            A tuple of (input_data, output_data), where input_data is a tensor
            of input features and output_data is a tensor of target values.
        """
        basin, idx, actual_length = self.lookup_table[item]
        warmup_length = self.warmup_length
        x = self.x[basin, idx - warmup_length : idx + actual_length, :]
        y = self.y[basin, idx : idx + actual_length, :]
        if self.c is None or self.c.shape[-1] == 0:
            return torch.from_numpy(x).float(), torch.from_numpy(y).float()
        c = self.c[basin, :]
        c = np.repeat(c, x.shape[0], axis=0).reshape(c.shape[0], -1).T
        xc = np.concatenate((x, c), axis=1)
        return torch.from_numpy(xc).float(), torch.from_numpy(y).float()

    def _load_data(self):
        origin_data = self._read_xyc()
        # normalization
        norm_data = self._normalize(origin_data)
        # 启用 NaN 处理以确保数据清洁
        origin_data_wonan, norm_data_wonan = self._kill_nan(origin_data, norm_data)
        # origin_data_wonan, norm_data_wonan = origin_data, norm_data  # 备用:跳过 NaN 处理
        self._trans2nparr(origin_data_wonan, norm_data_wonan)
        self._create_lookup_table()

    def _trans2nparr(self, origin_data, norm_data):
        """To make __getitem__ more efficient,
        we transform x, y, c to numpy array with shape (nsample, nt, nvar)
        """
        for key in origin_data.keys():
            _origin = origin_data[key]
            _norm = norm_data[key]
            if _origin is None or _norm is None:
                norm_arr = None
                origin_arr = None
            else:
                norm_arr = _norm.to_numpy()
                origin_arr = _origin.to_numpy()
            if key == "relevant_cols":
                self.x_origin = origin_arr
                self.x = norm_arr
            elif key == "target_cols":
                self.y_origin = origin_arr
                self.y = norm_arr
            elif key == "constant_cols":
                self.c_origin = origin_arr
                self.c = norm_arr
            elif key == "forecast_cols":
                self.f_origin = origin_arr
                self.f = norm_arr
            elif key == "global_cols":
                self.g_origin = origin_arr
                self.g = norm_arr
            elif key == "station_cols":
                # GNN特有的站点数据
                self.station_cols_origin = origin_arr
                self.station_cols = norm_arr
            else:
                raise ValueError(
                    f"Unknown data type {key} in origin_data, "
                    "it should be one of relevant_cols, target_cols, constant_cols, forecast_cols, global_cols, station_cols"
                )

    def _normalize(
        self,
        origin_data,
    ):
        """_summary_

        Parameters
        ----------
        origin_data : dict
            data with key as data type

        Returns
        -------
        _type_
            _description_
        """
        scaler_hub = ScalerHub(
            origin_data,
            data_cfgs=self.data_cfgs,
            is_tra_val_te=self.is_tra_val_te,
            data_source=self.data_source,
        )
        self.target_scaler = scaler_hub.target_scaler
        return scaler_hub.norm_data

    def _selected_time_points_for_denorm(self):
        """get the time points for denormalization

        Returns
        -------
            a list of time points
        """
        return self.target_scaler.data_target.coords["time"][self.warmup_length :]

    def denormalize(self, norm_data, pace_idx=None):
        """Denormalize the norm_data

        Parameters
        ----------
        norm_data : np.ndarray
            batch-first data
        pace_idx : int, optional
            which pace to show, by default None
            sometimes we may have multiple results for one time period and we flatten them
            so we need a temp time to replace real one

        Returns
        -------
        xr.Dataset
            denormlized data
        """
        target_scaler = self.target_scaler
        target_data = target_scaler.data_target
        # the units are dimensionless for pure DL models
        units = {k: "dimensionless" for k in target_data.attrs["units"].keys()}
        # mainly to get information about the time points of norm_data
        selected_time_points = self._selected_time_points_for_denorm()
        selected_data = target_data.sel(time=selected_time_points)

        # 处理三维数据 (basin, time, variable)
        if norm_data.ndim == 3:
            coords = {
                "basin": selected_data.coords["basin"],
                "time": selected_data.coords["time"],
                "variable": selected_data.coords["variable"],
            }
            dims = ["basin", "time", "variable"]
            # add
            if isinstance(selected_time_points, xr.DataArray):
                # 获取 target_data 的时间轴
                time_coords = target_data.coords["time"].values
                # 找到 selected_time_points 对应的整数索引
                selected_indices = np.where(np.isin(time_coords, selected_time_points))[
                    0
                ]
            else:
                # 如果 selected_time_points 已经是整数索引,直接使用
                selected_indices = selected_time_points

            # 确保索引不越界
            max_idx = norm_data.shape[1] - 1
            selected_indices = np.clip(selected_indices, 0, max_idx)
            if norm_data.shape[1] != len(selected_data.coords["time"]):
                norm_data_3d = norm_data[:, selected_indices, :]
            else:
                norm_data_3d = norm_data

        # 处理四维数据
        elif norm_data.ndim == 4:
            # Check if the data is organized by basins
            if self.evaluation_cfgs["evaluator"]["recover_mode"] == "bybasins":
                # Shape: (basin_num, i_e_time_length, forecast_length, nf)
                basin_num, i_e_time_length, forecast_length, nf = norm_data.shape

                # If pace_idx is specified, select the specific forecast step
                if (
                    pace_idx is not None
                    and pace_idx != np.nan
                    and pace_idx >= 0
                    and pace_idx < forecast_length
                ):
                    norm_data_3d = norm_data[:, :, pace_idx, :]
                    # 创建新的坐标
                    # 修改这里:确保basin坐标长度与数据维度匹配
                    if basin_num == 1 and len(selected_data.coords["basin"]) > 1:
                        # 当只有一个流域时,选择第一个流域的坐标
                        basin_coord = selected_data.coords["basin"].values[:1]
                    else:
                        basin_coord = selected_data.coords["basin"].values[:basin_num]

                    coords = {
                        "basin": basin_coord,
                        "time": selected_data.coords["time"][:i_e_time_length],
                        "variable": selected_data.coords["variable"],
                    }
                else:
                    # 如果没有指定pace_idx,则创建一个新的维度'horizon'
                    norm_data_3d = norm_data.reshape(
                        basin_num, i_e_time_length * forecast_length, nf
                    )
                    # 创建新的时间坐标,重复i_e_time_length次
                    new_times = []
                    for i in range(forecast_length):
                        if i < len(selected_data.coords["time"]):
                            new_times.extend(
                                selected_data.coords["time"][:i_e_time_length]
                            )

                    # 确保时间坐标长度与数据匹配
                    if len(new_times) > i_e_time_length * forecast_length:
                        new_times = new_times[: i_e_time_length * forecast_length]
                    elif len(new_times) < i_e_time_length * forecast_length:
                        # 如果时间坐标不足,使用最后一个时间点填充
                        last_time = (
                            new_times[-1]
                            if new_times
                            else selected_data.coords["time"][0]
                        )
                        while len(new_times) < i_e_time_length * forecast_length:
                            new_times.append(last_time)

                    # 修改这里:确保basin坐标长度与数据维度匹配
                    if basin_num == 1 and len(selected_data.coords["basin"]) > 1:
                        basin_coord = selected_data.coords["basin"].values[:1]
                    else:
                        basin_coord = selected_data.coords["basin"].values[:basin_num]

                    coords = {
                        "basin": basin_coord,
                        "time": new_times,
                        "variable": selected_data.coords["variable"],
                    }
            else:  # byforecast模式
                # 形状为 (forecast_length, basin_num, i_e_time_length, nf)
                forecast_length, basin_num, i_e_time_length, nf = norm_data.shape

                # 如果指定了pace_idx,则选择特定的预测步长
                if (
                    pace_idx is not None
                    and pace_idx != np.nan
                    and pace_idx >= 0
                    and pace_idx < forecast_length
                ):
                    norm_data_3d = norm_data[pace_idx]
                    # 修改这里:确保basin坐标长度与数据维度匹配
                    if basin_num == 1 and len(selected_data.coords["basin"]) > 1:
                        basin_coord = selected_data.coords["basin"].values[:1]
                    else:
                        basin_coord = selected_data.coords["basin"].values[:basin_num]

                    coords = {
                        "basin": basin_coord,
                        "time": selected_data.coords["time"][:i_e_time_length],
                        "variable": selected_data.coords["variable"],
                    }
                else:
                    # If pace_idx is not specified, create a new dimension 'horizon'
                    # Reshape (forecast_length, basin_num, i_e_time_length, nf) -> (basin_num, forecast_length * i_e_time_length, nf)
                    norm_data_3d = np.transpose(norm_data, (1, 0, 2, 3)).reshape(
                        basin_num, forecast_length * i_e_time_length, nf
                    )

                    # 创建新的时间坐标
                    new_times = []
                    for i in range(forecast_length):
                        if i < len(selected_data.coords["time"]):
                            new_times.extend(
                                selected_data.coords["time"][:i_e_time_length]
                            )

                    # 确保时间坐标长度与数据匹配
                    if len(new_times) > forecast_length * i_e_time_length:
                        new_times = new_times[: forecast_length * i_e_time_length]
                    elif len(new_times) < forecast_length * i_e_time_length:
                        # 如果时间坐标不足,使用最后一个时间点填充
                        last_time = (
                            new_times[-1]
                            if new_times
                            else selected_data.coords["time"][0]
                        )
                        while len(new_times) < forecast_length * i_e_time_length:
                            new_times.append(last_time)

                    # 修改这里:确保basin坐标长度与数据维度匹配
                    if basin_num == 1 and len(selected_data.coords["basin"]) > 1:
                        basin_coord = selected_data.coords["basin"].values[:1]
                    else:
                        basin_coord = selected_data.coords["basin"].values[:basin_num]

                    coords = {
                        "basin": basin_coord,
                        "time": new_times,
                        "variable": selected_data.coords["variable"],
                    }
            dims = ["basin", "time", "variable"]
        else:
            coords = selected_data.coords
            dims = selected_data.dims
            norm_data_3d = norm_data

        # create DataArray and inverse transform
        denorm_xr_ds = target_scaler.inverse_transform(
            xr.DataArray(
                norm_data_3d,
                dims=dims,
                coords=coords,
                attrs={"units": units},
            )
        )
        return set_unit_to_var(denorm_xr_ds)

    def _to_dataarray_with_unit(self, *args):
        """Convert xarray datasets to xarray data arrays and set units for each variable.

        Parameters
        ----------
        *args : xr.Dataset
            Any number of xarray dataset inputs.

        Returns
        -------
        tuple
            A tuple of converted data arrays, with the same number as the input parameters.
        """
        results = []
        for ds in args:
            if ds is not None:
                # First convert some string-type data to floating-point type
                results.append(self._trans2da_and_setunits(ds))
            else:
                results.append(None)
        return tuple(results)

    def _check_ts_xrds_unit(self, data_forcing_ds, data_output_ds):
        """Check timeseries xarray dataset unit and convert if necessary

        Parameters
        ----------
        data_forcing_ds : xr.Dataset
            the forcing data
        data_output_ds : xr.Dataset
            outputs including streamflow data
        """

        def standardize_unit(unit):
            unit = unit.lower()  # convert to lower case
            unit = re.sub(r"day", "d", unit)
            unit = re.sub(r"hour", "h", unit)
            return unit

        streamflow_unit = data_output_ds[self.streamflow_name].attrs["units"]
        prcp_unit = data_forcing_ds[self.precipitation_name].attrs["units"]

        standardized_streamflow_unit = standardize_unit(streamflow_unit)
        standardized_prcp_unit = standardize_unit(prcp_unit)
        if standardized_streamflow_unit != standardized_prcp_unit:
            streamflow_dataset = data_output_ds[[self.streamflow_name]]
            converted_streamflow_dataset = streamflow_unit_conv(
                streamflow_dataset,
                self.data_source.read_area(self.t_s_dict["sites_id"]),
                target_unit=prcp_unit,
                source_unit=streamflow_unit,
            )
            data_output_ds[self.streamflow_name] = converted_streamflow_dataset[
                self.streamflow_name
            ]
        return data_forcing_ds, data_output_ds

    def _read_xyc(self):
        """Read x, y, c data from data source

        Returns
        -------
        dict
            data with key as data type
            the dim must be (basin, time, lead_step, variable) for 4-d xr array;
            the dim must be (basin, time, variable) for 3-d xr array;
            the dim must be (basin, variable) for 2-d xr array;
        """
        # Check if we have multiple time periods (for multi-period training)
        t_range = self.t_s_dict["t_final_range"]

        # Check if first element is a list/tuple (indicating multiple periods)
        if isinstance(t_range[0], (list, tuple)):
            # Validate multi-period format
            self._validate_multi_period_format(t_range)

            # Multiple periods case - can be any number of periods
            all_data = None
            for start_date, end_date in t_range:
                period_data = self._read_xyc_period(start_date, end_date)
                if all_data is None:
                    all_data = period_data
                else:
                    # Concatenate along time dimension
                    for key in period_data:
                        # 确保两个数据集的时间维度都是字符串类型
                        if all_data[key] is not None and period_data[key] is not None:
                            if not isinstance(all_data[key].time.values[0], str):
                                all_data[key]["time"] = all_data[key].time.astype(str)
                            if not isinstance(period_data[key].time.values[0], str):
                                period_data[key]["time"] = period_data[key].time.astype(
                                    str
                                )

                            all_data[key] = xr.concat(
                                [all_data[key], period_data[key]], dim="time"
                            )
            return all_data
        else:
            # Single period case (existing behavior)
            start_date = t_range[0]
            end_date = t_range[1]
            return self._read_xyc_period(start_date, end_date)

    def _read_xyc_period(self, start_date, end_date):
        """Template method for reading x, y, c data for a specific time period

        This method can be overridden by subclasses to customize how data is read
        for each time period while keeping the multi-period handling logic in the parent class.

        Parameters
        ----------
        start_date : str
            start time
        end_date : str
            end time

        Returns
        -------
        dict
            Dictionary containing relevant_cols, target_cols, and constant_cols data
        """
        # Default implementation: delegate to the original method
        return self._read_xyc_specified_time(start_date, end_date)

    def _validate_multi_period_format(self, t_range):
        """Validate format of multi-period time ranges

        Parameters
        ----------
        t_range : list
            List of time periods, where each period should be [start_date, end_date]

        Raises
        ------
        ValueError
            If any period doesn't have exactly 2 elements (start_date, end_date)
        """
        for i, period in enumerate(t_range):
            if not isinstance(period, (list, tuple)) or len(period) != 2:
                raise ValueError(
                    f"Period {i} must be a list/tuple with exactly 2 elements (start_date, end_date), got: {period}"
                )

    def _rm_timeunit_key(self, ds_):
        """this means the data source return a dict with key as time_unit
            in this BaseDataset, we only support unified time range for all basins, so we chose the first key
            TODO: maybe this could be refactored better

        Parameters
        ----------
        ds_ : dict
            the xarray data with time_unit as key

        Returns
        ----------
        ds_ : xr.Dataset
            the output data without time_unit
        """
        if isinstance(ds_, dict):
            ds_ = ds_[list(ds_.keys())[0]]
        return ds_

    def _read_xyc_specified_time(self, start_date, end_date):
        """Read x, y, c data from data source with specified time range
        We set this function as sometimes we need adjust the time range for some specific dataset,
        such as seq2seq dataset (it needs one more period for the end of the time range)

        Parameters
        ----------
        start_date : str
            start time
        end_date : str
            end time
        """
        data_forcing_ds_ = self.data_source.read_ts_xrdataset(
            self.t_s_dict["sites_id"],
            [start_date, end_date],
            self.data_cfgs["relevant_cols"],
        )
        # y
        data_output_ds_ = self.data_source.read_ts_xrdataset(
            self.t_s_dict["sites_id"],
            [start_date, end_date],
            self.data_cfgs["target_cols"],
        )
        print(data_output_ds_)
        data_forcing_ds_ = self._rm_timeunit_key(data_forcing_ds_)
        data_output_ds_ = self._rm_timeunit_key(data_output_ds_)
        data_forcing_ds, data_output_ds = self._check_ts_xrds_unit(
            data_forcing_ds_, data_output_ds_
        )
        # c
        data_attr_ds = self.data_source.read_attr_xrdataset(
            self.t_s_dict["sites_id"],
            self.data_cfgs["constant_cols"],
            all_number=True,
        )
        x_origin, y_origin, c_origin = self._to_dataarray_with_unit(
            data_forcing_ds, data_output_ds, data_attr_ds
        )
        return {
            "relevant_cols": x_origin.transpose("basin", "time", "variable"),
            "target_cols": y_origin.transpose("basin", "time", "variable"),
            "constant_cols": (
                c_origin.transpose("basin", "variable")
                if c_origin is not None
                else None
            ),
        }

    def _trans2da_and_setunits(self, ds):
        """Set units for dataarray transfromed from dataset"""
        result = ds.to_array(dim="variable")
        units_dict = {
            var: ds[var].attrs["units"]
            for var in ds.variables
            if "units" in ds[var].attrs
        }
        result.attrs["units"] = units_dict
        return result

    def _kill_nan(self, origin_data, norm_data):
        """This function is used to remove NaN values in the original data and its normalized data.

        Parameters
        ----------
        origin_data : dict
            the original data
        norm_data : dict
            the normalized data

        Returns
        -------
        dict, dict
            the original data and normalized data after removing NaN values
        """
        data_cfgs = self.data_cfgs
        origins_wonan = {}
        norms_wonan = {}
        for key in origin_data.keys():
            _origin = origin_data[key]
            _norm = norm_data[key]
            if _origin is None or _norm is None:
                origins_wonan[key] = None
                norms_wonan[key] = None
                continue
            kill_way = "interpolate"
            if key == "relevant_cols":
                rm_nan = data_cfgs["relevant_rm_nan"]
            elif key == "target_cols":
                rm_nan = data_cfgs["target_rm_nan"]
            elif key == "constant_cols":
                rm_nan = data_cfgs["constant_rm_nan"]
                kill_way = "mean"
            elif key == "forecast_cols":
                rm_nan = data_cfgs["forecast_rm_nan"]
                kill_way = "lead_step"
            elif key == "global_cols":
                rm_nan = data_cfgs["global_rm_nan"]
            elif key == "station_cols":
                rm_nan = data_cfgs.get("station_rm_nan")
            else:
                raise ValueError(
                    f"Unknown data type {key} in origin_data, "
                    "it should be one of relevant_cols, target_cols, constant_cols, forecast_cols, global_cols and station_cols"
                )

            if rm_nan:
                norm = self._kill_1type_nan(
                    _norm,
                    kill_way,
                    "original data",
                    "nan_filled data",
                )
                origin = self._kill_1type_nan(
                    _origin,
                    kill_way,
                    "original data",
                    "nan_filled data",
                )
            else:
                norm = _norm
                origin = _origin
            if key == "target_cols" or not rm_nan:
                warn_if_nan(origin, nan_mode="all", data_name="nan_filled target data")
                warn_if_nan(norm, nan_mode="all", data_name="nan_filled target data")
            else:
                warn_if_nan(origin, nan_mode="any", data_name="nan_filled input data")
                warn_if_nan(norm, nan_mode="any", data_name="nan_filled input data")
            origins_wonan[key] = origin
            norms_wonan[key] = norm
        return origins_wonan, norms_wonan

    def _kill_1type_nan(self, the_data, fill_nan, data_name_before, data_name_after):
        is_any_nan = warn_if_nan(the_data, data_name=data_name_before)
        if not is_any_nan:
            return the_data
        # As input, we cannot have NaN values
        the_filled_data = _fill_gaps_da(the_data, fill_nan=fill_nan)
        warn_if_nan(the_filled_data, data_name=data_name_after)
        return the_filled_data

    def _create_lookup_table(self):
        lookup = []
        # list to collect basins ids of basins without a single training sample
        basin_coordinates = len(self.t_s_dict["sites_id"])
        rho = self.rho
        warmup_length = self.warmup_length
        horizon = self.horizon
        # NOTE: we set seq_len to rho + horizon instead of warmup_length + rho + horizon
        seq_len = rho + horizon
        max_time_length = self.nt
        variable_length_cfgs = self.training_cfgs.get("variable_length_cfgs", {})
        use_variable_length = variable_length_cfgs.get("use_variable_length", False)
        variable_length_type = variable_length_cfgs.get(
            "variable_length_type", "dynamic"
        )  # only used for case when use_variable_length is True
        fixed_lengths = variable_length_cfgs.get("fixed_lengths", [365, 1095, 1825])
        # Use fixed type variable length if enabled and type is fixed
        is_fixed_length_train = use_variable_length and variable_length_type == "fixed"
        for basin in tqdm(range(basin_coordinates), file=sys.stdout, disable=False):
            if not self.train_mode:
                # we don't need to ignore those with full nan in target vars for prediction without loss calculation
                # all samples should be included so that we can recover results to specified basins easily
                lookup.extend(
                    (basin, f, seq_len)
                    for f in range(warmup_length, max_time_length - rho - horizon + 1)
                )
            else:
                # some dataloader load data with warmup period, so leave some periods for it
                # [warmup_len] -> time_start -> [rho] -> [horizon]
                #                       window: \-----------------/ meaning rho + horizon
                nan_array = np.isnan(self.y[basin, :, :])
                if is_fixed_length_train:
                    for window in fixed_lengths:
                        lookup.extend(
                            (basin, f, window)
                            for f in range(
                                warmup_length,
                                max_time_length - window + 1,
                            )
                            # if all nan in window, we skip this sample
                            if not np.all(nan_array[f : f + window])
                        )
                else:
                    lookup.extend(
                        (basin, f, seq_len)
                        for f in range(
                            warmup_length, max_time_length - rho - horizon + 1
                        )
                        # if all nan in rho + horizon window, we skip this sample
                        if not np.all(nan_array[f : f + rho + horizon])
                    )
        self.lookup_table = dict(enumerate(lookup))
        self.num_samples = len(self.lookup_table)

    def _create_multi_len_lookup_table(self):
        """
        Create a lookup table for multi-length training
        TODO: not fully tested
        """
        lookup = []
        # list to collect basins ids of basins without a single training sample
        basin_coordinates = len(self.t_s_dict["sites_id"])
        rho = self.rho
        warmup_length = self.warmup_length
        horizon = self.horizon
        seq_len = warmup_length + rho + horizon
        max_time_length = self.nt
        variable_length_cfgs = self.training_cfgs.get("variable_length_cfgs", {})
        use_variable_length = variable_length_cfgs.get("use_variable_length", False)
        variable_length_type = variable_length_cfgs.get(
            "variable_length_type", "dynamic"
        )
        fixed_lengths = variable_length_cfgs.get("fixed_lengths", [365, 1095, 1825])
        # Use fixed type variable length if enabled and type is fixed
        is_fixed_length_train = use_variable_length and variable_length_type == "fixed"

        # 初始化不同长度的lookup表
        self.lookup_tables_by_length = {length: [] for length in fixed_lengths}

        # New: Global lookup table to map a single index to (window_length, index_within_that_window_length_table)
        self.global_lookup_table_indices = []

        for basin in tqdm(range(basin_coordinates), file=sys.stdout, disable=False):
            if not self.train_mode:
                # For prediction, we still use the original rho for simplicity if multi_length_training is enabled
                # or we can extend this logic to support multi-length prediction if needed.
                # For now, let's assume prediction uses a fixed rho or is handled differently.
                # If multi_length_training is active, we might need to decide which window_len to use for prediction.
                # For now, let's stick to the original logic for train_mode=False
                lookup.extend(
                    (basin, f, seq_len)
                    for f in range(warmup_length, max_time_length - rho - horizon + 1)
                )
            else:
                # some dataloader load data with warmup period, so leave some periods for it
                # [warmup_len] -> time_start -> [rho] -> [horizon]
                nan_array = np.isnan(self.y[basin, :, :])
                if is_fixed_length_train:
                    for window in fixed_lengths:
                        for f in range(
                            warmup_length, max_time_length - window - horizon + 1
                        ):
                            # 检查目标区间内是否全为nan
                            if not np.all(nan_array[f + window : f + window + horizon]):
                                # 记录 (basin, 起始位置) 到对应窗口长度的 lookup table
                                self.lookup_tables_by_length[window].append((basin, f))
                                # 记录 (窗口长度, 在该窗口长度 lookup table 中的索引) 到全局索引表
                                self.global_lookup_table_indices.append(
                                    (
                                        window,
                                        len(self.lookup_tables_by_length[window]) - 1,
                                    )
                                )
                else:
                    lookup.extend(
                        (basin, f, seq_len)
                        for f in range(
                            warmup_length, max_time_length - rho - horizon + 1
                        )
                        if not np.all(nan_array[f + rho : f + rho + horizon])
                    )

        if is_fixed_length_train and self.train_mode:
            # If fixed-length training is enabled and in train mode, use the global lookup table
            self.lookup_table = dict(enumerate(self.global_lookup_table_indices))
            self.num_samples = len(self.global_lookup_table_indices)
        else:
            # Otherwise, use the original lookup table (for fixed length training or prediction)
            self.lookup_table = dict(enumerate(lookup))
            self.num_samples = len(self.lookup_table)

basins property readonly

Return the basins of the dataset

ngrid property readonly

How many basins/grids in the dataset

Returns

int number of basins/grids

noutputvar property readonly

How many output variables in the dataset Used in evaluation.

Returns

int number of variables

nt property readonly

length of longest time series in all basins

Returns

int number of longest time steps

times property readonly

Return the times of all basins

TODO: Although we support get different time ranges for different basins, we didn't implement the reading function for this case in _read_xyc method. Hence, it's better to choose unified time range for all basins

__getitem__(self, item) special

Get one sample from the dataset with a unified return format.

Parameters:

Name Type Description Default
item int

The index of the sample to retrieve.

required

Returns:

Type Description

A tuple of (input_data, output_data), where input_data is a tensor of input features and output_data is a tensor of target values.

Source code in torchhydro/datasets/data_sets.py
def __getitem__(self, item: int):
    """Get one sample from the dataset with a unified return format.

    Args:
        item: The index of the sample to retrieve.

    Returns:
        A tuple of (input_data, output_data), where input_data is a tensor
        of input features and output_data is a tensor of target values.
    """
    basin, idx, actual_length = self.lookup_table[item]
    warmup_length = self.warmup_length
    x = self.x[basin, idx - warmup_length : idx + actual_length, :]
    y = self.y[basin, idx : idx + actual_length, :]
    if self.c is None or self.c.shape[-1] == 0:
        return torch.from_numpy(x).float(), torch.from_numpy(y).float()
    c = self.c[basin, :]
    c = np.repeat(c, x.shape[0], axis=0).reshape(c.shape[0], -1).T
    xc = np.concatenate((x, c), axis=1)
    return torch.from_numpy(xc).float(), torch.from_numpy(y).float()

__init__(self, cfgs, is_tra_val_te) special

Parameters

cfgs configs, including data and training + evaluation settings which will be used for organizing batch data is_tra_val_te train, vaild or test

Source code in torchhydro/datasets/data_sets.py
def __init__(self, cfgs: dict, is_tra_val_te: str):
    """
    Parameters
    ----------
    cfgs
        configs, including data and training + evaluation settings
        which will be used for organizing batch data
    is_tra_val_te
        train, vaild or test
    """
    super(BaseDataset, self).__init__()
    self.data_cfgs = cfgs["data_cfgs"]
    self.training_cfgs = cfgs["training_cfgs"]
    self.evaluation_cfgs = cfgs["evaluation_cfgs"]
    self._pre_load_data(is_tra_val_te)
    # load and preprocess data
    self._load_data()

denormalize(self, norm_data, pace_idx=None)

Denormalize the norm_data

Parameters

norm_data : np.ndarray batch-first data pace_idx : int, optional which pace to show, by default None sometimes we may have multiple results for one time period and we flatten them so we need a temp time to replace real one

Returns

xr.Dataset denormlized data

Source code in torchhydro/datasets/data_sets.py
def denormalize(self, norm_data, pace_idx=None):
    """Denormalize the norm_data

    Parameters
    ----------
    norm_data : np.ndarray
        batch-first data
    pace_idx : int, optional
        which pace to show, by default None
        sometimes we may have multiple results for one time period and we flatten them
        so we need a temp time to replace real one

    Returns
    -------
    xr.Dataset
        denormlized data
    """
    target_scaler = self.target_scaler
    target_data = target_scaler.data_target
    # the units are dimensionless for pure DL models
    units = {k: "dimensionless" for k in target_data.attrs["units"].keys()}
    # mainly to get information about the time points of norm_data
    selected_time_points = self._selected_time_points_for_denorm()
    selected_data = target_data.sel(time=selected_time_points)

    # 处理三维数据 (basin, time, variable)
    if norm_data.ndim == 3:
        coords = {
            "basin": selected_data.coords["basin"],
            "time": selected_data.coords["time"],
            "variable": selected_data.coords["variable"],
        }
        dims = ["basin", "time", "variable"]
        # add
        if isinstance(selected_time_points, xr.DataArray):
            # 获取 target_data 的时间轴
            time_coords = target_data.coords["time"].values
            # 找到 selected_time_points 对应的整数索引
            selected_indices = np.where(np.isin(time_coords, selected_time_points))[
                0
            ]
        else:
            # 如果 selected_time_points 已经是整数索引,直接使用
            selected_indices = selected_time_points

        # 确保索引不越界
        max_idx = norm_data.shape[1] - 1
        selected_indices = np.clip(selected_indices, 0, max_idx)
        if norm_data.shape[1] != len(selected_data.coords["time"]):
            norm_data_3d = norm_data[:, selected_indices, :]
        else:
            norm_data_3d = norm_data

    # 处理四维数据
    elif norm_data.ndim == 4:
        # Check if the data is organized by basins
        if self.evaluation_cfgs["evaluator"]["recover_mode"] == "bybasins":
            # Shape: (basin_num, i_e_time_length, forecast_length, nf)
            basin_num, i_e_time_length, forecast_length, nf = norm_data.shape

            # If pace_idx is specified, select the specific forecast step
            if (
                pace_idx is not None
                and pace_idx != np.nan
                and pace_idx >= 0
                and pace_idx < forecast_length
            ):
                norm_data_3d = norm_data[:, :, pace_idx, :]
                # 创建新的坐标
                # 修改这里:确保basin坐标长度与数据维度匹配
                if basin_num == 1 and len(selected_data.coords["basin"]) > 1:
                    # 当只有一个流域时,选择第一个流域的坐标
                    basin_coord = selected_data.coords["basin"].values[:1]
                else:
                    basin_coord = selected_data.coords["basin"].values[:basin_num]

                coords = {
                    "basin": basin_coord,
                    "time": selected_data.coords["time"][:i_e_time_length],
                    "variable": selected_data.coords["variable"],
                }
            else:
                # 如果没有指定pace_idx,则创建一个新的维度'horizon'
                norm_data_3d = norm_data.reshape(
                    basin_num, i_e_time_length * forecast_length, nf
                )
                # 创建新的时间坐标,重复i_e_time_length次
                new_times = []
                for i in range(forecast_length):
                    if i < len(selected_data.coords["time"]):
                        new_times.extend(
                            selected_data.coords["time"][:i_e_time_length]
                        )

                # 确保时间坐标长度与数据匹配
                if len(new_times) > i_e_time_length * forecast_length:
                    new_times = new_times[: i_e_time_length * forecast_length]
                elif len(new_times) < i_e_time_length * forecast_length:
                    # 如果时间坐标不足,使用最后一个时间点填充
                    last_time = (
                        new_times[-1]
                        if new_times
                        else selected_data.coords["time"][0]
                    )
                    while len(new_times) < i_e_time_length * forecast_length:
                        new_times.append(last_time)

                # 修改这里:确保basin坐标长度与数据维度匹配
                if basin_num == 1 and len(selected_data.coords["basin"]) > 1:
                    basin_coord = selected_data.coords["basin"].values[:1]
                else:
                    basin_coord = selected_data.coords["basin"].values[:basin_num]

                coords = {
                    "basin": basin_coord,
                    "time": new_times,
                    "variable": selected_data.coords["variable"],
                }
        else:  # byforecast模式
            # 形状为 (forecast_length, basin_num, i_e_time_length, nf)
            forecast_length, basin_num, i_e_time_length, nf = norm_data.shape

            # 如果指定了pace_idx,则选择特定的预测步长
            if (
                pace_idx is not None
                and pace_idx != np.nan
                and pace_idx >= 0
                and pace_idx < forecast_length
            ):
                norm_data_3d = norm_data[pace_idx]
                # 修改这里:确保basin坐标长度与数据维度匹配
                if basin_num == 1 and len(selected_data.coords["basin"]) > 1:
                    basin_coord = selected_data.coords["basin"].values[:1]
                else:
                    basin_coord = selected_data.coords["basin"].values[:basin_num]

                coords = {
                    "basin": basin_coord,
                    "time": selected_data.coords["time"][:i_e_time_length],
                    "variable": selected_data.coords["variable"],
                }
            else:
                # If pace_idx is not specified, create a new dimension 'horizon'
                # Reshape (forecast_length, basin_num, i_e_time_length, nf) -> (basin_num, forecast_length * i_e_time_length, nf)
                norm_data_3d = np.transpose(norm_data, (1, 0, 2, 3)).reshape(
                    basin_num, forecast_length * i_e_time_length, nf
                )

                # 创建新的时间坐标
                new_times = []
                for i in range(forecast_length):
                    if i < len(selected_data.coords["time"]):
                        new_times.extend(
                            selected_data.coords["time"][:i_e_time_length]
                        )

                # 确保时间坐标长度与数据匹配
                if len(new_times) > forecast_length * i_e_time_length:
                    new_times = new_times[: forecast_length * i_e_time_length]
                elif len(new_times) < forecast_length * i_e_time_length:
                    # 如果时间坐标不足,使用最后一个时间点填充
                    last_time = (
                        new_times[-1]
                        if new_times
                        else selected_data.coords["time"][0]
                    )
                    while len(new_times) < forecast_length * i_e_time_length:
                        new_times.append(last_time)

                # 修改这里:确保basin坐标长度与数据维度匹配
                if basin_num == 1 and len(selected_data.coords["basin"]) > 1:
                    basin_coord = selected_data.coords["basin"].values[:1]
                else:
                    basin_coord = selected_data.coords["basin"].values[:basin_num]

                coords = {
                    "basin": basin_coord,
                    "time": new_times,
                    "variable": selected_data.coords["variable"],
                }
        dims = ["basin", "time", "variable"]
    else:
        coords = selected_data.coords
        dims = selected_data.dims
        norm_data_3d = norm_data

    # create DataArray and inverse transform
    denorm_xr_ds = target_scaler.inverse_transform(
        xr.DataArray(
            norm_data_3d,
            dims=dims,
            coords=coords,
            attrs={"units": units},
        )
    )
    return set_unit_to_var(denorm_xr_ds)

BasinSingleFlowDataset (BaseDataset)

one time length output for each grid in a batch

Source code in torchhydro/datasets/data_sets.py
class BasinSingleFlowDataset(BaseDataset):
    """one time length output for each grid in a batch"""

    def __init__(self, cfgs: dict, is_tra_val_te: str):
        super(BasinSingleFlowDataset, self).__init__(cfgs, is_tra_val_te, **kwargs)

    def __getitem__(self, index):
        xc, ys = super(BasinSingleFlowDataset, self).__getitem__(index)
        y = ys[-1, :]
        return xc, y

DplDataset (BaseDataset)

pytorch dataset for Differential parameter learning

Source code in torchhydro/datasets/data_sets.py
class DplDataset(BaseDataset):
    """pytorch dataset for Differential parameter learning"""

    def __init__(self, cfgs: dict, is_tra_val_te: str):
        """
        Parameters
        ----------
        cfgs
            all configs
        is_tra_val_te
            train, vaild or test
        """
        super(DplDataset, self).__init__(cfgs, is_tra_val_te)
        # we don't use y_un_norm as its name because in the main function we will use "y"
        # For physical hydrological models, we need warmup, hence the target values should exclude data in warmup period
        self.warmup_length = self.training_cfgs["warmup_length"]
        self.target_as_input = self.data_cfgs["target_as_input"]
        self.constant_only = self.data_cfgs["constant_only"]
        if self.target_as_input and (not self.train_mode):
            # if the target is used as input and train_mode is False,
            # we need to get the target data in training period to generate pbm params
            self.train_dataset = DplDataset(cfgs, is_tra_val_te="train")

    def __getitem__(self, item):
        """
        Get one mini-batch for dPL (differential parameter learning) model

        TODO: not check target_as_input and constant_only cases yet

        Parameters
        ----------
        item
            index

        Returns
        -------
        tuple
            a mini-batch data;
            x_train (not normalized forcing), z_train (normalized data for DL model), y_train (not normalized output)
        """
        warmup = self.warmup_length
        rho = self.rho
        horizon = self.horizon
        xc_norm, _ = super(DplDataset, self).__getitem__(item)
        basin, time, _ = self.lookup_table[item]
        if self.target_as_input:
            # y_morn and xc_norm are concatenated and used for DL model
            y_norm = torch.from_numpy(
                self.y[basin, time - warmup : time + rho + horizon, :]
            ).float()
            # the order of xc_norm and y_norm matters, please be careful!
            z_train = torch.cat((xc_norm, y_norm), -1)
        elif self.constant_only:
            # only use attributes data for DL model
            z_train = torch.from_numpy(self.c[basin, :]).float()
        else:
            z_train = xc_norm.float()
        x_train = self.x_origin[basin, time - warmup : time + rho + horizon, :]
        y_train = self.y_origin[basin, time : time + rho + horizon, :]
        return (
            torch.from_numpy(x_train).float(),
            z_train,
        ), torch.from_numpy(y_train).float()

__getitem__(self, item) special

Get one mini-batch for dPL (differential parameter learning) model

TODO: not check target_as_input and constant_only cases yet

Parameters

item index

Returns

tuple a mini-batch data; x_train (not normalized forcing), z_train (normalized data for DL model), y_train (not normalized output)

Source code in torchhydro/datasets/data_sets.py
def __getitem__(self, item):
    """
    Get one mini-batch for dPL (differential parameter learning) model

    TODO: not check target_as_input and constant_only cases yet

    Parameters
    ----------
    item
        index

    Returns
    -------
    tuple
        a mini-batch data;
        x_train (not normalized forcing), z_train (normalized data for DL model), y_train (not normalized output)
    """
    warmup = self.warmup_length
    rho = self.rho
    horizon = self.horizon
    xc_norm, _ = super(DplDataset, self).__getitem__(item)
    basin, time, _ = self.lookup_table[item]
    if self.target_as_input:
        # y_morn and xc_norm are concatenated and used for DL model
        y_norm = torch.from_numpy(
            self.y[basin, time - warmup : time + rho + horizon, :]
        ).float()
        # the order of xc_norm and y_norm matters, please be careful!
        z_train = torch.cat((xc_norm, y_norm), -1)
    elif self.constant_only:
        # only use attributes data for DL model
        z_train = torch.from_numpy(self.c[basin, :]).float()
    else:
        z_train = xc_norm.float()
    x_train = self.x_origin[basin, time - warmup : time + rho + horizon, :]
    y_train = self.y_origin[basin, time : time + rho + horizon, :]
    return (
        torch.from_numpy(x_train).float(),
        z_train,
    ), torch.from_numpy(y_train).float()

__init__(self, cfgs, is_tra_val_te) special

Parameters

cfgs all configs is_tra_val_te train, vaild or test

Source code in torchhydro/datasets/data_sets.py
def __init__(self, cfgs: dict, is_tra_val_te: str):
    """
    Parameters
    ----------
    cfgs
        all configs
    is_tra_val_te
        train, vaild or test
    """
    super(DplDataset, self).__init__(cfgs, is_tra_val_te)
    # we don't use y_un_norm as its name because in the main function we will use "y"
    # For physical hydrological models, we need warmup, hence the target values should exclude data in warmup period
    self.warmup_length = self.training_cfgs["warmup_length"]
    self.target_as_input = self.data_cfgs["target_as_input"]
    self.constant_only = self.data_cfgs["constant_only"]
    if self.target_as_input and (not self.train_mode):
        # if the target is used as input and train_mode is False,
        # we need to get the target data in training period to generate pbm params
        self.train_dataset = DplDataset(cfgs, is_tra_val_te="train")

FlexibleDataset (BaseDataset)

A dataset whose datasources are from multiple sources according to the configuration

Source code in torchhydro/datasets/data_sets.py
class FlexibleDataset(BaseDataset):
    """A dataset whose datasources are from multiple sources according to the configuration"""

    def __init__(self, cfgs: dict, is_tra_val_te: str):
        super(FlexibleDataset, self).__init__(cfgs, is_tra_val_te)

    @property
    def data_source(self):
        source_cfgs = self.data_cfgs["source_cfgs"]
        return {
            name: data_sources_dict[name](path)
            for name, path in zip(
                source_cfgs["source_names"], source_cfgs["source_paths"]
            )
        }

    def _read_xyc(self):
        var_to_source_map = self.data_cfgs["var_to_source_map"]
        x_datasets, y_datasets, c_datasets = [], [], []
        gage_ids = self.t_s_dict["sites_id"]
        t_range = self.t_s_dict["t_final_range"]

        for var_name in var_to_source_map:
            source_name = var_to_source_map[var_name]
            data_source_ = self.data_source[source_name]
            if var_name in self.data_cfgs["relevant_cols"]:
                x_datasets.append(
                    data_source_.read_ts_xrdataset(gage_ids, t_range, [var_name])
                )
            elif var_name in self.data_cfgs["target_cols"]:
                y_datasets.append(
                    data_source_.read_ts_xrdataset(gage_ids, t_range, [var_name])
                )
            elif var_name in self.data_cfgs["constant_cols"]:
                c_datasets.append(
                    data_source_.read_attr_xrdataset(gage_ids, [var_name])
                )

        # 合并所有x, y, c类型的数据集
        x = xr.merge(x_datasets) if x_datasets else xr.Dataset()
        y = xr.merge(y_datasets) if y_datasets else xr.Dataset()
        c = xr.merge(c_datasets) if c_datasets else xr.Dataset()
        # Check if any flow variable exists in y dataset instead of hardcoding "streamflow"
        flow_var_name = (
            self.streamflow_name
            if hasattr(self, "streamflow_name") and self.streamflow_name in y
            else None
        )
        if flow_var_name is None:
            # fallback: check if any target variable is in y
            for target_var in self.data_cfgs["target_cols"]:
                if target_var in y:
                    flow_var_name = target_var
                    break
        if flow_var_name and flow_var_name in y:
            area = data_source_.camels.read_area(self.t_s_dict["sites_id"])
            y.update(streamflow_unit_conv(y[[flow_var_name]], area))
        x_origin, y_origin, c_origin = self._to_dataarray_with_unit(x, y, c)
        return x_origin, y_origin, c_origin

    def _normalize(self):
        var_to_source_map = self.data_cfgs["var_to_source_map"]
        for var_name in var_to_source_map:
            source_name = var_to_source_map[var_name]
            data_source_ = self.data_source[source_name]
            break
        # TODO: only support CAMELS for now
        scaler_hub = ScalerHub(
            self.y_origin,
            self.x_origin,
            self.c_origin,
            data_cfgs=self.data_cfgs,
            is_tra_val_te=self.is_tra_val_te,
            data_source=data_source_.camels,
        )
        self.target_scaler = scaler_hub.target_scaler
        return scaler_hub.x, scaler_hub.y, scaler_hub.c

FloodEventDataset (BaseDataset)

Dataset class for flood event detection and prediction tasks.

This dataset is specifically designed to handle flood event data where flood_event column contains binary indicators (0 for normal, non-zero for flood). It automatically creates a flood_mask from the flood_event data for special loss computation purposes.

The dataset reads data using SelfMadeHydroDataset from hydrodatasource, expecting CSV files with columns like: time, rain, inflow, flood_event.

Source code in torchhydro/datasets/data_sets.py
class FloodEventDataset(BaseDataset):
    """Dataset class for flood event detection and prediction tasks.

    This dataset is specifically designed to handle flood event data where
    flood_event column contains binary indicators (0 for normal, non-zero for flood).
    It automatically creates a flood_mask from the flood_event data for special
    loss computation purposes.

    The dataset reads data using SelfMadeHydroDataset from hydrodatasource,
    expecting CSV files with columns like: time, rain, inflow, flood_event.
    """

    def __init__(self, cfgs: dict, is_tra_val_te: str):
        """Initialize FloodEventDataset

        Parameters
        ----------
        cfgs : dict
            Configuration dictionary containing data_cfgs, training_cfgs, evaluation_cfgs
        is_tra_val_te : str
            One of 'train', 'valid', or 'test'
        """
        # Find flood_event column index for later processing
        target_cols = cfgs["data_cfgs"]["target_cols"]
        self.flood_event_idx = None
        for i, col in enumerate(target_cols):
            if "flood_event" in col.lower():
                self.flood_event_idx = i
                break

        if self.flood_event_idx is None:
            raise ValueError(
                "flood_event column not found in target_cols. Please ensure flood_event is included in the target columns."
            )
        super(FloodEventDataset, self).__init__(cfgs, is_tra_val_te)

    @property
    def noutputvar(self):
        """How many output variables in the dataset
        Used in evaluation.
        For flood datasets, the number of output variables is 2.
        But we don't need flood_mask in evaluation.

        Returns
        -------
        int
            number of variables
        """
        return len(self.data_cfgs["target_cols"]) - 1

    def _create_flood_mask(self, y):
        """Create flood mask from flood_event column

        Parameters
        ----------
        y : np.ndarray
            Target data with shape [seq_len, n_targets] containing flood_event column

        Returns
        -------
        np.ndarray
            Flood mask with shape [seq_len, 1] where 1 indicates flood event, 0 indicates normal
        """
        if self.flood_event_idx >= y.shape[1]:
            raise ValueError(
                f"flood_event_idx {self.flood_event_idx} exceeds target dimensions {y.shape[1]}"
            )

        # Extract flood_event column
        flood_events = y[:, self.flood_event_idx]

        # Create binary mask: 1 for flood (non-zero), 0 for normal (zero)
        no_flood_data = min(flood_events)
        flood_mask = (flood_events != no_flood_data).astype(np.float32)

        # Reshape to maintain dimension consistency
        flood_mask = flood_mask.reshape(-1, 1)

        return flood_mask

    def _create_lookup_table(self):
        """Create lookup table based on flood events with sliding window

        This method creates samples where:
        1. For each flood event sequence:
           - In training: use sliding window to generate samples with fixed length
           - In testing: use the entire flood event sequence as one sample with its actual length
        2. Each sample covers the full sequence length without internal structure division
        """
        lookup = []

        # Calculate total sample sequence length for training/validation
        sample_seqlen = self.warmup_length + self.rho + self.horizon

        for basin_idx in range(self.ngrid):
            # Get flood events for this basin
            flood_events = self.y_origin[basin_idx, :, self.flood_event_idx]

            # Find flood event sequences (consecutive non-zero values)
            flood_sequences = self._find_flood_sequences(flood_events)

            for seq_start, seq_end in flood_sequences:
                if self.is_new_batch_way:
                    # For test period, use the entire flood event sequence as one sample
                    # But we need to ensure the sample includes enough context (sample_seqlen)
                    flood_length = seq_end - seq_start + 1

                    # Calculate the start index to include enough context before the flood
                    # We want to include some data before the flood event starts
                    context_before = min(sample_seqlen - flood_length, seq_start)
                    context_before = max(context_before, 0)
                    # The actual start index should be early enough to provide context
                    actual_start = seq_start - context_before

                    # The total length should be at least sample_seqlen or the actual flood sequence length
                    total_length = max(sample_seqlen, flood_length + context_before)

                    # Ensure we don't exceed the data bounds
                    if actual_start + total_length > self.nt:
                        total_length = self.nt - actual_start

                    lookup.append((basin_idx, actual_start, total_length))
                else:
                    # For training, use sliding window approach
                    self._create_sliding_window_samples(
                        basin_idx, seq_start, seq_end, sample_seqlen, lookup
                    )

        self.lookup_table = dict(enumerate(lookup))
        self.num_samples = len(self.lookup_table)

    def _find_flood_sequences(self, flood_events):
        """Find sequences of consecutive flood events

        Parameters
        ----------
        flood_events : np.ndarray
            1D array of flood event indicators

        Returns
        -------
        list
            List of tuples (start_idx, end_idx) for each flood sequence
        """
        sequences = []
        in_sequence = False
        start_idx = None

        for i, event in enumerate(flood_events):
            if event > 0 and not in_sequence:
                # Start of a new flood sequence
                in_sequence = True
                start_idx = i
            elif event == 0 and in_sequence:
                # End of current flood sequence
                in_sequence = False
                sequences.append((start_idx, i - 1))
            elif i == len(flood_events) - 1 and in_sequence:
                # End of data while in sequence
                sequences.append((start_idx, i))

        return sequences

    def _create_sliding_window_samples(
        self, basin_idx, seq_start, seq_end, sample_seqlen, lookup
    ):
        """Create samples for a flood sequence using sliding window approach with data validity check

        Parameters
        ----------
        basin_idx : int
            Index of the basin
        seq_start : int
            Start index of flood sequence
        seq_end : int
            End index of flood sequence
        sample_seqlen : int
            Maximum length of each sample (warmup_length + rho + horizon)
        lookup : list
            List to append new samples to (basin_idx, actual_start, actual_length)
        """
        # Generate sliding window samples for this flood sequence
        # Each window should include at least some flood event data

        # Calculate the range where we can place the sliding window
        # The window end should not exceed the flood sequence end
        max_window_start = min(
            seq_end - sample_seqlen + 1, self.nt - sample_seqlen
        )  # Window end should not exceed seq_end or data bounds
        min_window_start = max(
            0, seq_start - sample_seqlen + 1
        )  # Window must include at least the first flood event

        # Ensure we have a valid range
        if max_window_start < min_window_start:
            return  # Skip this flood sequence if no valid window can be created

        # Generate samples with sliding window
        for window_start in range(min_window_start, max_window_start + 1):
            window_end = window_start + sample_seqlen - 1

            # Check if the window is valid (doesn't exceed data bounds and flood sequence)
            if window_end < self.nt and window_end <= seq_end:
                # Check if this window includes at least some flood events
                window_includes_flood = (window_start <= seq_end) and (
                    window_end >= seq_start
                )

                if window_includes_flood:
                    # Find the actual valid data range within this window closest to flood
                    actual_start, actual_length = self._find_valid_data_range(
                        basin_idx, window_start, window_end, seq_start, seq_end
                    )

                    # Only add sample if we have sufficient valid data
                    if (
                        actual_length >= self.rho + self.horizon
                    ):  # At least need rho + horizon
                        lookup.append((basin_idx, actual_start, actual_length))

    def _find_valid_data_range(
        self, basin_idx, window_start, window_end, flood_start, flood_end
    ):
        """Find the continuous valid data range closest to the flood sequence

        Parameters
        ----------
        basin_idx : int
            Basin index
        window_start : int
            Start of the window to check
        window_end : int
            End of the window to check
        flood_start : int
            Start index of the flood sequence
        flood_end : int
            End index of the flood sequence

        Returns
        -------
        tuple
            (actual_start, actual_length) of the valid data range closest to flood sequence
        """
        # Get data for this basin and window
        x_window = self.x[basin_idx, window_start : window_end + 1, :]

        # Check for NaN values in both input and output
        valid_mask = ~np.isnan(x_window).any(axis=1)  # Valid if no NaN in any feature

        # Find the continuous valid sequence closest to the flood sequence
        closest_start, closest_length = self._find_closest_valid_sequence(
            valid_mask, window_start, flood_start, flood_end
        )

        if closest_length <= 0:
            return window_start, 0
        return closest_start, closest_length

    def _find_closest_valid_sequence(
        self, valid_mask, window_start, flood_start, flood_end
    ):
        """Find the continuous valid sequence closest to the flood sequence

        Parameters
        ----------
        valid_mask : np.ndarray
            Boolean array indicating valid positions within the window
        window_start : int
            Start index of the window in the original time series
        flood_start : int
            Start index of the flood sequence in the original time series
        flood_end : int
            End index of the flood sequence in the original time series

        Returns
        -------
        tuple
            (closest_start, closest_length) in original time series coordinates
        """
        if not valid_mask.any():
            return window_start, 0

        # Find all continuous valid sequences within the window
        sequences = []
        current_start = None

        for i, is_valid in enumerate(valid_mask):
            if is_valid and current_start is None:
                current_start = i
            elif not is_valid and current_start is not None:
                sequences.append((current_start, i - current_start))
                current_start = None

        # Handle case where sequence continues to the end
        if current_start is not None:
            sequences.append((current_start, len(valid_mask) - current_start))

        if not sequences:
            return window_start, 0

        # If only one sequence, return it directly
        if len(sequences) == 1:
            seq_start_rel, seq_length = sequences[0]
            seq_start_abs = window_start + seq_start_rel
            return seq_start_abs, seq_length

        # Find the sequence closest to the flood sequence
        flood_center = (flood_start + flood_end) / 2
        closest_sequence = None
        min_distance = float("inf")

        for seq_start_rel, seq_length in sequences:
            seq_start_abs = window_start + seq_start_rel
            seq_end_abs = seq_start_abs + seq_length - 1
            seq_center = (seq_start_abs + seq_end_abs) / 2

            # Calculate distance from sequence center to flood center
            distance = abs(seq_center - flood_center)

            if distance < min_distance:
                min_distance = distance
                closest_sequence = (seq_start_abs, seq_length)

        return closest_sequence or (window_start, 0)

    def __getitem__(self, item: int):
        """Get one sample from the dataset with flood mask

        Returns samples with:
        1. Variable length sequences (no padding)
        2. Flood mask for weighted loss computation
        """
        basin, start_idx, actual_length = self.lookup_table[item]
        warmup_length = self.warmup_length
        end_idx = start_idx + actual_length

        # Get input and target data for the actual valid range
        x = self.x[basin, start_idx:end_idx, :]
        y = self.y[basin, start_idx + warmup_length : end_idx, :]

        # Create flood mask from flood_event column
        flood_mask = self._create_flood_mask(y)

        # Replace the original flood_event column with the new flood_mask
        y_with_flood_mask = y.copy()
        y_with_flood_mask[:, self.flood_event_idx] = flood_mask.squeeze()

        # Handle constant features if available
        if self.c is None or self.c.shape[-1] == 0:
            return (
                torch.from_numpy(x).float(),
                torch.from_numpy(y_with_flood_mask).float(),
            )

        # Add constant features to input
        c = self.c[basin, :]
        c = np.repeat(c, x.shape[0], axis=0).reshape(c.shape[0], -1).T
        xc = np.concatenate((x, c), axis=1)

        return torch.from_numpy(xc).float(), torch.from_numpy(y_with_flood_mask).float()

noutputvar property readonly

How many output variables in the dataset Used in evaluation. For flood datasets, the number of output variables is 2. But we don't need flood_mask in evaluation.

Returns

int number of variables

__getitem__(self, item) special

Get one sample from the dataset with flood mask

Returns samples with: 1. Variable length sequences (no padding) 2. Flood mask for weighted loss computation

Source code in torchhydro/datasets/data_sets.py
def __getitem__(self, item: int):
    """Get one sample from the dataset with flood mask

    Returns samples with:
    1. Variable length sequences (no padding)
    2. Flood mask for weighted loss computation
    """
    basin, start_idx, actual_length = self.lookup_table[item]
    warmup_length = self.warmup_length
    end_idx = start_idx + actual_length

    # Get input and target data for the actual valid range
    x = self.x[basin, start_idx:end_idx, :]
    y = self.y[basin, start_idx + warmup_length : end_idx, :]

    # Create flood mask from flood_event column
    flood_mask = self._create_flood_mask(y)

    # Replace the original flood_event column with the new flood_mask
    y_with_flood_mask = y.copy()
    y_with_flood_mask[:, self.flood_event_idx] = flood_mask.squeeze()

    # Handle constant features if available
    if self.c is None or self.c.shape[-1] == 0:
        return (
            torch.from_numpy(x).float(),
            torch.from_numpy(y_with_flood_mask).float(),
        )

    # Add constant features to input
    c = self.c[basin, :]
    c = np.repeat(c, x.shape[0], axis=0).reshape(c.shape[0], -1).T
    xc = np.concatenate((x, c), axis=1)

    return torch.from_numpy(xc).float(), torch.from_numpy(y_with_flood_mask).float()

__init__(self, cfgs, is_tra_val_te) special

Initialize FloodEventDataset

Parameters

cfgs : dict Configuration dictionary containing data_cfgs, training_cfgs, evaluation_cfgs is_tra_val_te : str One of 'train', 'valid', or 'test'

Source code in torchhydro/datasets/data_sets.py
def __init__(self, cfgs: dict, is_tra_val_te: str):
    """Initialize FloodEventDataset

    Parameters
    ----------
    cfgs : dict
        Configuration dictionary containing data_cfgs, training_cfgs, evaluation_cfgs
    is_tra_val_te : str
        One of 'train', 'valid', or 'test'
    """
    # Find flood_event column index for later processing
    target_cols = cfgs["data_cfgs"]["target_cols"]
    self.flood_event_idx = None
    for i, col in enumerate(target_cols):
        if "flood_event" in col.lower():
            self.flood_event_idx = i
            break

    if self.flood_event_idx is None:
        raise ValueError(
            "flood_event column not found in target_cols. Please ensure flood_event is included in the target columns."
        )
    super(FloodEventDataset, self).__init__(cfgs, is_tra_val_te)

FloodEventDplDataset (FloodEventDataset)

Dataset class for flood event detection and prediction with differential parameter learning support.

This dataset combines FloodEventDataset's flood event handling capabilities with DplDataset's data format for differential parameter learning (dPL) models. It handles flood event sequences and returns data in the format required for physical hydrological models with neural network components.

Source code in torchhydro/datasets/data_sets.py
class FloodEventDplDataset(FloodEventDataset):
    """Dataset class for flood event detection and prediction with differential parameter learning support.

    This dataset combines FloodEventDataset's flood event handling capabilities with
    DplDataset's data format for differential parameter learning (dPL) models.
    It handles flood event sequences and returns data in the format required for
    physical hydrological models with neural network components.
    """

    def __init__(self, cfgs: dict, is_tra_val_te: str):
        """Initialize FloodEventDplDataset

        Parameters
        ----------
        cfgs : dict
            Configuration dictionary containing data_cfgs, training_cfgs, evaluation_cfgs
        is_tra_val_te : str
            One of 'train', 'valid', or 'test'
        """
        super(FloodEventDplDataset, self).__init__(cfgs, is_tra_val_te)

        # Additional attributes for DPL functionality
        self.target_as_input = self.data_cfgs["target_as_input"]
        self.constant_only = self.data_cfgs["constant_only"]

        if self.target_as_input and (not self.train_mode):
            # if the target is used as input and train_mode is False,
            # we need to get the target data in training period to generate pbm params
            self.train_dataset = FloodEventDplDataset(cfgs, is_tra_val_te="train")

    def __getitem__(self, item: int):
        """Get one sample from the dataset in DPL format with flood mask

        Returns data in the format required for differential parameter learning:
        - x_train: not normalized forcing data
        - z_train: normalized data for DL model (with flood mask)
        - y_train: not normalized output data

        Parameters
        ----------
        item : int
            Index of the sample

        Returns
        -------
        tuple
            ((x_train, z_train), y_train) where:
            - x_train: torch.Tensor, not normalized forcing data
            - z_train: torch.Tensor, normalized data for DL model
            - y_train: torch.Tensor, not normalized output data with flood mask
        """
        basin, start_idx, actual_length = self.lookup_table[item]
        end_idx = start_idx + actual_length
        warmup_length = self.warmup_length
        # Get normalized data first (using parent's logic for flood mask)
        xc_norm, y_norm_with_mask = super(FloodEventDplDataset, self).__getitem__(item)

        # Get original (not normalized) data
        x_origin = self.x_origin[basin, start_idx:end_idx, :]
        y_origin = self.y_origin[basin, start_idx + warmup_length : end_idx, :]

        # Create flood mask for original y data
        flood_mask_origin = self._create_flood_mask(y_origin)
        y_origin_with_mask = y_origin.copy()
        y_origin_with_mask[:, self.flood_event_idx] = flood_mask_origin.squeeze()

        # Prepare z_train based on configuration
        if self.target_as_input:
            # y_norm and xc_norm are concatenated and used for DL model
            # the order of xc_norm and y_norm matters, please be careful!
            z_train = torch.cat((xc_norm, y_norm_with_mask), -1)
        elif self.constant_only:
            # only use attributes data for DL model
            if self.c is None or self.c.shape[-1] == 0:
                # If no constant features, use a zero tensor
                z_train = torch.zeros((actual_length, 1)).float()
            else:
                c = self.c[basin, :]
                # Repeat constants for the actual sequence length
                c_repeated = (
                    np.repeat(c, actual_length, axis=0).reshape(c.shape[0], -1).T
                )
                z_train = torch.from_numpy(c_repeated).float()
        else:
            # Use normalized input features with constants
            z_train = xc_norm.float()

        # Prepare x_train (original forcing data with constants if available)
        if self.c is None or self.c.shape[-1] == 0:
            x_train = torch.from_numpy(x_origin).float()
        else:
            c = self.c_origin[basin, :]
            c_repeated = np.repeat(c, actual_length, axis=0).reshape(c.shape[0], -1).T
            x_origin_with_c = np.concatenate((x_origin, c_repeated), axis=1)
            x_train = torch.from_numpy(x_origin_with_c).float()

        # y_train is the original output data with flood mask
        y_train = torch.from_numpy(y_origin_with_mask).float()

        return (x_train, z_train), y_train

__getitem__(self, item) special

Get one sample from the dataset in DPL format with flood mask

Returns data in the format required for differential parameter learning: - x_train: not normalized forcing data - z_train: normalized data for DL model (with flood mask) - y_train: not normalized output data

Parameters

item : int Index of the sample

Returns

tuple ((x_train, z_train), y_train) where: - x_train: torch.Tensor, not normalized forcing data - z_train: torch.Tensor, normalized data for DL model - y_train: torch.Tensor, not normalized output data with flood mask

Source code in torchhydro/datasets/data_sets.py
def __getitem__(self, item: int):
    """Get one sample from the dataset in DPL format with flood mask

    Returns data in the format required for differential parameter learning:
    - x_train: not normalized forcing data
    - z_train: normalized data for DL model (with flood mask)
    - y_train: not normalized output data

    Parameters
    ----------
    item : int
        Index of the sample

    Returns
    -------
    tuple
        ((x_train, z_train), y_train) where:
        - x_train: torch.Tensor, not normalized forcing data
        - z_train: torch.Tensor, normalized data for DL model
        - y_train: torch.Tensor, not normalized output data with flood mask
    """
    basin, start_idx, actual_length = self.lookup_table[item]
    end_idx = start_idx + actual_length
    warmup_length = self.warmup_length
    # Get normalized data first (using parent's logic for flood mask)
    xc_norm, y_norm_with_mask = super(FloodEventDplDataset, self).__getitem__(item)

    # Get original (not normalized) data
    x_origin = self.x_origin[basin, start_idx:end_idx, :]
    y_origin = self.y_origin[basin, start_idx + warmup_length : end_idx, :]

    # Create flood mask for original y data
    flood_mask_origin = self._create_flood_mask(y_origin)
    y_origin_with_mask = y_origin.copy()
    y_origin_with_mask[:, self.flood_event_idx] = flood_mask_origin.squeeze()

    # Prepare z_train based on configuration
    if self.target_as_input:
        # y_norm and xc_norm are concatenated and used for DL model
        # the order of xc_norm and y_norm matters, please be careful!
        z_train = torch.cat((xc_norm, y_norm_with_mask), -1)
    elif self.constant_only:
        # only use attributes data for DL model
        if self.c is None or self.c.shape[-1] == 0:
            # If no constant features, use a zero tensor
            z_train = torch.zeros((actual_length, 1)).float()
        else:
            c = self.c[basin, :]
            # Repeat constants for the actual sequence length
            c_repeated = (
                np.repeat(c, actual_length, axis=0).reshape(c.shape[0], -1).T
            )
            z_train = torch.from_numpy(c_repeated).float()
    else:
        # Use normalized input features with constants
        z_train = xc_norm.float()

    # Prepare x_train (original forcing data with constants if available)
    if self.c is None or self.c.shape[-1] == 0:
        x_train = torch.from_numpy(x_origin).float()
    else:
        c = self.c_origin[basin, :]
        c_repeated = np.repeat(c, actual_length, axis=0).reshape(c.shape[0], -1).T
        x_origin_with_c = np.concatenate((x_origin, c_repeated), axis=1)
        x_train = torch.from_numpy(x_origin_with_c).float()

    # y_train is the original output data with flood mask
    y_train = torch.from_numpy(y_origin_with_mask).float()

    return (x_train, z_train), y_train

__init__(self, cfgs, is_tra_val_te) special

Initialize FloodEventDplDataset

Parameters

cfgs : dict Configuration dictionary containing data_cfgs, training_cfgs, evaluation_cfgs is_tra_val_te : str One of 'train', 'valid', or 'test'

Source code in torchhydro/datasets/data_sets.py
def __init__(self, cfgs: dict, is_tra_val_te: str):
    """Initialize FloodEventDplDataset

    Parameters
    ----------
    cfgs : dict
        Configuration dictionary containing data_cfgs, training_cfgs, evaluation_cfgs
    is_tra_val_te : str
        One of 'train', 'valid', or 'test'
    """
    super(FloodEventDplDataset, self).__init__(cfgs, is_tra_val_te)

    # Additional attributes for DPL functionality
    self.target_as_input = self.data_cfgs["target_as_input"]
    self.constant_only = self.data_cfgs["constant_only"]

    if self.target_as_input and (not self.train_mode):
        # if the target is used as input and train_mode is False,
        # we need to get the target data in training period to generate pbm params
        self.train_dataset = FloodEventDplDataset(cfgs, is_tra_val_te="train")

GNNDataset (FloodEventDataset)

Optimized GNN Dataset for hydrological Graph Neural Network tasks.

This dataset extends FloodEventDataset to support Graph Neural Networks by: 1. Integrating station data via StationHydroDataset 2. Processing adjacency matrices with flexible edge weight and attribute handling 3. Merging basin-level features (xc) with station-level features (sxc) per node 4. Returning GNN-ready format: (sxc, y, edge_index, edge_attr)

Key Features: - Leverages BaseDataset's universal normalization and NaN handling for station data - Supports flexible edge weight selection (specify column or default to binary) - Always constructs edge_index and edge_attr for each basin - Merges basin and station features to create comprehensive node representations

Configuration keys in data_cfgs.gnn_cfgs: - station_cols: List of station variable names to load - station_rm_nan: Whether to remove/interpolate NaN values (default: True) - station_scaler_type: Scaler type for station data normalization - use_adjacency: Whether to load adjacency matrices (default: True) - adjacency_src_col: Source node column name (default: "ID") - adjacency_dst_col: Destination node column name (default: "NEXTDOWNID") - adjacency_edge_attr_cols: Columns for edge attributes (default: ["dist_hdn", "elev_diff", "strm_slope"]) - adjacency_weight_col: Column to use as edge weights (default: None for binary weights) - return_edge_weight: Whether to return edge_weight instead of edge_attr (default: False)

edge_attr : torch.Tensor Edge attributes [num_edges, edge_attr_dim]

Source code in torchhydro/datasets/data_sets.py
class GNNDataset(FloodEventDataset):
    """Optimized GNN Dataset for hydrological Graph Neural Network tasks.

    This dataset extends FloodEventDataset to support Graph Neural Networks by:
    1. Integrating station data via StationHydroDataset
    2. Processing adjacency matrices with flexible edge weight and attribute handling
    3. Merging basin-level features (xc) with station-level features (sxc) per node
    4. Returning GNN-ready format: (sxc, y, edge_index, edge_attr)

    Key Features:
    - Leverages BaseDataset's universal normalization and NaN handling for station data
    - Supports flexible edge weight selection (specify column or default to binary)
    - Always constructs edge_index and edge_attr for each basin
    - Merges basin and station features to create comprehensive node representations

    Configuration keys in data_cfgs.gnn_cfgs:
    - station_cols: List of station variable names to load
    - station_rm_nan: Whether to remove/interpolate NaN values (default: True)
    - station_scaler_type: Scaler type for station data normalization
    - use_adjacency: Whether to load adjacency matrices (default: True)
    - adjacency_src_col: Source node column name (default: "ID")
    - adjacency_dst_col: Destination node column name (default: "NEXTDOWNID")
    - adjacency_edge_attr_cols: Columns for edge attributes (default: ["dist_hdn", "elev_diff", "strm_slope"])
    - adjacency_weight_col: Column to use as edge weights (default: None for binary weights)
    - return_edge_weight: Whether to return edge_weight instead of edge_attr (default: False)

    edge_attr : torch.Tensor
        Edge attributes [num_edges, edge_attr_dim]
    """

    def __init__(self, cfgs: dict, is_tra_val_te: str):
        # Extract and extend configuration for station data
        self._extend_data_cfgs_for_stations(cfgs)

        # Store GNN-specific settings
        self.gnn_cfgs = cfgs["data_cfgs"].get("station_cfgs", {})

        # Initialize parent (this will call BaseDataset._load_data() automatically)
        super(GNNDataset, self).__init__(cfgs, is_tra_val_te)

        # Load adjacency data after main data processing
        self.adjacency_data = self._load_adjacency_data()

    def _extend_data_cfgs_for_stations(self, cfgs):
        """Extend data configuration to include station data as a standard data type

        This allows BaseDataset to handle station data using its universal processing pipeline.
        """
        data_cfgs = cfgs["data_cfgs"]
        gnn_cfgs = data_cfgs.get("station_cfgs", {})

        # Add station_cols to data configuration if specified(这个不见得非得有gnn_cfgs,正常应该是data_cfgs里面继续扩充的)
        if gnn_cfgs.get("station_cols"):
            data_cfgs["station_cols"] = gnn_cfgs["station_cols"]
            # Add station data processing settings to leverage BaseDataset pipeline
            data_cfgs["station_rm_nan"] = gnn_cfgs.get("station_rm_nan", True)

    def _read_xyc(self):
        """Read X, Y, C data including station data using unified approach

        This is the ONLY method we need to override from BaseDataset.
        All other processing (normalization, NaN handling, array conversion)
        is handled automatically by BaseDataset's pipeline.
        """
        # Read standard basin data using parent's logic
        data_dict = super(GNNDataset, self)._read_xyc()

        # Add station data if configured
        if self.data_cfgs.get("station_cols"):
            station_data = self._read_all_station_data()
            data_dict["station_cols"] = station_data

        return data_dict

    def _read_all_station_data(self):
        """Read station data for all basins using StationHydroDataset

        Creates xr.DataArray with the same structure as other data types
        so that BaseDataset can process it using the universal pipeline.
        """
        if not hasattr(self.data_source, "get_stations_by_basin"):
            LOGGER.warning(
                "Data source does not support station data, skipping station data reading"
            )
            return None

        # Convert basin IDs from "songliao_21100150" to "21100150" for StationHydroDataset
        basin_ids_with_prefix = self.t_s_dict["sites_id"]
        basin_ids = self._convert_basin_to_station_ids(basin_ids_with_prefix)
        t_range = self.t_s_dict["t_final_range"]

        # Collect station data for all basins
        all_station_data = []

        for basin_id in basin_ids:
            basin_station_data = self._read_basin_station_data(basin_id, t_range)
            all_station_data.append(basin_station_data)

        # Combine into unified xr.DataArray structure
        if all_station_data and any(data is not None for data in all_station_data):
            combined_station_data = self._combine_station_data_arrays(
                all_station_data, basin_ids
            )
            return combined_station_data
        else:
            return None

    def _read_basin_station_data(self, basin_id, t_range):
        """Read station data for a single basin, supporting multi-period case"""
        try:
            # Get stations for this basin
            station_ids = self.data_source.get_stations_by_basin(basin_id)

            if not station_ids:
                return None

            # Handle multi-period case
            if isinstance(t_range[0], (list, tuple)):
                # Validate that each period has exactly 2 elements (start and end date)
                for i, period in enumerate(t_range):
                    if not isinstance(period, (list, tuple)) or len(period) != 2:
                        raise ValueError(
                            f"Period {i} must be a list/tuple with exactly 2 elements (start_date, end_date), got: {period}"
                        )

                # Multi-period case - read and concatenate data
                all_station_data = None

                for start_date, end_date in t_range:
                    period_station_data = self.data_source.read_station_ts_xrdataset(
                        station_id_lst=station_ids,
                        t_range=[start_date, end_date],
                        var_lst=self.data_cfgs["station_cols"],
                        time_units=self.gnn_cfgs.get("station_time_units", ["1D"]),
                    )

                    if all_station_data is None:
                        all_station_data = period_station_data
                    else:
                        all_station_data = xr.concat(
                            [all_station_data, period_station_data], dim="time"
                        )

                station_data = all_station_data
            else:
                # Single period case (existing behavior)
                station_data = self.data_source.read_station_ts_xrdataset(
                    station_id_lst=station_ids,
                    t_range=t_range,
                    var_lst=self.data_cfgs["station_cols"],
                    time_units=self.gnn_cfgs.get("station_time_units", ["1D"]),
                )

            return self._process_station_xr_data(station_data)

        except Exception as e:
            LOGGER.warning(f"Could not read station data for basin {basin_id}: {e}")
            return None

    def _process_station_xr_data(self, station_data):
        """Process xarray station data into standard format"""
        if not station_data:
            return None

        # Handle multiple time units
        if isinstance(station_data, dict):
            # Use first available time unit
            time_unit = list(station_data.keys())[0]
            station_ds = station_data[time_unit]
        else:
            station_ds = station_data

        if not station_ds or not station_ds.sizes:
            return None

        # Convert to DataArray with standard format
        if isinstance(station_ds, xr.Dataset):
            station_da = station_ds.to_array(dim="variable")
            # Transpose to [time, station, variable]
            station_da = station_da.transpose("time", "station", "variable")
        else:
            station_da = station_ds

        return station_da

    def _combine_station_data_arrays(self, station_data_list, basin_ids):
        """Combine station data from all basins into a unified structure

        Creates an xr.DataArray with dimensions [basin, time, station, variable]
        similar to how other data types are structured in BaseDataset.
        """
        # Find common time dimension and data structure
        valid_data = [data for data in station_data_list if data is not None]
        if not valid_data:
            return None

        # Use time dimension from first valid dataset
        common_time = valid_data[0].coords["time"]

        # Find maximum number of stations and variables across all basins
        max_stations = max(data.sizes.get("station", 0) for data in valid_data)
        max_variables = max(data.sizes.get("variable", 0) for data in valid_data)

        # Create unified data array
        n_basins = len(basin_ids)
        n_time = len(common_time)

        # Initialize with NaN (BaseDataset will handle NaN processing)
        unified_data = np.full(
            (n_basins, n_time, max_stations, max_variables), np.nan, dtype=np.float32
        )

        # Fill with actual data
        for i, (basin_id, station_data) in enumerate(zip(basin_ids, station_data_list)):
            if station_data is not None:
                # Align time dimension
                try:
                    aligned_data = station_data.reindex(
                        time=common_time, method="nearest"
                    )
                    data_array = aligned_data.values

                    # Insert into unified array
                    n_stations_basin = data_array.shape[1]
                    n_vars_basin = data_array.shape[2]
                    unified_data[i, :, :n_stations_basin, :n_vars_basin] = data_array
                except Exception as e:
                    LOGGER.warning(
                        f"Failed to align station data for basin {basin_id}: {e}"
                    )
                    continue

        # Create xr.DataArray with proper coordinates
        station_coords = [f"station_{j}" for j in range(max_stations)]
        variable_coords = self.data_cfgs["station_cols"][:max_variables]

        station_da = xr.DataArray(
            unified_data,
            dims=["basin", "time", "station", "variable"],
            coords={
                "basin": basin_ids,
                "time": common_time,
                "station": station_coords,
                "variable": variable_coords,
            },
        )

        return station_da

    def _load_adjacency_data(self):
        """Load and process adjacency data from .nc files

        Returns
        -------
        dict
            Dictionary containing edge_index, edge_attr for each basin
        """
        if not self.gnn_cfgs.get("use_adjacency", True):
            return None

        if not hasattr(self.data_source, "read_adjacency_xrdataset"):
            LOGGER.warning("Data source does not support adjacency data")
            return None

        adjacency_data = {}
        # basin_ids = self.t_s_dict["sites_id"]
        # Convert basin IDs from "songliao_21100150" to "21100150" for StationHydroDataset
        basin_ids_with_prefix = self.t_s_dict["sites_id"]
        basin_ids = self._convert_basin_to_station_ids(basin_ids_with_prefix)

        for basin_id in basin_ids:
            try:
                # Read adjacency data from .nc file
                adj_df = self.data_source.read_adjacency_xrdataset(basin_id)

                if adj_df is None:
                    LOGGER.warning(
                        f"No adjacency data for basin {basin_id}, using self-loops"
                    )
                    adjacency_data[basin_id] = self._create_self_loop_adjacency(
                        basin_id
                    )
                else:
                    # Let _process_adjacency_dataframe handle the format checking and processing
                    adjacency_data[basin_id] = self._process_adjacency_dataframe(
                        adj_df, basin_id
                    )

            except Exception as e:
                LOGGER.warning(
                    f"Failed to load adjacency data for basin {basin_id}: {e}"
                )
                adjacency_data[basin_id] = self._create_self_loop_adjacency(basin_id)

        return adjacency_data

    def _process_adjacency_dataframe(self, adj_df, basin_id):
        """Process adjacency DataFrame into edge_index and edge_attr tensors

        Standard GNN processing: extract edges and their attributes from DataFrame or xarray Dataset.

        Parameters
        ----------
        adj_df : pd.DataFrame or xr.Dataset
            Adjacency DataFrame/Dataset with columns like ID, NEXTDOWNID, dist_hdn, elev_diff, strm_slope
        basin_id : str
            Basin identifier

        Returns
        -------
        dict
            Dictionary containing edge_index, edge_attr, edge_weight, num_nodes
        """
        import torch
        import pandas as pd
        import xarray as xr
        import numpy as np

        # Convert xarray Dataset to pandas DataFrame if needed
        if isinstance(adj_df, xr.Dataset):
            try:
                # Convert xarray Dataset to pandas DataFrame
                adj_df = adj_df.to_dataframe().reset_index()
                # LOGGER.info(f"Basin {basin_id}: Converted xarray Dataset to DataFrame with shape {adj_df.shape}")
                # LOGGER.info(f"Basin {basin_id}: DataFrame columns = {list(adj_df.columns)}")
            except Exception as e:
                LOGGER.error(
                    f"Basin {basin_id}: Failed to convert xarray Dataset to DataFrame: {e}"
                )
                return self._create_self_loop_adjacency(basin_id)

        # Configuration (simplified)
        src_col = self.gnn_cfgs.get("adjacency_src_col", "ID")
        dst_col = self.gnn_cfgs.get("adjacency_dst_col", "NEXTDOWNID")
        edge_attr_cols = self.gnn_cfgs.get(
            "adjacency_edge_attr_cols", ["dist_hdn", "elev_diff", "strm_slope"]
        )
        weight_col = self.gnn_cfgs.get("adjacency_weight_col", None)  # 新增:指定权重列
        # Check if required columns exist
        if src_col not in adj_df.columns:
            LOGGER.warning(
                f"Basin {basin_id}: Source column '{src_col}' not found in adjacency data. Available columns: {list(adj_df.columns)}"
            )
            return self._create_self_loop_adjacency(basin_id)

        if dst_col not in adj_df.columns:
            LOGGER.warning(
                f"Basin {basin_id}: Destination column '{dst_col}' not found in adjacency data. Available columns: {list(adj_df.columns)}"
            )
            return self._create_self_loop_adjacency(basin_id)

        # Clean and convert numeric columns to proper dtypes in batch
        # Handle string "nan" values that may come from NetCDF files
        numeric_cols = [
            col
            for col in edge_attr_cols + ([weight_col] if weight_col else [])
            if col in adj_df.columns
        ]
        if numeric_cols:
            # Batch replace string "nan" with actual NaN and convert to numeric
            adj_df[numeric_cols] = adj_df[numeric_cols].replace(
                ["nan", "NaN", "NAN"], np.nan
            )
            adj_df[numeric_cols] = adj_df[numeric_cols].apply(
                pd.to_numeric, errors="coerce"
            )
            LOGGER.debug(
                f"Basin {basin_id}: Converted {len(numeric_cols)} numeric columns in batch"
            )

        # Create comprehensive node mapping including all stations in the basin
        # First get all nodes that appear in adjacency matrix (connected nodes)
        connected_nodes = set(adj_df[src_col].dropna()) | set(adj_df[dst_col].dropna())

        # Then get all stations in this basin (including isolated nodes)
        try:
            if hasattr(self.data_source, "get_stations_by_basin"):
                all_basin_stations = self.data_source.get_stations_by_basin(basin_id)
                if all_basin_stations:
                    # Convert station IDs to strings to match adjacency data format
                    all_basin_nodes = set(
                        str(station_id) for station_id in all_basin_stations
                    )
                    # Combine connected nodes with all basin nodes
                    all_nodes = connected_nodes | all_basin_nodes
                    isolated_nodes = all_basin_nodes - connected_nodes
                    if isolated_nodes:
                        LOGGER.info(
                            f"Basin {basin_id}: Found {len(isolated_nodes)} isolated nodes: {isolated_nodes}"
                        )
                else:
                    all_nodes = connected_nodes
            else:
                # Fallback to only connected nodes if station data unavailable
                all_nodes = connected_nodes
        except Exception as e:
            LOGGER.warning(
                f"Basin {basin_id}: Failed to get all basin stations: {e}, using connected nodes only"
            )
            all_nodes = connected_nodes

        if len(all_nodes) == 0:
            LOGGER.warning(f"Basin {basin_id}: No valid nodes found")
            return self._create_self_loop_adjacency(basin_id)

        node_to_idx = {node: idx for idx, node in enumerate(sorted(all_nodes))}
        LOGGER.info(
            f"Basin {basin_id}: Found {len(all_nodes)} total nodes ({len(connected_nodes)} connected, {len(all_nodes) - len(connected_nodes)} isolated)"
        )

        # Extract edges and attributes using vectorized operations
        # First process edges from adjacency matrix
        valid_rows = adj_df.dropna(subset=[src_col, dst_col])
        edges_from_adj = []
        edge_attrs_from_adj = []
        edge_weights_from_adj = []

        if len(valid_rows) > 0:
            # Vectorized edge creation from adjacency matrix
            src_nodes = valid_rows[src_col].map(node_to_idx).values
            dst_nodes = valid_rows[dst_col].map(node_to_idx).values
            edges_from_adj = np.column_stack([src_nodes, dst_nodes])

            # Vectorized edge attributes extraction
            edge_attrs_list = []
            for col in edge_attr_cols:
                if col in valid_rows.columns:
                    attrs = valid_rows[col].fillna(0.0).values
                else:
                    attrs = np.zeros(len(valid_rows))
                edge_attrs_list.append(attrs)
            edge_attrs_from_adj = (
                np.column_stack(edge_attrs_list)
                if edge_attrs_list
                else np.zeros((len(valid_rows), len(edge_attr_cols)))
            )

            # Vectorized edge weights extraction
            if weight_col and weight_col in valid_rows.columns:
                edge_weights_from_adj = valid_rows[weight_col].fillna(1.0).values
            else:
                edge_weights_from_adj = np.ones(len(valid_rows))

        # Add self-loops for isolated nodes (nodes not in adjacency matrix)
        isolated_nodes = all_nodes - connected_nodes
        edges_from_isolated = []
        edge_attrs_from_isolated = []
        edge_weights_from_isolated = []

        if isolated_nodes:
            # Create self-loops for isolated nodes
            isolated_indices = [node_to_idx[node] for node in isolated_nodes]
            edges_from_isolated = np.column_stack([isolated_indices, isolated_indices])
            edge_attrs_from_isolated = np.zeros(
                (len(isolated_nodes), len(edge_attr_cols))
            )
            edge_weights_from_isolated = np.ones(len(isolated_nodes))

        # Combine edges from adjacency matrix and self-loops for isolated nodes
        if len(edges_from_adj) > 0 and len(edges_from_isolated) > 0:
            all_edges = np.vstack([edges_from_adj, edges_from_isolated])
            all_edge_attrs = np.vstack([edge_attrs_from_adj, edge_attrs_from_isolated])
            all_edge_weights = np.concatenate(
                [edge_weights_from_adj, edge_weights_from_isolated]
            )
        elif len(edges_from_adj) > 0:
            all_edges = edges_from_adj
            all_edge_attrs = edge_attrs_from_adj
            all_edge_weights = edge_weights_from_adj
        elif len(edges_from_isolated) > 0:
            all_edges = edges_from_isolated
            all_edge_attrs = edge_attrs_from_isolated
            all_edge_weights = edge_weights_from_isolated
        else:
            # Fallback: create self-loops for all nodes
            # LOGGER.warning(f"Basin {basin_id}: No edges found, creating self-loops for all nodes")
            n_nodes = len(all_nodes)
            node_indices = list(range(n_nodes))
            all_edges = np.column_stack([node_indices, node_indices])
            all_edge_attrs = np.zeros((n_nodes, len(edge_attr_cols)))
            all_edge_weights = np.ones(n_nodes)

        # Convert to tensors
        edge_index = torch.tensor(all_edges.T, dtype=torch.long).contiguous()
        edge_attr = (
            torch.tensor(all_edge_attrs, dtype=torch.float)
            if all_edge_attrs is not None
            else None
        )
        edge_weight = torch.tensor(all_edge_weights, dtype=torch.float)

        return {
            "edge_index": edge_index,
            "edge_attr": edge_attr,
            "edge_weight": edge_weight,  # 新增:单独的边权重张量
            "num_nodes": len(all_nodes),
            "node_to_idx": node_to_idx,
            "weight_col": weight_col,  # 记录使用的权重列
        }

    def _create_self_loop_adjacency(self, basin_id):
        """Create self-loop adjacency as fallback"""
        import torch

        try:
            # Try to get station count for this basin
            if hasattr(self.data_source, "get_stations_by_basin"):
                station_ids = self.data_source.get_stations_by_basin(basin_id)
                n_nodes = len(station_ids) if station_ids else 1
            else:
                n_nodes = 1
        except Exception:
            n_nodes = 1

        # Create self-loops: edge_index = [[0,1,2,...], [0,1,2,...]]
        edge_index = torch.arange(n_nodes).repeat(2, 1)

        # Create default edge attributes
        edge_attr_cols = self.gnn_cfgs.get(
            "adjacency_edge_attr_cols", ["dist_hdn", "elev_diff", "strm_slope"]
        )
        if edge_attr_cols:
            edge_attr = torch.zeros((n_nodes, len(edge_attr_cols)), dtype=torch.float)
        else:
            edge_attr = None

        # Create default edge weights (1.0 for self-loops)
        edge_weight = torch.ones(n_nodes, dtype=torch.float)

        return {
            "edge_index": edge_index,
            "edge_attr": edge_attr,
            "edge_weight": edge_weight,  # 新增:边权重
            "num_nodes": n_nodes,
            "node_to_idx": {i: i for i in range(n_nodes)},
            "weight_col": None,  # 自环情况下没有指定权重列
        }

    # GNN-specific utility methods
    def get_station_data(self, basin_idx):
        """Get station data for a specific basin

        Since station data is now processed by BaseDataset pipeline,
        it's available as self.station_cols (converted to numpy array).
        """
        if hasattr(self, "station_cols") and self.station_cols is not None:
            return self.station_cols[basin_idx]
        return None

    def get_adjacency_data(self, basin_idx):
        """Get adjacency data for a specific basin

        Returns
        -------
        dict or None
            Dictionary containing edge_index, edge_attr, edge_weight, etc. or None
        """
        if self.adjacency_data is None:
            return None

        # Get the specific basin ID for this basin index
        basin_id_with_prefix = self.t_s_dict["sites_id"][basin_idx]
        # Convert single basin ID to station ID (without prefix)
        basin_id = self._convert_basin_to_station_ids([basin_id_with_prefix])[0]
        return self.adjacency_data.get(basin_id)

    def get_edge_weight(self, basin_idx):
        """Get edge weights for a specific basin

        Parameters
        ----------
        basin_idx : int
            Basin index

        Returns
        -------
        torch.Tensor or None
            Edge weights tensor [num_edges] or None
        """
        adjacency_data = self.get_adjacency_data(basin_idx)
        if adjacency_data is not None:
            return adjacency_data.get("edge_weight")
        return None

    def _convert_basin_to_station_ids(self, basin_ids):
        """Convert basin IDs (with prefix) to station IDs (without prefix) for StationHydroDataset

        Parameters
        ----------
        basin_ids : list
            List of basin IDs with prefix (e.g., ["songliao_21100150"])

        Returns
        -------
        list
            List of station IDs without prefix (e.g., ["21100150"])
        """
        station_ids = []
        for basin_id in basin_ids:
            # Remove common prefixes
            if "_" in basin_id:
                # Extract the part after the last underscore
                station_id = basin_id.split("_")[-1]
            else:
                # If no underscore, use the original ID
                station_id = basin_id
            station_ids.append(station_id)
        return station_ids

    def _convert_station_to_basin_ids(self, station_ids, prefix="songliao"):
        """Convert station IDs (without prefix) to basin IDs (with prefix) for consistency

        Parameters
        ----------
        station_ids : list
            List of station IDs without prefix (e.g., ["21100150"])
        prefix : str
            Prefix to add (default: "songliao")

        Returns
        -------
        list
            List of basin IDs with prefix (e.g., ["songliao_21100150"])
        """
        basin_ids = []
        for station_id in station_ids:
            basin_id = f"{prefix}_{station_id}"
            basin_ids.append(basin_id)
        return basin_ids

    def __getitem__(self, item: int):
        """Get one sample with GNN-specific data format.

        This method merges basin-level features (xc) into each station node's
        features (sxc), so each node's input includes both station and basin
        attributes.

        Args:
            item: The index of the sample to retrieve.

        Returns:
            A tuple of (sxc, y, edge_index, edge_weight) where:
            - sxc: Station features merged with basin features.
                   Shape: [num_stations, seq_length, feature_dim]
            - y: Target values for prediction.
                 Shape: [forecast_length, output_dim]
            - edge_index: Edge connectivity.
                          Shape: [2, num_edges]
            - edge_weight: Edge weights.
                           Shape: [num_edges]
        """
        import torch
        import numpy as np

        # Get basic sample from parent (includes flood mask if FloodEventDataset)
        basic_sample = super(GNNDataset, self).__getitem__(item)

        # Extract x, y from parent's output
        if isinstance(basic_sample, tuple):
            x, y = (
                basic_sample  # x: [seq_length, x_feature_dim], y_full: [full_length, y_feature_dim]
            )
        elif isinstance(basic_sample, dict):
            x = basic_sample.get("x")
            y = basic_sample.get("y")
        else:
            raise ValueError(f"Unexpected basic_sample format: {type(basic_sample)}")

        # Get sample metadata
        basin, time_idx, actual_length = self.lookup_table[item]

        # For GNN prediction, we only need the forecast part of y as target
        # The structure should be: warmup + hindcast (rho) + forecast (horizon)
        # We only predict the forecast (horizon) part

        # Get station data for current basin and time window
        station_data = self.get_station_data(basin)  # [time, station, variable]
        adjacency_data = self.get_adjacency_data(basin)

        # Extract station data for the time window (input sequence)
        if station_data is not None:
            # For station data, we need the input sequence (not just forecast part)
            seq_end = time_idx + actual_length
            sxc_raw = station_data[
                time_idx:seq_end
            ]  # [seq_length, num_stations, station_feature_dim]
        else:
            # If no station data, create dummy station data
            LOGGER.warning(
                f"No station data for basin {basin}, using single dummy station"
            )
            dummy_station_features = 1  # Number of dummy features
            sxc_raw = np.zeros(
                (actual_length, 1, dummy_station_features)
            )  # [seq_length, 1, 1]

        # Get basin-level features (xc) for merging
        # x contains basin-level features, we need to replicate it for each station
        if x is not None and x.ndim >= 2:
            xc = x  # [seq_length, basin_feature_dim]
            basin_feature_dim = xc.shape[-1]
            seq_length, num_stations, station_feature_dim = sxc_raw.shape

            # Replicate basin features for each station and concatenate with station features
            # xc expanded: [seq_length, 1, basin_feature_dim] -> [seq_length, num_stations, basin_feature_dim]
            xc_expanded = np.tile(xc[:, np.newaxis, :], (1, num_stations, 1))

            # Concatenate station features with basin features
            # sxc_temp: [seq_length, num_stations, station_feature_dim + basin_feature_dim]
            sxc_temp = np.concatenate([sxc_raw, xc_expanded], axis=-1)

            # Transpose to get desired shape: [num_stations, seq_length, feature_dim]
            sxc = sxc_temp.transpose(1, 0, 2)
        else:
            # If no basin features, use only station features and transpose
            # sxc: [num_stations, seq_length, station_feature_dim]
            sxc = sxc_raw.transpose(1, 0, 2)

        # Process adjacency data (GNN edge orientation handled here)
        # Edge orientation logic: support 'upstream', 'downstream', 'bidirectional' (default: downstream)
        edge_orientation = self.gnn_cfgs.get("edge_orientation", "downstream")
        if adjacency_data is not None:
            edge_index = adjacency_data["edge_index"]  # [2, num_edges]
            edge_attr = adjacency_data[
                "edge_attr"
            ]  # [num_edges, edge_attr_dim] or None
            edge_weight = adjacency_data.get("edge_weight")  # [num_edges]
            # If edge_weight is None, fill with ones (all edges weight=1)
            if edge_weight is None:
                num_edges = edge_index.shape[1]
                edge_weight = torch.ones(num_edges, dtype=torch.float)

            # Edge orientation handling
            if edge_orientation == "downstream":
                # Reverse all edges: swap source and target
                edge_index = edge_index[[1, 0], :]
            elif edge_orientation == "bidirectional":
                # Add reversed edges to make bidirectional
                edge_index_rev = edge_index[[1, 0], :]
                edge_index = torch.cat([edge_index, edge_index_rev], dim=1)
                if edge_attr is not None:
                    edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
                if edge_weight is not None:
                    edge_weight = torch.cat([edge_weight, edge_weight], dim=0)
            # else: downstream (default), do nothing

        else:
            # Default: self-loops for each station
            num_stations = sxc.shape[
                0
            ]  # Now sxc is [num_stations, seq_length, feature_dim]
            edge_index = torch.arange(num_stations).repeat(2, 1)
            edge_attr = None
            edge_weight = torch.ones(num_stations, dtype=torch.float)  # 默认权重为1

        # Ensure edge_attr has proper shape
        if edge_attr is None:
            num_edges = edge_index.shape[1]
            edge_attr_dim = len(
                self.gnn_cfgs.get(
                    "adjacency_edge_attr_cols", ["dist_hdn", "elev_diff", "strm_slope"]
                )
            )
            edge_attr = torch.zeros((num_edges, edge_attr_dim), dtype=torch.float)

        # Convert to tensors if needed
        if not isinstance(sxc, torch.Tensor):
            sxc = torch.tensor(sxc, dtype=torch.float)
        if not isinstance(y, torch.Tensor):
            y = torch.tensor(y, dtype=torch.float)

        return sxc, y, edge_index, edge_weight  # edge_attr

__getitem__(self, item) special

Get one sample with GNN-specific data format.

This method merges basin-level features (xc) into each station node's features (sxc), so each node's input includes both station and basin attributes.

Parameters:

Name Type Description Default
item int

The index of the sample to retrieve.

required

Returns:

Type Description
A tuple of (sxc, y, edge_index, edge_weight) where
  • sxc: Station features merged with basin features. Shape: [num_stations, seq_length, feature_dim]
  • y: Target values for prediction. Shape: [forecast_length, output_dim]
  • edge_index: Edge connectivity. Shape: [2, num_edges]
  • edge_weight: Edge weights. Shape: [num_edges]
Source code in torchhydro/datasets/data_sets.py
def __getitem__(self, item: int):
    """Get one sample with GNN-specific data format.

    This method merges basin-level features (xc) into each station node's
    features (sxc), so each node's input includes both station and basin
    attributes.

    Args:
        item: The index of the sample to retrieve.

    Returns:
        A tuple of (sxc, y, edge_index, edge_weight) where:
        - sxc: Station features merged with basin features.
               Shape: [num_stations, seq_length, feature_dim]
        - y: Target values for prediction.
             Shape: [forecast_length, output_dim]
        - edge_index: Edge connectivity.
                      Shape: [2, num_edges]
        - edge_weight: Edge weights.
                       Shape: [num_edges]
    """
    import torch
    import numpy as np

    # Get basic sample from parent (includes flood mask if FloodEventDataset)
    basic_sample = super(GNNDataset, self).__getitem__(item)

    # Extract x, y from parent's output
    if isinstance(basic_sample, tuple):
        x, y = (
            basic_sample  # x: [seq_length, x_feature_dim], y_full: [full_length, y_feature_dim]
        )
    elif isinstance(basic_sample, dict):
        x = basic_sample.get("x")
        y = basic_sample.get("y")
    else:
        raise ValueError(f"Unexpected basic_sample format: {type(basic_sample)}")

    # Get sample metadata
    basin, time_idx, actual_length = self.lookup_table[item]

    # For GNN prediction, we only need the forecast part of y as target
    # The structure should be: warmup + hindcast (rho) + forecast (horizon)
    # We only predict the forecast (horizon) part

    # Get station data for current basin and time window
    station_data = self.get_station_data(basin)  # [time, station, variable]
    adjacency_data = self.get_adjacency_data(basin)

    # Extract station data for the time window (input sequence)
    if station_data is not None:
        # For station data, we need the input sequence (not just forecast part)
        seq_end = time_idx + actual_length
        sxc_raw = station_data[
            time_idx:seq_end
        ]  # [seq_length, num_stations, station_feature_dim]
    else:
        # If no station data, create dummy station data
        LOGGER.warning(
            f"No station data for basin {basin}, using single dummy station"
        )
        dummy_station_features = 1  # Number of dummy features
        sxc_raw = np.zeros(
            (actual_length, 1, dummy_station_features)
        )  # [seq_length, 1, 1]

    # Get basin-level features (xc) for merging
    # x contains basin-level features, we need to replicate it for each station
    if x is not None and x.ndim >= 2:
        xc = x  # [seq_length, basin_feature_dim]
        basin_feature_dim = xc.shape[-1]
        seq_length, num_stations, station_feature_dim = sxc_raw.shape

        # Replicate basin features for each station and concatenate with station features
        # xc expanded: [seq_length, 1, basin_feature_dim] -> [seq_length, num_stations, basin_feature_dim]
        xc_expanded = np.tile(xc[:, np.newaxis, :], (1, num_stations, 1))

        # Concatenate station features with basin features
        # sxc_temp: [seq_length, num_stations, station_feature_dim + basin_feature_dim]
        sxc_temp = np.concatenate([sxc_raw, xc_expanded], axis=-1)

        # Transpose to get desired shape: [num_stations, seq_length, feature_dim]
        sxc = sxc_temp.transpose(1, 0, 2)
    else:
        # If no basin features, use only station features and transpose
        # sxc: [num_stations, seq_length, station_feature_dim]
        sxc = sxc_raw.transpose(1, 0, 2)

    # Process adjacency data (GNN edge orientation handled here)
    # Edge orientation logic: support 'upstream', 'downstream', 'bidirectional' (default: downstream)
    edge_orientation = self.gnn_cfgs.get("edge_orientation", "downstream")
    if adjacency_data is not None:
        edge_index = adjacency_data["edge_index"]  # [2, num_edges]
        edge_attr = adjacency_data[
            "edge_attr"
        ]  # [num_edges, edge_attr_dim] or None
        edge_weight = adjacency_data.get("edge_weight")  # [num_edges]
        # If edge_weight is None, fill with ones (all edges weight=1)
        if edge_weight is None:
            num_edges = edge_index.shape[1]
            edge_weight = torch.ones(num_edges, dtype=torch.float)

        # Edge orientation handling
        if edge_orientation == "downstream":
            # Reverse all edges: swap source and target
            edge_index = edge_index[[1, 0], :]
        elif edge_orientation == "bidirectional":
            # Add reversed edges to make bidirectional
            edge_index_rev = edge_index[[1, 0], :]
            edge_index = torch.cat([edge_index, edge_index_rev], dim=1)
            if edge_attr is not None:
                edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
            if edge_weight is not None:
                edge_weight = torch.cat([edge_weight, edge_weight], dim=0)
        # else: downstream (default), do nothing

    else:
        # Default: self-loops for each station
        num_stations = sxc.shape[
            0
        ]  # Now sxc is [num_stations, seq_length, feature_dim]
        edge_index = torch.arange(num_stations).repeat(2, 1)
        edge_attr = None
        edge_weight = torch.ones(num_stations, dtype=torch.float)  # 默认权重为1

    # Ensure edge_attr has proper shape
    if edge_attr is None:
        num_edges = edge_index.shape[1]
        edge_attr_dim = len(
            self.gnn_cfgs.get(
                "adjacency_edge_attr_cols", ["dist_hdn", "elev_diff", "strm_slope"]
            )
        )
        edge_attr = torch.zeros((num_edges, edge_attr_dim), dtype=torch.float)

    # Convert to tensors if needed
    if not isinstance(sxc, torch.Tensor):
        sxc = torch.tensor(sxc, dtype=torch.float)
    if not isinstance(y, torch.Tensor):
        y = torch.tensor(y, dtype=torch.float)

    return sxc, y, edge_index, edge_weight  # edge_attr

get_adjacency_data(self, basin_idx)

Get adjacency data for a specific basin

Returns

dict or None Dictionary containing edge_index, edge_attr, edge_weight, etc. or None

Source code in torchhydro/datasets/data_sets.py
def get_adjacency_data(self, basin_idx):
    """Get adjacency data for a specific basin

    Returns
    -------
    dict or None
        Dictionary containing edge_index, edge_attr, edge_weight, etc. or None
    """
    if self.adjacency_data is None:
        return None

    # Get the specific basin ID for this basin index
    basin_id_with_prefix = self.t_s_dict["sites_id"][basin_idx]
    # Convert single basin ID to station ID (without prefix)
    basin_id = self._convert_basin_to_station_ids([basin_id_with_prefix])[0]
    return self.adjacency_data.get(basin_id)

get_edge_weight(self, basin_idx)

Get edge weights for a specific basin

Parameters

basin_idx : int Basin index

Returns

torch.Tensor or None Edge weights tensor [num_edges] or None

Source code in torchhydro/datasets/data_sets.py
def get_edge_weight(self, basin_idx):
    """Get edge weights for a specific basin

    Parameters
    ----------
    basin_idx : int
        Basin index

    Returns
    -------
    torch.Tensor or None
        Edge weights tensor [num_edges] or None
    """
    adjacency_data = self.get_adjacency_data(basin_idx)
    if adjacency_data is not None:
        return adjacency_data.get("edge_weight")
    return None

get_station_data(self, basin_idx)

Get station data for a specific basin

Since station data is now processed by BaseDataset pipeline, it's available as self.station_cols (converted to numpy array).

Source code in torchhydro/datasets/data_sets.py
def get_station_data(self, basin_idx):
    """Get station data for a specific basin

    Since station data is now processed by BaseDataset pipeline,
    it's available as self.station_cols (converted to numpy array).
    """
    if hasattr(self, "station_cols") and self.station_cols is not None:
        return self.station_cols[basin_idx]
    return None

ObsForeDataset (BaseDataset)

处理观测和预见期数据的混合数据集

这个类专门用于处理具有双维度预见期数据格式的数据集,其中 lead_time 和 time 都是独立维度。 适合表示不同发布时间对不同目标日期的预报。

Source code in torchhydro/datasets/data_sets.py
class ObsForeDataset(BaseDataset):
    """处理观测和预见期数据的混合数据集

    这个类专门用于处理具有双维度预见期数据格式的数据集,其中 lead_time 和 time 都是独立维度。
    适合表示不同发布时间对不同目标日期的预报。
    """

    def __init__(self, cfgs: dict, is_tra_val_te: str):
        """初始化观测和预见期混合数据集

        Parameters
        ----------
        cfgs : dict
            all configs
        is_tra_val_te : str
            指定是训练集、验证集还是测试集
        """
        # 调用父类初始化方法
        super(ObsForeDataset, self).__init__(cfgs, is_tra_val_te)
        # for each batch, we fix length of hindcast and forecast length.
        # data from different lead time with a number representing the lead time,
        # for example, now is 2020-09-30, our min_time_interval is 1 day, hindcast length is 30 and forecast length is 1,
        # lead_time = 3 means 2020-09-01 to 2020-09-30, and the forecast data is 2020-10-01 from 2020-09-28
        # for forecast data, we have two different configurations:
        # 1st, we can set a same lead time for all forecast time
        # 2020-09-30now, 30hindcast, 2forecast, 3leadtime means 2020-09-01 to 2020-09-30 obs concatenate with 2020-10-01 forecast data from 2020-09-28 and 2020-10-02 forecast data from 2020-09-29
        # 2nd, we can set a increasing lead time for each forecast time
        # 2020-09-30now, 30hindcast, 2forecast, [1, 2]leadtime means 2020-09-01 to 2020-09-30 obs concatenate with 2020-10-01 to 2010-10-02 forecast data from 2020-09-30
        self.lead_time_type = self.training_cfgs["lead_time_type"]
        if self.lead_time_type not in ["fixed", "increasing"]:
            raise ValueError(
                "lead_time_type must be one of 'fixed' or 'increasing', "
                f"but got {self.lead_time_type}"
            )
        self.lead_time_start = self.training_cfgs["lead_time_start"]
        horizon = self.horizon
        offset = np.zeros((horizon,), dtype=int)
        if self.lead_time_type == "fixed":
            offset = offset + self.lead_time_start
        elif self.lead_time_type == "increasing":
            offset = offset + np.arange(
                self.lead_time_start, self.lead_time_start + horizon
            )
        self.horizon_offset = offset
        feature_mapping = self.data_cfgs["feature_mapping"]
        #
        xf_var_indices = {}
        for obs_var, fore_var in feature_mapping.items():
            # 找到x中需要被替换的变量索引
            x_var_indice = [
                i
                for i, var in enumerate(self.data_cfgs["relevant_cols"])
                if var == obs_var
            ][0]
            # 找到f中对应的变量索引
            f_var_indice = [
                i
                for i, var in enumerate(self.data_cfgs["forecast_cols"])
                if var == fore_var
            ][0]
            xf_var_indices[x_var_indice] = f_var_indice
        self.xf_var_indices = xf_var_indices

    def _read_xyc_specified_time(self, start_date, end_date, **kwargs):
        """read f data from data source with specified time range and add it to the whole dict"""
        data_dict = super(ObsForeDataset, self)._read_xyc_specified_time(
            start_date, end_date
        )
        lead_time = kwargs.get("lead_time", None)
        f_origin = self.data_source.read_ts_xrdataset(
            self.t_s_dict["sites_id"],
            [start_date, end_date],
            self.data_cfgs["forecast_cols"],
            forecast_mode=True,
            lead_time=lead_time,
        )
        f_origin_ = self._rm_timeunit_key(f_origin)
        f_data = self._trans2da_and_setunits(f_origin_)
        data_dict["forecast_cols"] = f_data.transpose(
            "basin", "time", "lead_step", "variable"
        )
        return data_dict

    def __getitem__(self, item: int):
        """Get a sample from the dataset

        Parameters
        ----------
        item : int
            index of sample

        Returns
        -------
        tuple
            A pair of (x, y) data, where x contains input features and lead time flags,
            and y contains target values
        """
        # train mode
        basin, idx, _ = self.lookup_table[item]
        warmup_length = self.warmup_length
        # for x, we only chose data before horizon, but we may need forecast data for not all variables
        # hence, to avoid nan values for some variables without forecast in horizon
        # we still get data from the first time to the end of horizon
        x = self.x[basin, idx - warmup_length : idx + self.rho + self.horizon, :]
        # for y, we chose data after warmup_length
        y = self.y[basin, idx : idx + self.rho + self.horizon, :]
        # use offset to get forecast data
        offset = self.horizon_offset
        if self.lead_time_type == "fixed":
            # Fixed lead_time mode - All forecast steps use the same lead_step
            f = self.f[
                basin, idx + self.rho : idx + self.rho + self.horizon, offset[0], :
            ]
        else:
            # Increasing lead_time mode - Each forecast step uses a different lead_step
            f = self.f[basin, idx + self.rho, offset, :]
        xf = self._concat_xf(x, f)
        if self.c is None or self.c.shape[-1] == 0:
            xfc = xf
        else:
            c = self.c[basin, :]
            c = np.repeat(c, xf.shape[0], axis=0).reshape(c.shape[0], -1).T
            xfc = np.concatenate((xf, c), axis=1)

        return torch.from_numpy(xfc).float(), torch.from_numpy(y).float()

    def _concat_xf(self, x, f):
        # Create a copy of x to avoid modifying the original data
        x_combined = x.copy()

        # Iterate through the variable mapping relationship
        for x_idx, f_idx in self.xf_var_indices.items():
            # Replace the variables in the forecast period of x with the forecast variables in f
            # The forecast period of x starts from the rho position
            x_combined[self.warmup_length + self.rho :, x_idx] = f[:, f_idx]

        return x_combined

__getitem__(self, item) special

Get a sample from the dataset

Parameters

item : int index of sample

Returns

tuple A pair of (x, y) data, where x contains input features and lead time flags, and y contains target values

Source code in torchhydro/datasets/data_sets.py
def __getitem__(self, item: int):
    """Get a sample from the dataset

    Parameters
    ----------
    item : int
        index of sample

    Returns
    -------
    tuple
        A pair of (x, y) data, where x contains input features and lead time flags,
        and y contains target values
    """
    # train mode
    basin, idx, _ = self.lookup_table[item]
    warmup_length = self.warmup_length
    # for x, we only chose data before horizon, but we may need forecast data for not all variables
    # hence, to avoid nan values for some variables without forecast in horizon
    # we still get data from the first time to the end of horizon
    x = self.x[basin, idx - warmup_length : idx + self.rho + self.horizon, :]
    # for y, we chose data after warmup_length
    y = self.y[basin, idx : idx + self.rho + self.horizon, :]
    # use offset to get forecast data
    offset = self.horizon_offset
    if self.lead_time_type == "fixed":
        # Fixed lead_time mode - All forecast steps use the same lead_step
        f = self.f[
            basin, idx + self.rho : idx + self.rho + self.horizon, offset[0], :
        ]
    else:
        # Increasing lead_time mode - Each forecast step uses a different lead_step
        f = self.f[basin, idx + self.rho, offset, :]
    xf = self._concat_xf(x, f)
    if self.c is None or self.c.shape[-1] == 0:
        xfc = xf
    else:
        c = self.c[basin, :]
        c = np.repeat(c, xf.shape[0], axis=0).reshape(c.shape[0], -1).T
        xfc = np.concatenate((xf, c), axis=1)

    return torch.from_numpy(xfc).float(), torch.from_numpy(y).float()

__init__(self, cfgs, is_tra_val_te) special

初始化观测和预见期混合数据集

Parameters

cfgs : dict all configs is_tra_val_te : str 指定是训练集、验证集还是测试集

Source code in torchhydro/datasets/data_sets.py
def __init__(self, cfgs: dict, is_tra_val_te: str):
    """初始化观测和预见期混合数据集

    Parameters
    ----------
    cfgs : dict
        all configs
    is_tra_val_te : str
        指定是训练集、验证集还是测试集
    """
    # 调用父类初始化方法
    super(ObsForeDataset, self).__init__(cfgs, is_tra_val_te)
    # for each batch, we fix length of hindcast and forecast length.
    # data from different lead time with a number representing the lead time,
    # for example, now is 2020-09-30, our min_time_interval is 1 day, hindcast length is 30 and forecast length is 1,
    # lead_time = 3 means 2020-09-01 to 2020-09-30, and the forecast data is 2020-10-01 from 2020-09-28
    # for forecast data, we have two different configurations:
    # 1st, we can set a same lead time for all forecast time
    # 2020-09-30now, 30hindcast, 2forecast, 3leadtime means 2020-09-01 to 2020-09-30 obs concatenate with 2020-10-01 forecast data from 2020-09-28 and 2020-10-02 forecast data from 2020-09-29
    # 2nd, we can set a increasing lead time for each forecast time
    # 2020-09-30now, 30hindcast, 2forecast, [1, 2]leadtime means 2020-09-01 to 2020-09-30 obs concatenate with 2020-10-01 to 2010-10-02 forecast data from 2020-09-30
    self.lead_time_type = self.training_cfgs["lead_time_type"]
    if self.lead_time_type not in ["fixed", "increasing"]:
        raise ValueError(
            "lead_time_type must be one of 'fixed' or 'increasing', "
            f"but got {self.lead_time_type}"
        )
    self.lead_time_start = self.training_cfgs["lead_time_start"]
    horizon = self.horizon
    offset = np.zeros((horizon,), dtype=int)
    if self.lead_time_type == "fixed":
        offset = offset + self.lead_time_start
    elif self.lead_time_type == "increasing":
        offset = offset + np.arange(
            self.lead_time_start, self.lead_time_start + horizon
        )
    self.horizon_offset = offset
    feature_mapping = self.data_cfgs["feature_mapping"]
    #
    xf_var_indices = {}
    for obs_var, fore_var in feature_mapping.items():
        # 找到x中需要被替换的变量索引
        x_var_indice = [
            i
            for i, var in enumerate(self.data_cfgs["relevant_cols"])
            if var == obs_var
        ][0]
        # 找到f中对应的变量索引
        f_var_indice = [
            i
            for i, var in enumerate(self.data_cfgs["forecast_cols"])
            if var == fore_var
        ][0]
        xf_var_indices[x_var_indice] = f_var_indice
    self.xf_var_indices = xf_var_indices

detect_date_format(date_str)

检测日期格式,支持单个字符串或字符串列表

Parameters

date_str : str or list 日期字符串或日期字符串列表

Returns

str 检测到的日期格式

Source code in torchhydro/datasets/data_sets.py
def detect_date_format(date_str):
    """
    检测日期格式,支持单个字符串或字符串列表

    Parameters
    ----------
    date_str : str or list
        日期字符串或日期字符串列表

    Returns
    -------
    str
        检测到的日期格式
    """
    # 如果输入是列表,使用第一个元素
    if isinstance(date_str, (list, tuple)):
        if not date_str:  # 如果列表为空
            raise ValueError("Empty date list")
        date_str = date_str[0]  # 使用第一个日期字符串

    # 确保输入是字符串
    if not isinstance(date_str, str):
        raise ValueError(
            f"Date must be string or list of strings, got {type(date_str)}"
        )

    # 尝试不同的日期格式
    for date_format in DATE_FORMATS:
        try:
            datetime.strptime(date_str, date_format)
            return date_format
        except ValueError:
            continue

    raise ValueError(f"Unknown date format: {date_str}")

data_sources

Author: Wenyu Ouyang Date: 2024-04-02 14:37:09 LastEditTime: 2025-11-08 09:58:42 LastEditors: Wenyu Ouyang Description: A module for different data sources FilePath: orchhydro orchhydro\datasets\data_sources.py Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.

data_utils

Author: Wenyu Ouyang Date: 2023-09-21 15:37:58 LastEditTime: 2025-07-13 15:46:09 LastEditors: Wenyu Ouyang Description: Some basic funtions for dealing with data FilePath: orchhydro orchhydro\datasets\data_utils.py Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.

choose_basins_with_area(gages, usgs_ids, smallest_area, largest_area)

choose basins with not too large or too small area

Parameters

gages Camels, CamelsSeries, Gages or GagesPro object !!! usgs_ids "list" given sites' ids smallest_area lower limit; unit is km2 largest_area upper limit; unit is km2

Returns

list sites_chosen: [] -- ids of chosen gages

Source code in torchhydro/datasets/data_utils.py
def choose_basins_with_area(
    gages,
    usgs_ids: list,
    smallest_area: float,
    largest_area: float,
) -> list:
    """
    choose basins with not too large or too small area

    Parameters
    ----------
    gages
        Camels, CamelsSeries, Gages or GagesPro object
    usgs_ids: list
        given sites' ids
    smallest_area
        lower limit; unit is km2
    largest_area
        upper limit; unit is km2

    Returns
    -------
    list
        sites_chosen: [] -- ids of chosen gages

    """
    basins_areas = gages.read_basin_area(usgs_ids).flatten()
    sites_index = np.arange(len(usgs_ids))
    sites_chosen = np.ones(len(usgs_ids))
    for i in range(sites_index.size):
        # loop for every site
        if basins_areas[i] < smallest_area or basins_areas[i] > largest_area:
            sites_chosen[sites_index[i]] = 0
        else:
            sites_chosen[sites_index[i]] = 1
    return [usgs_ids[i] for i in range(len(sites_chosen)) if sites_chosen[i] > 0]

choose_sites_in_ecoregion(gages, site_ids, ecoregion)

Choose sites in ecoregions

Parameters

gages : Gages Only gages dataset has ecoregion attribute site_ids : list all ids of sites ecoregion : Union[list, tuple] which ecoregions

Returns

list chosen sites' ids

Raises

NotImplementedError PLease choose 'ECO2_CODE' or 'ECO3_CODE' NotImplementedError must be in EC02 code list NotImplementedError must be in EC03 code list

Source code in torchhydro/datasets/data_utils.py
def choose_sites_in_ecoregion(
    gages, site_ids: list, ecoregion: Union[list, tuple]
) -> list:
    """
    Choose sites in ecoregions

    Parameters
    ----------
    gages : Gages
        Only gages dataset has ecoregion attribute
    site_ids : list
        all ids of sites
    ecoregion : Union[list, tuple]
        which ecoregions

    Returns
    -------
    list
        chosen sites' ids

    Raises
    ------
    NotImplementedError
        PLease choose 'ECO2_CODE' or 'ECO3_CODE'
    NotImplementedError
        must be in EC02 code list
    NotImplementedError
        must be in EC03 code list
    """
    if ecoregion[0] not in ["ECO2_CODE", "ECO3_CODE"]:
        raise NotImplementedError("PLease choose 'ECO2_CODE' or 'ECO3_CODE'")
    if ecoregion[0] == "ECO2_CODE":
        ec02_code_lst = [
            5.2,
            5.3,
            6.2,
            7.1,
            8.1,
            8.2,
            8.3,
            8.4,
            8.5,
            9.2,
            9.3,
            9.4,
            9.5,
            9.6,
            10.1,
            10.2,
            10.4,
            11.1,
            12.1,
            13.1,
        ]
        if ecoregion[1] not in ec02_code_lst:
            raise NotImplementedError(
                f"No such EC02 code, please choose from {ec02_code_lst}"
            )
        attr_name = "ECO2_BAS_DOM"
    elif ecoregion[1] in np.arange(1, 85):
        attr_name = "ECO3_BAS_DOM"
    else:
        raise NotImplementedError("No such EC03 code, please choose from 1 - 85")
    attr_lst = [attr_name]
    data_attr = gages.read_constant_cols(site_ids, attr_lst)
    eco_names = data_attr[:, 0]
    return [site_ids[i] for i in range(eco_names.size) if eco_names[i] == ecoregion[1]]

dam_num_chosen(gages, usgs_id, dam_num)

choose basins of dams

Source code in torchhydro/datasets/data_utils.py
def dam_num_chosen(gages, usgs_id, dam_num):
    """choose basins of dams"""
    assert all(x < y for x, y in zip(usgs_id, usgs_id[1:]))
    attr_lst = ["NDAMS_2009"]
    data_attr = gages.read_constant_cols(usgs_id, attr_lst)
    return (
        [
            usgs_id[i]
            for i in range(data_attr.size)
            if dam_num[0] <= data_attr[:, 0][i] < dam_num[1]
        ]
        if type(dam_num) == list
        else [
            usgs_id[i] for i in range(data_attr.size) if data_attr[:, 0][i] == dam_num
        ]
    )

dor_reservoirs_chosen(gages, usgs_id, dor_chosen)

choose basins of small DOR(calculated by NOR_STORAGE/RUNAVE7100)

Source code in torchhydro/datasets/data_utils.py
def dor_reservoirs_chosen(gages, usgs_id, dor_chosen) -> list:
    """
    choose basins of small DOR(calculated by NOR_STORAGE/RUNAVE7100)

    """

    dors = get_dor_values(gages, usgs_id)
    if type(dor_chosen) in [list, tuple]:
        # right half-open range
        chosen_id = [
            usgs_id[i]
            for i in range(dors.size)
            if dor_chosen[0] <= dors[i] < dor_chosen[1]
        ]
    elif dor_chosen < 0:
        chosen_id = [usgs_id[i] for i in range(dors.size) if dors[i] < -dor_chosen]
    else:
        chosen_id = [usgs_id[i] for i in range(dors.size) if dors[i] >= dor_chosen]

    assert all(x < y for x, y in zip(chosen_id, chosen_id[1:]))
    return chosen_id

set_unit_to_var(ds)

returned xa.Dataset need has units for each variable -- xr.DataArray or the dataset cannot be saved to netCDF file

Parameters

ds : xr.Dataset the dataset with units as attributes

Returns

ds : xr.Dataset unit attrs are for each variable dataarray

Source code in torchhydro/datasets/data_utils.py
def set_unit_to_var(ds):
    """returned xa.Dataset need has units for each variable -- xr.DataArray
    or the dataset cannot be saved to netCDF file

    Parameters
    ----------
    ds : xr.Dataset
        the dataset with units as attributes

    Returns
    -------
    ds : xr.Dataset
        unit attrs are for each variable dataarray
    """
    units_dict = ds.attrs["units"]
    for var_name, units in units_dict.items():
        if var_name in ds:
            ds[var_name].attrs["units"] = units
    if "units" in ds.attrs:
        del ds.attrs["units"]
    return ds

unify_streamflow_unit(ds, area=None, inverse=False)

Unify the unit of xr_dataset to be mm/day in a basin or inverse

Parameters

!!! ds "xarray dataset" description !!! area area of each basin

Returns

type description

Source code in torchhydro/datasets/data_utils.py
def unify_streamflow_unit(ds: xr.Dataset, area=None, inverse=False):
    """Unify the unit of xr_dataset to be mm/day in a basin or inverse

    Parameters
    ----------
    ds: xarray dataset
        _description_
    area:
        area of each basin

    Returns
    -------
    _type_
        _description_
    """
    # use pint to convert unit
    if not inverse:
        target_unit = "mm/d"
        q = ds.pint.quantify()
        a = area.pint.quantify()
        r = q[list(q.keys())[0]] / a[list(a.keys())[0]]
        result = r.pint.to(target_unit).to_dataset(name=list(q.keys())[0])
    else:
        target_unit = "m^3/s"
        r = ds.pint.quantify()
        a = area.pint.quantify()
        q = r[list(r.keys())[0]] * a[list(a.keys())[0]]
        # q = q.pint.quantify()
        result = q.pint.to(target_unit).to_dataset(name=list(r.keys())[0])
    # dequantify to get normal xr_dataset
    return result.pint.dequantify()

warn_if_nan(dataarray, max_display=5, nan_mode='any', data_name='')

Issue a warning if the dataarray contains any NaN values and display their locations.

Parameters

!!! dataarray "xr.DataArray" Input dataarray to check for NaN values. !!! max_display "int" Maximum number of NaN locations to display in the warning. !!! nan_mode "str" Mode of NaN checking: 'any' means if any NaNs exist return True, if all values are NaNs raise ValueError 'all' means if all values are NaNs return True !!! data_name "str" Name of the dataarray to be displayed in the warning message.

Source code in torchhydro/datasets/data_utils.py
def warn_if_nan(dataarray, max_display=5, nan_mode="any", data_name=""):
    """
    Issue a warning if the dataarray contains any NaN values and display their locations.

    Parameters
    -----------
    dataarray: xr.DataArray
        Input dataarray to check for NaN values.
    max_display: int
        Maximum number of NaN locations to display in the warning.
    nan_mode: str
        Mode of NaN checking:
        'any' means if any NaNs exist return True, if all values are NaNs raise ValueError
        'all' means if all values are NaNs return True
    data_name: str
        Name of the dataarray to be displayed in the warning message.
    """
    if dataarray is None:
        raise ValueError("The dataarray is None!")
    if nan_mode not in ["any", "all"]:
        raise ValueError("nan_mode must be 'any' or 'all'")

    if np.all(np.isnan(dataarray.values)):
        if nan_mode == "any":
            raise ValueError("The dataarray contains only NaN values!")
        else:
            return True

    nan_indices = np.argwhere(np.isnan(dataarray.values))
    total_nans = len(nan_indices)

    if total_nans <= 0:
        return False
    message = f"The {data_name} dataarray contains {total_nans} NaN values!"

    # Displaying only the first few NaN locations if there are too many
    display_indices = nan_indices[:max_display].tolist()
    message += (
        f" Here are the indices of the first {max_display} NaNs: {display_indices}..."
        if total_nans > max_display
        else f" Here are the indices of the NaNs: {display_indices}"
    )
    warnings.warn(message)

    return True

wrap_t_s_dict(data_cfgs, is_tra_val_te)

Basins and periods

Parameters

data_cfgs configs for reading from data source is_tra_val_te train, valid or test

Returns

OrderedDict OrderedDict(sites_id=basins_id, t_final_range=t_range_list)

Source code in torchhydro/datasets/data_utils.py
def wrap_t_s_dict(data_cfgs: dict, is_tra_val_te: str) -> OrderedDict:
    """
    Basins and periods

    Parameters
    ----------
    data_cfgs
        configs for reading from data source
    is_tra_val_te
        train, valid or test

    Returns
    -------
    OrderedDict
        OrderedDict(sites_id=basins_id, t_final_range=t_range_list)
    """
    basins_id = data_cfgs["object_ids"]
    if type(basins_id) is str and basins_id == "ALL":
        raise ValueError("Please specify the basins_id in configs!")
    if any(x >= y for x, y in zip(basins_id, basins_id[1:])):
        # raise a warning if the basins_id is not sorted
        warnings.warn("The basins_id is not sorted!")
    if f"t_range_{is_tra_val_te}" in data_cfgs:
        t_range_list = data_cfgs[f"t_range_{is_tra_val_te}"]
    else:
        raise KeyError(f"Error! The mode {is_tra_val_te} was not found. Please add it.")
    return OrderedDict(sites_id=basins_id, t_final_range=t_range_list)

sampler

Author: Wenyu Ouyang Date: 2023-09-25 08:21:27 LastEditTime: 2025-07-13 15:47:53 LastEditors: Wenyu Ouyang Description: Some sampling class or functions FilePath: orchhydro orchhydro\datasets\sampler.py Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.

BasinBatchSampler (Sampler)

A custom sampler for hydrological modeling that iterates over a dataset in a way tailored for batches of hydrological data. It ensures that each batch contains data from a single randomly selected 'basin' out of several basins, with batches constructed to respect the specified batch size and the unique characteristics of hydrological datasets. TODO: made by Xinzhuo Wu, maybe need to be tested more

Parameters

dataset : BaseDataset The dataset to sample from, expected to have a data_cfgs attribute. num_samples : Optional[int], default=None The total number of samples to draw (optional). generator : Optional[torch.Generator] A PyTorch Generator object for random number generation (optional).

The sampler divides the dataset by the number of basins, then iterates through each basin's range in shuffled order, ensuring non-overlapping, basin-specific batches suitable for models that predict hydrological outcomes.

Source code in torchhydro/datasets/sampler.py
class BasinBatchSampler(Sampler[int]):
    """
    A custom sampler for hydrological modeling that iterates over a dataset in
    a way tailored for batches of hydrological data. It ensures that each batch
    contains data from a single randomly selected 'basin' out of several basins,
    with batches constructed to respect the specified batch size and the unique
    characteristics of hydrological datasets.
    TODO: made by Xinzhuo Wu, maybe need to be tested more

    Parameters
    ----------
    dataset : BaseDataset
        The dataset to sample from, expected to have a `data_cfgs` attribute.
    num_samples : Optional[int], default=None
        The total number of samples to draw (optional).
    generator : Optional[torch.Generator]
        A PyTorch Generator object for random number generation (optional).

    The sampler divides the dataset by the number of basins, then iterates through
    each basin's range in shuffled order, ensuring non-overlapping, basin-specific
    batches suitable for models that predict hydrological outcomes.
    """

    def __init__(
        self,
        dataset,
        num_samples: Optional[int] = None,
        generator=None,
    ) -> None:
        self.dataset = dataset
        self._num_samples = num_samples
        self.generator = generator

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError(
                f"num_samples should be a positive integer value, but got num_samples={self.num_samples}"
            )

    @property
    def num_samples(self) -> int:
        return len(self.dataset)

    def __iter__(self) -> Iterator[int]:
        n = self.dataset.training_cfgs["batch_size"]
        basin_number = len(self.dataset.data_cfgs["object_ids"])
        basin_range = len(self.dataset) // basin_number
        if n > basin_range:
            raise ValueError(
                f"batch_size should equal or less than basin_range={basin_range} "
            )

        if self.generator is None:
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            generator = torch.Generator()
            generator.manual_seed(seed)
        else:
            generator = self.generator

        # basin_list = torch.randperm(basin_number)
        # for select_basin in basin_list:
        #     x = torch.randperm(basin_range)
        #     for i in range(0, basin_range, n):
        #         yield from (x[i : i + n] + basin_range * select_basin.item()).tolist()
        x = torch.randperm(self.num_samples)
        for i in range(0, self.num_samples, n):
            yield from (x[i : i + n]).tolist()

    def __len__(self) -> int:
        return self.num_samples

KuaiSampler (RandomSampler)

Source code in torchhydro/datasets/sampler.py
class KuaiSampler(RandomSampler):
    def __init__(
        self,
        dataset,
        batch_size,
        warmup_length,
        rho_horizon,
        ngrid,
        nt,
    ):
        """a sampler from Kuai Fang's paper: https://doi.org/10.1002/2017GL075619
           He used a random pick-up that we don't need to iterate all samples.
           Then, we can train model more quickly

        Parameters
        ----------
        dataset : torch.utils.data.Dataset
            just a object of dataset class inherited from torch.utils.data.Dataset
        batch_size : int
            we need batch_size to calculate the number of samples in an epoch
        warmup_length : int
            warmup length, typically for physical hydrological models
        rho_horizon : int
            sequence length of a mini-batch, for encoder-decoder models, rho+horizon, for decoder-only models, horizon
        ngrid : int
            number of basins
        nt : int
            number of all periods
        """
        while batch_size * rho_horizon >= ngrid * nt:
            # try to use a smaller batch_size to make the model runnable
            batch_size = int(batch_size / 10)
        batch_size = max(batch_size, 1)
        # 99% chance that all periods' data are used in an epoch
        n_iter_ep = int(
            np.ceil(
                np.log(0.01)
                / np.log(1 - batch_size * rho_horizon / ngrid / (nt - warmup_length))
            )
        )
        assert n_iter_ep >= 1
        # __len__ means the number of all samples, then, the number of loops in an epoch is __len__()/batch_size = n_iter_ep
        # hence we return n_iter_ep * batch_size
        num_samples = n_iter_ep * batch_size
        super(KuaiSampler, self).__init__(dataset, num_samples=num_samples)

__init__(self, dataset, batch_size, warmup_length, rho_horizon, ngrid, nt) special

a sampler from Kuai Fang's paper: https://doi.org/10.1002/2017GL075619 He used a random pick-up that we don't need to iterate all samples. Then, we can train model more quickly

Parameters

dataset : torch.utils.data.Dataset just a object of dataset class inherited from torch.utils.data.Dataset batch_size : int we need batch_size to calculate the number of samples in an epoch warmup_length : int warmup length, typically for physical hydrological models rho_horizon : int sequence length of a mini-batch, for encoder-decoder models, rho+horizon, for decoder-only models, horizon ngrid : int number of basins nt : int number of all periods

Source code in torchhydro/datasets/sampler.py
def __init__(
    self,
    dataset,
    batch_size,
    warmup_length,
    rho_horizon,
    ngrid,
    nt,
):
    """a sampler from Kuai Fang's paper: https://doi.org/10.1002/2017GL075619
       He used a random pick-up that we don't need to iterate all samples.
       Then, we can train model more quickly

    Parameters
    ----------
    dataset : torch.utils.data.Dataset
        just a object of dataset class inherited from torch.utils.data.Dataset
    batch_size : int
        we need batch_size to calculate the number of samples in an epoch
    warmup_length : int
        warmup length, typically for physical hydrological models
    rho_horizon : int
        sequence length of a mini-batch, for encoder-decoder models, rho+horizon, for decoder-only models, horizon
    ngrid : int
        number of basins
    nt : int
        number of all periods
    """
    while batch_size * rho_horizon >= ngrid * nt:
        # try to use a smaller batch_size to make the model runnable
        batch_size = int(batch_size / 10)
    batch_size = max(batch_size, 1)
    # 99% chance that all periods' data are used in an epoch
    n_iter_ep = int(
        np.ceil(
            np.log(0.01)
            / np.log(1 - batch_size * rho_horizon / ngrid / (nt - warmup_length))
        )
    )
    assert n_iter_ep >= 1
    # __len__ means the number of all samples, then, the number of loops in an epoch is __len__()/batch_size = n_iter_ep
    # hence we return n_iter_ep * batch_size
    num_samples = n_iter_ep * batch_size
    super(KuaiSampler, self).__init__(dataset, num_samples=num_samples)

fl_sample_basin(dataset)

Sample one basin data as a client from a dataset for federated learning

Parameters

dataset dataset

Returns
1
dict of image index
Source code in torchhydro/datasets/sampler.py
def fl_sample_basin(dataset: BaseDataset):
    """
    Sample one basin data as a client from a dataset for federated learning

    Parameters
    ----------
    dataset
        dataset

    Returns
    -------
        dict of image index
    """
    lookup_table = dataset.lookup_table
    basins = dataset.basins
    # one basin is one user
    num_users = len(basins)
    # set group for basins
    basin_groups = defaultdict(list)
    for idx, (basin, date) in lookup_table.items():
        basin_groups[basin].append(idx)

    # one user is one basin
    user_basins = defaultdict(list)
    for i, basin in enumerate(basins):
        user_id = i % num_users
        user_basins[user_id].append(basin)

    # a lookup_table subset for each user
    user_lookup_tables = {}
    for user_id, basins in user_basins.items():
        user_lookup_table = {}
        for basin in basins:
            for idx in basin_groups[basin]:
                user_lookup_table[idx] = lookup_table[idx]
        user_lookup_tables[user_id] = user_lookup_table

    return user_lookup_tables

fl_sample_region(dataset)

Sample one region data as a client from a dataset for federated learning

TODO: not finished

Source code in torchhydro/datasets/sampler.py
def fl_sample_region(dataset: BaseDataset):
    """
    Sample one region data as a client from a dataset for federated learning

    TODO: not finished

    """
    num_users = 10
    num_shards, num_imgs = 200, 250
    idx_shard = list(range(num_shards))
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards * num_imgs)
    # labels = dataset.train_labels.numpy()
    labels = np.array(dataset.train_labels)

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # divide and assign
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand * num_imgs : (rand + 1) * num_imgs]), axis=0
            )
    return dict_users