Skip to content

Models API

Author: Wenyu Ouyang Date: 2023-07-11 17:39:09 LastEditTime: 2023-07-11 20:40:37 LastEditors: Wenyu Ouyang Description: FilePath: \HydroTL\hydrotl\models__init__.py Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.

ann

Author: Wenyu Ouyang Date: 2021-12-17 18:02:27 LastEditTime: 2025-06-25 14:07:58 LastEditors: Wenyu Ouyang Description: ANN model FilePath: orchhydro orchhydro\modelsnn.py Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.

Mlp (SimpleAnn)

Source code in torchhydro/models/ann.py
class Mlp(SimpleAnn):
    def __init__(
        self,
        nx: int,
        ny: int,
        hidden_size: Union[int, tuple, list] = None,
        dr: Union[float, tuple, list] = 0.0,
        activation: str = "relu",
    ):
        """
        MLP model inherited from SimpleAnn, using activation + dropout after each layer.
        The final layer also goes through activation+dropout if there's a corresponding
        dropout layer in dropout_list.
        """
        if type(dr) is float:
            if type(hidden_size) in [tuple, list]:
                dr = [dr] * (len(hidden_size) + 1)
            elif hidden_size is not None and hidden_size > 0:
                dr = [dr] * 2
        super(Mlp, self).__init__(
            nx=nx,
            ny=ny,
            hidden_size=hidden_size,
            dr=dr,
            activation=activation,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass with activation followed by dropout for each layer in linear_list.
        The number of dropout layers must match or be exactly one less than
        the number of linear layers.
        """
        # Raise an error if the number of linear layers and dropout layers do not match
        # and do not satisfy "number of linear layers = number of dropout layers + 1"
        if len(self.linear_list) != len(self.dropout_list) and (
            len(self.linear_list) - 1
        ) != len(self.dropout_list):
            raise ValueError(
                "Mlp: linear_list and dropout_list sizes do not match. "
                "They either have the same length or linear_list has exactly one more."
            )

        out = x
        for i, layer in enumerate(self.linear_list):
            out = layer(out)
            out = self.activation(out)
            if i < len(self.dropout_list):
                out = self.dropout_list[i](out)

        return out

__init__(self, nx, ny, hidden_size=None, dr=0.0, activation='relu') special

MLP model inherited from SimpleAnn, using activation + dropout after each layer. The final layer also goes through activation+dropout if there's a corresponding dropout layer in dropout_list.

Source code in torchhydro/models/ann.py
def __init__(
    self,
    nx: int,
    ny: int,
    hidden_size: Union[int, tuple, list] = None,
    dr: Union[float, tuple, list] = 0.0,
    activation: str = "relu",
):
    """
    MLP model inherited from SimpleAnn, using activation + dropout after each layer.
    The final layer also goes through activation+dropout if there's a corresponding
    dropout layer in dropout_list.
    """
    if type(dr) is float:
        if type(hidden_size) in [tuple, list]:
            dr = [dr] * (len(hidden_size) + 1)
        elif hidden_size is not None and hidden_size > 0:
            dr = [dr] * 2
    super(Mlp, self).__init__(
        nx=nx,
        ny=ny,
        hidden_size=hidden_size,
        dr=dr,
        activation=activation,
    )

forward(self, x)

Forward pass with activation followed by dropout for each layer in linear_list. The number of dropout layers must match or be exactly one less than the number of linear layers.

Source code in torchhydro/models/ann.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass with activation followed by dropout for each layer in linear_list.
    The number of dropout layers must match or be exactly one less than
    the number of linear layers.
    """
    # Raise an error if the number of linear layers and dropout layers do not match
    # and do not satisfy "number of linear layers = number of dropout layers + 1"
    if len(self.linear_list) != len(self.dropout_list) and (
        len(self.linear_list) - 1
    ) != len(self.dropout_list):
        raise ValueError(
            "Mlp: linear_list and dropout_list sizes do not match. "
            "They either have the same length or linear_list has exactly one more."
        )

    out = x
    for i, layer in enumerate(self.linear_list):
        out = layer(out)
        out = self.activation(out)
        if i < len(self.dropout_list):
            out = self.dropout_list[i](out)

    return out

SimpleAnn (Module)

Source code in torchhydro/models/ann.py
class SimpleAnn(nn.Module):
    def __init__(
        self,
        nx: int,
        ny: int,
        hidden_size: Union[int, tuple, list] = None,
        dr: Union[float, tuple, list] = 0.0,
        activation: str = "relu",
    ):
        """
        A simple multi-layer NN model with final linear layer

        Parameters
        ----------
        nx
            number of input neurons
        ny
            number of output neurons
        hidden_size
            a list/tuple which contains number of neurons in each hidden layer;
            if int, only one hidden layer except for hidden_size=0
        dr
            dropout rate of layers, default is 0.0 which means no dropout;
            here we set number of dropout layers to (number of nn layers - 1)
        activation
            activation function for hidden layers, default is "relu"
        """
        super(SimpleAnn, self).__init__()
        linear_list = nn.ModuleList()
        dropout_list = nn.ModuleList()
        if (
            hidden_size is None
            or (type(hidden_size) is int and hidden_size == 0)
            or (type(hidden_size) in [tuple, list] and len(hidden_size) < 1)
        ):
            linear_list.add_module("linear1", nn.Linear(nx, ny))
            if type(dr) in [tuple, list]:
                dr = dr[0]
            if dr > 0.0:
                dropout_list.append(nn.Dropout(dr))
        elif type(hidden_size) is int:
            linear_list.add_module("linear1", nn.Linear(nx, hidden_size))
            linear_list.add_module("linear2", nn.Linear(hidden_size, ny))
            if type(dr) in [tuple, list]:
                # dropout layer do not have additional weights, so we do not name them here
                dropout_list.append(nn.Dropout(dr[0]))
                if len(dr) > 1:
                    dropout_list.append(nn.Dropout(dr[1]))
            else:
                # dr must be a float
                dropout_list.append(nn.Dropout(dr))
        else:
            linear_list.add_module("linear1", nn.Linear(nx, hidden_size[0]))
            if type(dr) is float:
                dr = [dr] * len(hidden_size)
            elif len(dr) > len(hidden_size) + 1:
                raise ArithmeticError(
                    "At most, we set dropout layer for each nn layer, please check the number of dropout layers"
                )
            # dropout_list.add_module("dropout1", torch.nn.Dropout(dr[0]))
            dropout_list.append(nn.Dropout(dr[0]))
            for i in range(len(hidden_size) - 1):
                linear_list.add_module(
                    "linear%d" % (i + 1 + 1),
                    nn.Linear(hidden_size[i], hidden_size[i + 1]),
                )
                dropout_list.append(
                    nn.Dropout(dr[i + 1]),
                )
            linear_list.add_module(
                "linear%d" % (len(hidden_size) + 1),
                nn.Linear(hidden_size[-1], ny),
            )
            if len(dr) == len(linear_list):
                # if final linear also need a dr
                dropout_list.append(nn.Dropout(dr[-1]))
        self.linear_list = linear_list
        self.dropout_list = dropout_list
        self.activation = self._get_activation(activation)

    def forward(self, x):
        for i, model in enumerate(self.linear_list):
            if i == 0:
                if len(self.linear_list) == 1:
                    # final layer must be a linear layer
                    return (
                        model(x)
                        if len(self.dropout_list) < len(self.linear_list)
                        else self.dropout_list[i](model(x))
                    )
                else:
                    out = self.activation(self.dropout_list[i](model(x)))
            elif i == len(self.linear_list) - 1:
                # in final layer, no relu again
                return (
                    model(out)
                    if len(self.dropout_list) < len(self.linear_list)
                    else self.dropout_list[i](model(out))
                )
            else:
                out = self.activation(self.dropout_list[i](model(out)))

    def _get_activation(self, name: str) -> nn.Module:
        """a function to get activation function by name, reference from:
        https://github.com/neuralhydrology/neuralhydrology/blob/master/neuralhydrology/modelzoo/fc.py

        Parameters
        ----------
        name : str
            _description_

        Returns
        -------
        nn.Module
            _description_

        Raises
        ------
        NotImplementedError
            _description_
        """
        if name.lower() == "tanh":
            activation = nn.Tanh()
        elif name.lower() == "sigmoid":
            activation = nn.Sigmoid()
        elif name.lower() == "relu":
            activation = nn.ReLU()
        elif name.lower() == "linear":
            activation = nn.Identity()
        else:
            raise NotImplementedError(
                f"{name} currently not supported as activation in this class"
            )
        return activation

__init__(self, nx, ny, hidden_size=None, dr=0.0, activation='relu') special

A simple multi-layer NN model with final linear layer

Parameters

nx number of input neurons ny number of output neurons hidden_size a list/tuple which contains number of neurons in each hidden layer; if int, only one hidden layer except for hidden_size=0 dr dropout rate of layers, default is 0.0 which means no dropout; here we set number of dropout layers to (number of nn layers - 1) activation activation function for hidden layers, default is "relu"

Source code in torchhydro/models/ann.py
def __init__(
    self,
    nx: int,
    ny: int,
    hidden_size: Union[int, tuple, list] = None,
    dr: Union[float, tuple, list] = 0.0,
    activation: str = "relu",
):
    """
    A simple multi-layer NN model with final linear layer

    Parameters
    ----------
    nx
        number of input neurons
    ny
        number of output neurons
    hidden_size
        a list/tuple which contains number of neurons in each hidden layer;
        if int, only one hidden layer except for hidden_size=0
    dr
        dropout rate of layers, default is 0.0 which means no dropout;
        here we set number of dropout layers to (number of nn layers - 1)
    activation
        activation function for hidden layers, default is "relu"
    """
    super(SimpleAnn, self).__init__()
    linear_list = nn.ModuleList()
    dropout_list = nn.ModuleList()
    if (
        hidden_size is None
        or (type(hidden_size) is int and hidden_size == 0)
        or (type(hidden_size) in [tuple, list] and len(hidden_size) < 1)
    ):
        linear_list.add_module("linear1", nn.Linear(nx, ny))
        if type(dr) in [tuple, list]:
            dr = dr[0]
        if dr > 0.0:
            dropout_list.append(nn.Dropout(dr))
    elif type(hidden_size) is int:
        linear_list.add_module("linear1", nn.Linear(nx, hidden_size))
        linear_list.add_module("linear2", nn.Linear(hidden_size, ny))
        if type(dr) in [tuple, list]:
            # dropout layer do not have additional weights, so we do not name them here
            dropout_list.append(nn.Dropout(dr[0]))
            if len(dr) > 1:
                dropout_list.append(nn.Dropout(dr[1]))
        else:
            # dr must be a float
            dropout_list.append(nn.Dropout(dr))
    else:
        linear_list.add_module("linear1", nn.Linear(nx, hidden_size[0]))
        if type(dr) is float:
            dr = [dr] * len(hidden_size)
        elif len(dr) > len(hidden_size) + 1:
            raise ArithmeticError(
                "At most, we set dropout layer for each nn layer, please check the number of dropout layers"
            )
        # dropout_list.add_module("dropout1", torch.nn.Dropout(dr[0]))
        dropout_list.append(nn.Dropout(dr[0]))
        for i in range(len(hidden_size) - 1):
            linear_list.add_module(
                "linear%d" % (i + 1 + 1),
                nn.Linear(hidden_size[i], hidden_size[i + 1]),
            )
            dropout_list.append(
                nn.Dropout(dr[i + 1]),
            )
        linear_list.add_module(
            "linear%d" % (len(hidden_size) + 1),
            nn.Linear(hidden_size[-1], ny),
        )
        if len(dr) == len(linear_list):
            # if final linear also need a dr
            dropout_list.append(nn.Dropout(dr[-1]))
    self.linear_list = linear_list
    self.dropout_list = dropout_list
    self.activation = self._get_activation(activation)

forward(self, x)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/ann.py
def forward(self, x):
    for i, model in enumerate(self.linear_list):
        if i == 0:
            if len(self.linear_list) == 1:
                # final layer must be a linear layer
                return (
                    model(x)
                    if len(self.dropout_list) < len(self.linear_list)
                    else self.dropout_list[i](model(x))
                )
            else:
                out = self.activation(self.dropout_list[i](model(x)))
        elif i == len(self.linear_list) - 1:
            # in final layer, no relu again
            return (
                model(out)
                if len(self.dropout_list) < len(self.linear_list)
                else self.dropout_list[i](model(out))
            )
        else:
            out = self.activation(self.dropout_list[i](model(out)))

crits

Author: Wenyu Ouyang Date: 2021-12-31 11:08:29 LastEditTime: 2025-07-13 16:36:07 LastEditors: Wenyu Ouyang Description: Loss functions FilePath: orchhydro orchhydro\models\crits.py Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.

DynamicTaskPrior (Module)

Dynamic Task Prioritization

This method is proposed in https://openaccess.thecvf.com/content_ECCV_2018/html/Michelle_Guo_Focus_on_the_ECCV_2018_paper.html In contrast to UW and other curriculum learning methods, where easy tasks are prioritized above difficult tasks, It shows the importance of prioritizing difficult tasks first. It automatically prioritize more difficult tasks by adaptively adjusting the mixing weight of each task's loss. Here we choose correlation as KPI. As KPI must be in [0,1], we set (corr+1)/2 as KPI

Source code in torchhydro/models/crits.py
class DynamicTaskPrior(torch.nn.Module):
    r"""Dynamic Task Prioritization

    This method is proposed in https://openaccess.thecvf.com/content_ECCV_2018/html/Michelle_Guo_Focus_on_the_ECCV_2018_paper.html
    In contrast to UW and other curriculum learning methods, where easy tasks are prioritized above difficult tasks,
    It shows the importance of prioritizing difficult tasks first.
    It automatically prioritize more difficult tasks by adaptively adjusting the mixing weight of each task's loss.
    Here we choose correlation as KPI. As KPI must be in [0,1], we set (corr+1)/2 as KPI
    """

    def __init__(
        self,
        loss_funcs: Union[torch.nn.Module, list],
        data_gap: list = None,
        device: list = None,
        limit_part: list = None,
        gamma=2,
        alpha=0.5,
    ):
        """

        Parameters
        ----------
        loss_funcs
        data_gap
        device
        limit_part
        gamma
            the example-level focusing parameter
        alpha
            default is 1, which means we only use the newest KPI value
        """
        if data_gap is None:
            data_gap = [0, 2]
        if device is None:
            device = [0]
        super(DynamicTaskPrior, self).__init__()
        self.loss_funcs = loss_funcs
        self.data_gap = data_gap
        self.device = get_the_device(device)
        self.limit_part = limit_part
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, output, target, kpi_last=None):
        """
        Parameters
        ----------
        output
            model's prediction
        target
            observation

        kpi_last
            the KPI value of last iteration; each element for an output
            It use a moving average KPI as the weighting coefficient: KPI_i = alpha * KPI_i + (1-alpha) * KPI_{i-1}

        Returns
        -------
        torch.Tensor
            multi-task loss by Dynamic Task Prioritization method
        """
        n_out = target.shape[-1]
        loss = 0
        kpis = torch.zeros(n_out).to(self.device)
        for k in range(n_out):
            if self.limit_part is not None and k in self.limit_part:
                continue
            p0 = output[:, :, k]
            t0 = target[:, :, k]
            mask = t0 == t0
            p = p0[mask]
            t = t0[mask]
            if self.data_gap[k] > 0:
                p, t = deal_gap_data(p0, t0, self.data_gap[k], self.device)
            if type(self.loss_funcs) is list:
                temp = self.loss_funcs[k](p, t)
            else:
                temp = self.loss_funcs(p, t)
            # kpi must be in [0, 1], as corr's range is [-1, 1], just trans corr to  (corr+1)/2
            kpi = (torch.corrcoef(torch.stack([p, t], 1).T)[0, 1] + 1) / 2
            if self.alpha < 1:
                assert kpi_last is not None
                kpi = kpi * self.alpha + kpi_last[k] * (1 - self.alpha)
                # if we exclude kpi from the backward, it trans to a normal multi-task model
                # kpi = kpi.detach().clone() * self.alpha + kpi_last[k] * (1 - self.alpha)
            kpis[k] = kpi
            # focal loss
            fl = -((1 - kpi) ** self.gamma) * torch.log(kpi)
            loss += torch.sum(fl * temp, -1)
        # if kpi has grad_fn, backward will repeat. It won't work
        return loss, kpis.detach().clone()

__init__(self, loss_funcs, data_gap=None, device=None, limit_part=None, gamma=2, alpha=0.5) special

Parameters

loss_funcs data_gap device limit_part gamma the example-level focusing parameter alpha default is 1, which means we only use the newest KPI value

Source code in torchhydro/models/crits.py
def __init__(
    self,
    loss_funcs: Union[torch.nn.Module, list],
    data_gap: list = None,
    device: list = None,
    limit_part: list = None,
    gamma=2,
    alpha=0.5,
):
    """

    Parameters
    ----------
    loss_funcs
    data_gap
    device
    limit_part
    gamma
        the example-level focusing parameter
    alpha
        default is 1, which means we only use the newest KPI value
    """
    if data_gap is None:
        data_gap = [0, 2]
    if device is None:
        device = [0]
    super(DynamicTaskPrior, self).__init__()
    self.loss_funcs = loss_funcs
    self.data_gap = data_gap
    self.device = get_the_device(device)
    self.limit_part = limit_part
    self.gamma = gamma
    self.alpha = alpha

forward(self, output, target, kpi_last=None)

Parameters

output model's prediction target observation

kpi_last the KPI value of last iteration; each element for an output It use a moving average KPI as the weighting coefficient: KPI_i = alpha * KPI_i + (1-alpha) * KPI_{i-1}

Returns

torch.Tensor multi-task loss by Dynamic Task Prioritization method

Source code in torchhydro/models/crits.py
def forward(self, output, target, kpi_last=None):
    """
    Parameters
    ----------
    output
        model's prediction
    target
        observation

    kpi_last
        the KPI value of last iteration; each element for an output
        It use a moving average KPI as the weighting coefficient: KPI_i = alpha * KPI_i + (1-alpha) * KPI_{i-1}

    Returns
    -------
    torch.Tensor
        multi-task loss by Dynamic Task Prioritization method
    """
    n_out = target.shape[-1]
    loss = 0
    kpis = torch.zeros(n_out).to(self.device)
    for k in range(n_out):
        if self.limit_part is not None and k in self.limit_part:
            continue
        p0 = output[:, :, k]
        t0 = target[:, :, k]
        mask = t0 == t0
        p = p0[mask]
        t = t0[mask]
        if self.data_gap[k] > 0:
            p, t = deal_gap_data(p0, t0, self.data_gap[k], self.device)
        if type(self.loss_funcs) is list:
            temp = self.loss_funcs[k](p, t)
        else:
            temp = self.loss_funcs(p, t)
        # kpi must be in [0, 1], as corr's range is [-1, 1], just trans corr to  (corr+1)/2
        kpi = (torch.corrcoef(torch.stack([p, t], 1).T)[0, 1] + 1) / 2
        if self.alpha < 1:
            assert kpi_last is not None
            kpi = kpi * self.alpha + kpi_last[k] * (1 - self.alpha)
            # if we exclude kpi from the backward, it trans to a normal multi-task model
            # kpi = kpi.detach().clone() * self.alpha + kpi_last[k] * (1 - self.alpha)
        kpis[k] = kpi
        # focal loss
        fl = -((1 - kpi) ** self.gamma) * torch.log(kpi)
        loss += torch.sum(fl * temp, -1)
    # if kpi has grad_fn, backward will repeat. It won't work
    return loss, kpis.detach().clone()

FloodBaseLoss (Module, ABC)

Abstract base class for flood-related loss functions.

All flood-related loss functions should inherit from this class. The labels tensor is expected to have the flood mask as the last column.

Source code in torchhydro/models/crits.py
class FloodBaseLoss(torch.nn.Module, ABC):
    """
    Abstract base class for flood-related loss functions.

    All flood-related loss functions should inherit from this class.
    The labels tensor is expected to have the flood mask as the last column.
    """

    def __init__(self):
        super(FloodBaseLoss, self).__init__()

    @abstractmethod
    def compute_flood_loss(
        self, predictions: torch.Tensor, targets: torch.Tensor, flood_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute the flood-aware loss.

        Parameters
        ----------
        predictions : torch.Tensor
            Model predictions [batch_size, seq_len, output_features]
        targets : torch.Tensor
            Target values [batch_size, seq_len, output_features]
        flood_mask : torch.Tensor
            Flood mask [batch_size, seq_len] (1 for flood, 0 for normal)

        Returns
        -------
        torch.Tensor
            Computed loss value
        """
        pass

    def forward(
        self, predictions: torch.Tensor, targets: torch.Tensor, flood_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Forward pass that calls the abstract compute_flood_loss method.

        Parameters
        ----------
        predictions : torch.Tensor
            Model predictions [batch_size, seq_len, output_features]
        targets : torch.Tensor
            Target values [batch_size, seq_len, output_features]
        flood_mask : torch.Tensor
            Flood mask [batch_size, seq_len] (1 for flood, 0 for normal)

        Returns
        -------
        torch.Tensor
            Computed loss value
        """
        return self.compute_flood_loss(predictions, targets, flood_mask)

compute_flood_loss(self, predictions, targets, flood_mask)

Compute the flood-aware loss.

Parameters

predictions : torch.Tensor Model predictions [batch_size, seq_len, output_features] targets : torch.Tensor Target values [batch_size, seq_len, output_features] flood_mask : torch.Tensor Flood mask [batch_size, seq_len] (1 for flood, 0 for normal)

Returns

torch.Tensor Computed loss value

Source code in torchhydro/models/crits.py
@abstractmethod
def compute_flood_loss(
    self, predictions: torch.Tensor, targets: torch.Tensor, flood_mask: torch.Tensor
) -> torch.Tensor:
    """
    Compute the flood-aware loss.

    Parameters
    ----------
    predictions : torch.Tensor
        Model predictions [batch_size, seq_len, output_features]
    targets : torch.Tensor
        Target values [batch_size, seq_len, output_features]
    flood_mask : torch.Tensor
        Flood mask [batch_size, seq_len] (1 for flood, 0 for normal)

    Returns
    -------
    torch.Tensor
        Computed loss value
    """
    pass

forward(self, predictions, targets, flood_mask)

Forward pass that calls the abstract compute_flood_loss method.

Parameters

predictions : torch.Tensor Model predictions [batch_size, seq_len, output_features] targets : torch.Tensor Target values [batch_size, seq_len, output_features] flood_mask : torch.Tensor Flood mask [batch_size, seq_len] (1 for flood, 0 for normal)

Returns

torch.Tensor Computed loss value

Source code in torchhydro/models/crits.py
def forward(
    self, predictions: torch.Tensor, targets: torch.Tensor, flood_mask: torch.Tensor
) -> torch.Tensor:
    """
    Forward pass that calls the abstract compute_flood_loss method.

    Parameters
    ----------
    predictions : torch.Tensor
        Model predictions [batch_size, seq_len, output_features]
    targets : torch.Tensor
        Target values [batch_size, seq_len, output_features]
    flood_mask : torch.Tensor
        Flood mask [batch_size, seq_len] (1 for flood, 0 for normal)

    Returns
    -------
    torch.Tensor
        Computed loss value
    """
    return self.compute_flood_loss(predictions, targets, flood_mask)

FloodLoss (FloodBaseLoss)

Source code in torchhydro/models/crits.py
class FloodLoss(FloodBaseLoss):
    def __init__(
        self,
        loss_func: Union[torch.nn.Module, str] = "MSELoss",
        flood_weight: float = 2.0,
        non_flood_weight: float = 1.0,
        flood_strategy: str = "weight",
        flood_focus_factor: float = 2.0,
        device: list = None,
        **kwargs,
    ):
        """
        General flood-aware loss function with configurable base loss and strategy.

        Parameters
        ----------
        loss_func : Union[torch.nn.Module, str]
            Base loss function to use. Can be a PyTorch loss function or string name.
            Supported strings: "MSELoss", "MAELoss", "RMSELoss", "L1Loss"
        flood_weight : float
            Weight multiplier for flood events when using "weight" strategy, default is 2.0
        non_flood_weight : float
            Weight multiplier for non-flood events when using "weight" strategy, default is 1.0
        flood_strategy : str
            Strategy for handling flood events:
            - "weight": Apply higher weights to flood events
            - "focal": Use focal loss approach based on flood event frequency
        flood_focus_factor : float
            Factor for focal loss when using "focal" strategy, default is 2.0
        device : list
            Device configuration, default is None (auto-detect)
        """
        super(FloodLoss, self).__init__()
        self.flood_weight = flood_weight
        self.non_flood_weight = non_flood_weight
        self.flood_strategy = flood_strategy
        self.flood_focus_factor = flood_focus_factor
        self.device = get_the_device(device if device is not None else [0])

        # Initialize base loss function
        if isinstance(loss_func, str):
            loss_dict = {
                # NOTE: reduction="none" is important, otherwise the loss will be reduced to a scalar
                "MSELoss": torch.nn.MSELoss(reduction="none"),
                "MAELoss": torch.nn.L1Loss(reduction="none"),
                "L1Loss": torch.nn.L1Loss(reduction="none"),
                "RMSELoss": RMSELoss(),
                "HybridLoss": HybridLoss(
                    kwargs.get("mae_weight", 0.5), reduction="none"
                ),
            }
            if loss_func in loss_dict:
                self.base_loss_func = loss_dict[loss_func]
            else:
                raise ValueError(f"Unsupported loss function string: {loss_func}")
        else:
            self.base_loss_func = loss_func

    def compute_flood_loss(
        self, predictions: torch.Tensor, targets: torch.Tensor, flood_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute flood-aware loss using the specified strategy.

        Parameters
        ----------
        predictions : torch.Tensor
            Model predictions [batch_size, seq_len, output_features]
        targets : torch.Tensor
            Target values [batch_size, seq_len, output_features]
        flood_mask : torch.Tensor
            Flood mask [batch_size, seq_len, 1] (1 for flood, 0 for normal)

        Returns
        -------
        torch.Tensor
            Computed loss value
        """
        # Ensure flood_mask has correct shape
        if flood_mask.dim() == 3 and flood_mask.shape[-1] == 1:
            flood_mask = flood_mask.squeeze(-1)  # Remove last dimension if it's 1

        if self.flood_strategy == "weight":
            return self._compute_weighted_loss(predictions, targets, flood_mask)
        elif self.flood_strategy == "focal":
            return self._compute_focal_loss(predictions, targets, flood_mask)
        else:
            raise ValueError(f"Unsupported flood strategy: {self.flood_strategy}")

    def _compute_weighted_loss(
        self, predictions: torch.Tensor, targets: torch.Tensor, flood_mask: torch.Tensor
    ) -> torch.Tensor:
        """Compute loss with higher weights for flood events."""
        # Compute base loss
        if isinstance(self.base_loss_func, RMSELoss):
            # Special handling for RMSELoss which doesn't have reduction="none"
            base_loss = torch.pow(predictions - targets, 2)
        else:
            mask = ~torch.isnan(targets)
            predictions = predictions[mask]
            targets = targets[mask]
            flood_mask = flood_mask[
                mask.squeeze(-1)
            ]  # Ensure flood_mask matches predictions/targets
            base_loss = self.base_loss_func(predictions, targets)
        # return base_loss

        # Apply flood weights
        weights = torch.full_like(
            flood_mask, self.non_flood_weight, dtype=predictions.dtype
        )
        weights[flood_mask >= 1] = self.flood_weight

        # Apply weights to loss
        if base_loss.dim() == 3:  # [batch, seq, features]
            weighted_loss = base_loss * weights.unsqueeze(-1)
        else:  # [batch, seq]
            weighted_loss = base_loss * weights
        valid_mask = ~torch.isnan(weighted_loss)
        weighted_loss = weighted_loss[valid_mask]

        if isinstance(self.base_loss_func, RMSELoss):
            return torch.sqrt(weighted_loss.mean())
        else:
            return torch.mean(weighted_loss)

    def _compute_focal_loss(
        self, predictions: torch.Tensor, targets: torch.Tensor, flood_mask: torch.Tensor
    ) -> torch.Tensor:
        """Compute focal loss that emphasizes flood events."""
        # Compute base loss
        if isinstance(self.base_loss_func, RMSELoss):
            base_loss = torch.pow(predictions - targets, 2)
        else:
            base_loss = self.base_loss_func(predictions, targets)

        # Calculate flood ratio for focal weighting
        flood_ratio = flood_mask.float().mean(dim=1, keepdim=True)  # [batch_size, 1]

        # Focal weight: higher weight when flood events are rare
        focal_weight = (1 - flood_ratio) ** self.flood_focus_factor

        # Separate flood and normal events
        flood_loss = base_loss * flood_mask.unsqueeze(-1).float()
        normal_loss = base_loss * (1 - flood_mask.unsqueeze(-1).float())

        # Apply focal weight to flood events
        weighted_flood_loss = flood_loss * focal_weight.unsqueeze(-1)

        # Combine losses
        total_loss = weighted_flood_loss + normal_loss

        if isinstance(self.base_loss_func, RMSELoss):
            return torch.sqrt(total_loss.mean())
        else:
            return total_loss.mean()

__init__(self, loss_func='MSELoss', flood_weight=2.0, non_flood_weight=1.0, flood_strategy='weight', flood_focus_factor=2.0, device=None, **kwargs) special

General flood-aware loss function with configurable base loss and strategy.

Parameters

loss_func : Union[torch.nn.Module, str] Base loss function to use. Can be a PyTorch loss function or string name. Supported strings: "MSELoss", "MAELoss", "RMSELoss", "L1Loss" flood_weight : float Weight multiplier for flood events when using "weight" strategy, default is 2.0 non_flood_weight : float Weight multiplier for non-flood events when using "weight" strategy, default is 1.0 flood_strategy : str Strategy for handling flood events: - "weight": Apply higher weights to flood events - "focal": Use focal loss approach based on flood event frequency flood_focus_factor : float Factor for focal loss when using "focal" strategy, default is 2.0 device : list Device configuration, default is None (auto-detect)

Source code in torchhydro/models/crits.py
def __init__(
    self,
    loss_func: Union[torch.nn.Module, str] = "MSELoss",
    flood_weight: float = 2.0,
    non_flood_weight: float = 1.0,
    flood_strategy: str = "weight",
    flood_focus_factor: float = 2.0,
    device: list = None,
    **kwargs,
):
    """
    General flood-aware loss function with configurable base loss and strategy.

    Parameters
    ----------
    loss_func : Union[torch.nn.Module, str]
        Base loss function to use. Can be a PyTorch loss function or string name.
        Supported strings: "MSELoss", "MAELoss", "RMSELoss", "L1Loss"
    flood_weight : float
        Weight multiplier for flood events when using "weight" strategy, default is 2.0
    non_flood_weight : float
        Weight multiplier for non-flood events when using "weight" strategy, default is 1.0
    flood_strategy : str
        Strategy for handling flood events:
        - "weight": Apply higher weights to flood events
        - "focal": Use focal loss approach based on flood event frequency
    flood_focus_factor : float
        Factor for focal loss when using "focal" strategy, default is 2.0
    device : list
        Device configuration, default is None (auto-detect)
    """
    super(FloodLoss, self).__init__()
    self.flood_weight = flood_weight
    self.non_flood_weight = non_flood_weight
    self.flood_strategy = flood_strategy
    self.flood_focus_factor = flood_focus_factor
    self.device = get_the_device(device if device is not None else [0])

    # Initialize base loss function
    if isinstance(loss_func, str):
        loss_dict = {
            # NOTE: reduction="none" is important, otherwise the loss will be reduced to a scalar
            "MSELoss": torch.nn.MSELoss(reduction="none"),
            "MAELoss": torch.nn.L1Loss(reduction="none"),
            "L1Loss": torch.nn.L1Loss(reduction="none"),
            "RMSELoss": RMSELoss(),
            "HybridLoss": HybridLoss(
                kwargs.get("mae_weight", 0.5), reduction="none"
            ),
        }
        if loss_func in loss_dict:
            self.base_loss_func = loss_dict[loss_func]
        else:
            raise ValueError(f"Unsupported loss function string: {loss_func}")
    else:
        self.base_loss_func = loss_func

compute_flood_loss(self, predictions, targets, flood_mask)

Compute flood-aware loss using the specified strategy.

Parameters

predictions : torch.Tensor Model predictions [batch_size, seq_len, output_features] targets : torch.Tensor Target values [batch_size, seq_len, output_features] flood_mask : torch.Tensor Flood mask [batch_size, seq_len, 1] (1 for flood, 0 for normal)

Returns

torch.Tensor Computed loss value

Source code in torchhydro/models/crits.py
def compute_flood_loss(
    self, predictions: torch.Tensor, targets: torch.Tensor, flood_mask: torch.Tensor
) -> torch.Tensor:
    """
    Compute flood-aware loss using the specified strategy.

    Parameters
    ----------
    predictions : torch.Tensor
        Model predictions [batch_size, seq_len, output_features]
    targets : torch.Tensor
        Target values [batch_size, seq_len, output_features]
    flood_mask : torch.Tensor
        Flood mask [batch_size, seq_len, 1] (1 for flood, 0 for normal)

    Returns
    -------
    torch.Tensor
        Computed loss value
    """
    # Ensure flood_mask has correct shape
    if flood_mask.dim() == 3 and flood_mask.shape[-1] == 1:
        flood_mask = flood_mask.squeeze(-1)  # Remove last dimension if it's 1

    if self.flood_strategy == "weight":
        return self._compute_weighted_loss(predictions, targets, flood_mask)
    elif self.flood_strategy == "focal":
        return self._compute_focal_loss(predictions, targets, flood_mask)
    else:
        raise ValueError(f"Unsupported flood strategy: {self.flood_strategy}")

GaussianLoss (Module)

Source code in torchhydro/models/crits.py
class GaussianLoss(torch.nn.Module):
    def __init__(self, mu=0, sigma=0):
        """Compute the negative log likelihood of Gaussian Distribution
        From https://arxiv.org/abs/1907.00235
        """
        super(GaussianLoss, self).__init__()
        self.mu = mu
        self.sigma = sigma

    def forward(self, x: torch.Tensor):
        loss = -tdist.Normal(self.mu, self.sigma).log_prob(x)
        return torch.sum(loss) / (loss.size(0) * loss.size(1))

__init__(self, mu=0, sigma=0) special

Compute the negative log likelihood of Gaussian Distribution From https://arxiv.org/abs/1907.00235

Source code in torchhydro/models/crits.py
def __init__(self, mu=0, sigma=0):
    """Compute the negative log likelihood of Gaussian Distribution
    From https://arxiv.org/abs/1907.00235
    """
    super(GaussianLoss, self).__init__()
    self.mu = mu
    self.sigma = sigma

forward(self, x)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/crits.py
def forward(self, x: torch.Tensor):
    loss = -tdist.Normal(self.mu, self.sigma).log_prob(x)
    return torch.sum(loss) / (loss.size(0) * loss.size(1))

HybridFloodloss (FloodBaseLoss)

Source code in torchhydro/models/crits.py
class HybridFloodloss(FloodBaseLoss):
    def __init__(self, mae_weight=0.5):
        """
        Hybrid Flood Loss: PES loss + mae_weight × MAE with flood weighting

        Combines PES loss (MSE × sigmoid(MSE)) with Mean Absolute Error,
        applying flood weighting to the loss.

        The difference from FloodLoss is that this class filter flood events first then calculate loss,
        because Hybrid does sigmoid on MSE, when the non-flood-weight is 0, which means we do not want to
        calculate loss on non-flood events, so we need to filter them out first.

        Parameters
        ----------
        mae_weight : float
            Weight for the MAE component, default is 0.5
        flood_weight : float
            Weight multiplier for flood events, default is 2.0
        """
        super(HybridFloodloss, self).__init__()
        self.mae_weight = mae_weight

    def compute_flood_loss(
        self, predictions: torch.Tensor, targets: torch.Tensor, flood_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute flood-aware loss using the specified strategy.

        Parameters
        ----------
        predictions : torch.Tensor
            Model predictions [batch_size, seq_len, output_features]
        targets : torch.Tensor
            Target values [batch_size, seq_len, output_features]
        flood_mask : torch.Tensor
            Flood mask [batch_size, seq_len, 1] (1 for flood, 0 for normal)

        Returns
        -------
        torch.Tensor
            Computed loss value
        """
        boolean_mask = flood_mask.to(torch.bool)
        predictions = predictions[boolean_mask]
        targets = targets[boolean_mask]

        base_loss_func = HybridLoss(self.mae_weight)
        return base_loss_func(predictions, targets)

__init__(self, mae_weight=0.5) special

Hybrid Flood Loss: PES loss + mae_weight × MAE with flood weighting

Combines PES loss (MSE × sigmoid(MSE)) with Mean Absolute Error, applying flood weighting to the loss.

The difference from FloodLoss is that this class filter flood events first then calculate loss, because Hybrid does sigmoid on MSE, when the non-flood-weight is 0, which means we do not want to calculate loss on non-flood events, so we need to filter them out first.

Parameters

mae_weight : float Weight for the MAE component, default is 0.5 flood_weight : float Weight multiplier for flood events, default is 2.0

Source code in torchhydro/models/crits.py
def __init__(self, mae_weight=0.5):
    """
    Hybrid Flood Loss: PES loss + mae_weight × MAE with flood weighting

    Combines PES loss (MSE × sigmoid(MSE)) with Mean Absolute Error,
    applying flood weighting to the loss.

    The difference from FloodLoss is that this class filter flood events first then calculate loss,
    because Hybrid does sigmoid on MSE, when the non-flood-weight is 0, which means we do not want to
    calculate loss on non-flood events, so we need to filter them out first.

    Parameters
    ----------
    mae_weight : float
        Weight for the MAE component, default is 0.5
    flood_weight : float
        Weight multiplier for flood events, default is 2.0
    """
    super(HybridFloodloss, self).__init__()
    self.mae_weight = mae_weight

compute_flood_loss(self, predictions, targets, flood_mask)

Compute flood-aware loss using the specified strategy.

Parameters

predictions : torch.Tensor Model predictions [batch_size, seq_len, output_features] targets : torch.Tensor Target values [batch_size, seq_len, output_features] flood_mask : torch.Tensor Flood mask [batch_size, seq_len, 1] (1 for flood, 0 for normal)

Returns

torch.Tensor Computed loss value

Source code in torchhydro/models/crits.py
def compute_flood_loss(
    self, predictions: torch.Tensor, targets: torch.Tensor, flood_mask: torch.Tensor
) -> torch.Tensor:
    """
    Compute flood-aware loss using the specified strategy.

    Parameters
    ----------
    predictions : torch.Tensor
        Model predictions [batch_size, seq_len, output_features]
    targets : torch.Tensor
        Target values [batch_size, seq_len, output_features]
    flood_mask : torch.Tensor
        Flood mask [batch_size, seq_len, 1] (1 for flood, 0 for normal)

    Returns
    -------
    torch.Tensor
        Computed loss value
    """
    boolean_mask = flood_mask.to(torch.bool)
    predictions = predictions[boolean_mask]
    targets = targets[boolean_mask]

    base_loss_func = HybridLoss(self.mae_weight)
    return base_loss_func(predictions, targets)

HybridLoss (Module)

Source code in torchhydro/models/crits.py
class HybridLoss(torch.nn.Module):
    def __init__(self, mae_weight: float = 0.5, reduction: str = "mean"):
        """
        Hybrid Loss: PES loss + mae_weight × MAE

        Combines PES loss (MSE × sigmoid(MSE)) with Mean Absolute Error.

        Parameters
        ----------
        mae_weight : float
            Weight for the MAE component, default is 0.5
        reduction : str
            Reduction method for the loss, default is "mean". Can be "mean" or "none".
            If "none", returns the loss without reduction.
        """
        super(HybridLoss, self).__init__()
        self.pes_loss = PESLoss()
        self.mae = MAELoss(reduction=reduction)
        self.mae_weight = mae_weight
        self.reduction = reduction

    def forward(self, output: torch.Tensor, target: torch.Tensor):
        pes = self.pes_loss(output, target)
        mae = self.mae(output, target)
        if self.reduction == "none":
            return pes + self.mae_weight * mae
        elif self.reduction == "mean":
            loss = pes + self.mae_weight * mae
            valid_mask = ~torch.isnan(loss)
            return torch.mean(loss[valid_mask])
        else:
            raise ValueError(
                f"Unsupported reduction method: {self.reduction}. Use 'mean' or 'none'."
            )

__init__(self, mae_weight=0.5, reduction='mean') special

Hybrid Loss: PES loss + mae_weight × MAE

Combines PES loss (MSE × sigmoid(MSE)) with Mean Absolute Error.

Parameters

mae_weight : float Weight for the MAE component, default is 0.5 reduction : str Reduction method for the loss, default is "mean". Can be "mean" or "none". If "none", returns the loss without reduction.

Source code in torchhydro/models/crits.py
def __init__(self, mae_weight: float = 0.5, reduction: str = "mean"):
    """
    Hybrid Loss: PES loss + mae_weight × MAE

    Combines PES loss (MSE × sigmoid(MSE)) with Mean Absolute Error.

    Parameters
    ----------
    mae_weight : float
        Weight for the MAE component, default is 0.5
    reduction : str
        Reduction method for the loss, default is "mean". Can be "mean" or "none".
        If "none", returns the loss without reduction.
    """
    super(HybridLoss, self).__init__()
    self.pes_loss = PESLoss()
    self.mae = MAELoss(reduction=reduction)
    self.mae_weight = mae_weight
    self.reduction = reduction

forward(self, output, target)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/crits.py
def forward(self, output: torch.Tensor, target: torch.Tensor):
    pes = self.pes_loss(output, target)
    mae = self.mae(output, target)
    if self.reduction == "none":
        return pes + self.mae_weight * mae
    elif self.reduction == "mean":
        loss = pes + self.mae_weight * mae
        valid_mask = ~torch.isnan(loss)
        return torch.mean(loss[valid_mask])
    else:
        raise ValueError(
            f"Unsupported reduction method: {self.reduction}. Use 'mean' or 'none'."
        )

MAELoss (Module)

Source code in torchhydro/models/crits.py
class MAELoss(torch.nn.Module):
    def __init__(self, reduction: str = "mean"):
        super().__init__()
        self.reduction = reduction

    def forward(self, output: torch.Tensor, target: torch.Tensor):
        # Create a mask to filter out NaN values
        mask = ~torch.isnan(target)

        # Apply the mask to both target and output
        target = target[mask]
        output = output[mask]

        # Calculate MAE for the non-NaN values
        if self.reduction == "mean":  # Return mean MAe
            return torch.mean(torch.abs(target - output))
        elif self.reduction == "none":
            return torch.abs(target - output)
        else:
            raise ValueError(
                "Reduction must be 'mean' or 'none', got {}".format(self.reduction)
            )

forward(self, output, target)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/crits.py
def forward(self, output: torch.Tensor, target: torch.Tensor):
    # Create a mask to filter out NaN values
    mask = ~torch.isnan(target)

    # Apply the mask to both target and output
    target = target[mask]
    output = output[mask]

    # Calculate MAE for the non-NaN values
    if self.reduction == "mean":  # Return mean MAe
        return torch.mean(torch.abs(target - output))
    elif self.reduction == "none":
        return torch.abs(target - output)
    else:
        raise ValueError(
            "Reduction must be 'mean' or 'none', got {}".format(self.reduction)
        )

MAPELoss (Module)

Returns MAPE using: target -> True y output -> Predtion by model

Source code in torchhydro/models/crits.py
class MAPELoss(torch.nn.Module):
    """
    Returns MAPE using:
    target -> True y
    output -> Predtion by model
    """

    def __init__(self, variance_penalty=0.0):
        super().__init__()
        self.variance_penalty = variance_penalty

    def forward(self, output: torch.Tensor, target: torch.Tensor):
        if len(output) > 1:
            return torch.mean(
                torch.abs(torch.sub(target, output) / target)
            ) + self.variance_penalty * torch.std(torch.sub(target, output))
        else:
            return torch.mean(torch.abs(torch.sub(target, output) / target))

forward(self, output, target)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/crits.py
def forward(self, output: torch.Tensor, target: torch.Tensor):
    if len(output) > 1:
        return torch.mean(
            torch.abs(torch.sub(target, output) / target)
        ) + self.variance_penalty * torch.std(torch.sub(target, output))
    else:
        return torch.mean(torch.abs(torch.sub(target, output) / target))

MASELoss (Module)

Source code in torchhydro/models/crits.py
class MASELoss(torch.nn.Module):
    def __init__(self, baseline_method):
        """
        This implements the MASE loss function (e.g. MAE_MODEL/MAE_NAIEVE)
        """
        super(MASELoss, self).__init__()
        self.method_dict = {
            "mean": lambda x, y: torch.mean(x, 1).unsqueeze(1).repeat(1, y[1], 1)
        }
        self.baseline_method = self.method_dict[baseline_method]

    def forward(
        self, target: torch.Tensor, output: torch.Tensor, train_data: torch.Tensor, m=1
    ) -> torch.Tensor:
        # Ugh why can't all tensors have batch size... Fixes for modern
        if len(train_data.shape) < 3:
            train_data = train_data.unsqueeze(0)
        if m == 1 and len(target.shape) == 1:
            output = output.unsqueeze(0)
            output = output.unsqueeze(2)
            target = target.unsqueeze(0)
            target = target.unsqueeze(2)
        if len(target.shape) == 2:
            output = output.unsqueeze(0)
            target = target.unsqueeze(0)
        result_baseline = self.baseline_method(train_data, output.shape)
        MAE = torch.nn.L1Loss()
        mae2 = MAE(output, target)
        mase4 = MAE(result_baseline, target)
        # Prevent divison by zero/loss exploding
        if mase4 < 0.001:
            mase4 = 0.001
        return mae2 / mase4

__init__(self, baseline_method) special

This implements the MASE loss function (e.g. MAE_MODEL/MAE_NAIEVE)

Source code in torchhydro/models/crits.py
def __init__(self, baseline_method):
    """
    This implements the MASE loss function (e.g. MAE_MODEL/MAE_NAIEVE)
    """
    super(MASELoss, self).__init__()
    self.method_dict = {
        "mean": lambda x, y: torch.mean(x, 1).unsqueeze(1).repeat(1, y[1], 1)
    }
    self.baseline_method = self.method_dict[baseline_method]

forward(self, target, output, train_data, m=1)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/crits.py
def forward(
    self, target: torch.Tensor, output: torch.Tensor, train_data: torch.Tensor, m=1
) -> torch.Tensor:
    # Ugh why can't all tensors have batch size... Fixes for modern
    if len(train_data.shape) < 3:
        train_data = train_data.unsqueeze(0)
    if m == 1 and len(target.shape) == 1:
        output = output.unsqueeze(0)
        output = output.unsqueeze(2)
        target = target.unsqueeze(0)
        target = target.unsqueeze(2)
    if len(target.shape) == 2:
        output = output.unsqueeze(0)
        target = target.unsqueeze(0)
    result_baseline = self.baseline_method(train_data, output.shape)
    MAE = torch.nn.L1Loss()
    mae2 = MAE(output, target)
    mase4 = MAE(result_baseline, target)
    # Prevent divison by zero/loss exploding
    if mase4 < 0.001:
        mase4 = 0.001
    return mae2 / mase4

MultiOutLoss (Module)

Source code in torchhydro/models/crits.py
class MultiOutLoss(torch.nn.Module):
    def __init__(
        self,
        loss_funcs: Union[torch.nn.Module, list],
        data_gap: list = None,
        device: list = None,
        limit_part: list = None,
        item_weight: list = None,
    ):
        """
        Loss function for multiple output

        Parameters
        ----------
        loss_funcs
            The loss functions for each output
        data_gap
            It belongs to the feature dim.
            If 1, then the corresponding value is uniformly-spaced with NaN values filling the gap;
            in addition, the first non-nan value means the aggregated value of the following interval,
            for example, in [5, nan, nan, nan], 5 means all four data's sum, although the next 3 values are nan
            hence the calculation is a little different;
            if 2, the first non-nan value means the average value of the following interval,
            for example, in [5, nan, nan, nan], 5 means all four data's mean value;
            default is [0, 2]
        device
            the number of device: -1 -> "cpu" or "cuda:x" (x is 0, 1 or ...)
        limit_part
            when transfer learning, we may ignore some part;
            the default is None, which means no ignorance;
            other choices are list, such as [0], [0, 1] or [1,2,..];
            0 means the first variable;
            tensor is [seq, time, var] or [time, seq, var]
        item_weight
            use different weight for each item's loss;
            for example, the default values [0.5, 0.5] means 0.5 * loss1 + 0.5 * loss2
        """
        if data_gap is None:
            data_gap = [0, 2]
        if device is None:
            device = [0]
        if item_weight is None:
            item_weight = [0.5, 0.5]
        super(MultiOutLoss, self).__init__()
        self.loss_funcs = loss_funcs
        self.data_gap = data_gap
        self.device = get_the_device(device)
        self.limit_part = limit_part
        self.item_weight = item_weight

    def forward(self, output: Tensor, target: Tensor):
        """
        Calculate the sum of losses for different variables

        When there are NaN values in observation, we will perform a "reduce" operation on prediction.
        For example, pred = [0,1,2,3,4], obs=[5, nan, nan, 6, nan]; the "reduce" is sum;
        then, pred_sum = [0+1+2, 3+4], obs_sum=[5,6], loss = loss_func(pred_sum, obs_sum).
        Notice: when "sum", actually final index is not chosen,
        because the whole observation may be [5, nan, nan, 6, nan, nan, 7, nan, nan], 6 means sum of three elements.
        Just as the rho is 5, the final one is not chosen


        Parameters
        ----------
        output
            the prediction tensor; 3-dims are time sequence, batch and feature, respectively
        target
            the observation tensor

        Returns
        -------
        Tensor
            Whole loss
        """
        n_out = target.shape[-1]
        loss = 0
        for k in range(n_out):
            if self.limit_part is not None and k in self.limit_part:
                continue
            p0 = output[:, :, k]
            t0 = target[:, :, k]
            mask = t0 == t0
            p = p0[mask]
            t = t0[mask]
            if self.data_gap[k] > 0:
                p, t = deal_gap_data(p0, t0, self.data_gap[k], self.device)
            if type(self.loss_funcs) is list:
                temp = self.item_weight[k] * self.loss_funcs[k](p, t)
            else:
                temp = self.item_weight[k] * self.loss_funcs(p, t)
            # sum of all k-th loss
            if torch.isnan(temp).any():
                continue
            loss = loss + temp
        return loss

__init__(self, loss_funcs, data_gap=None, device=None, limit_part=None, item_weight=None) special

Loss function for multiple output

Parameters

loss_funcs The loss functions for each output data_gap It belongs to the feature dim. If 1, then the corresponding value is uniformly-spaced with NaN values filling the gap; in addition, the first non-nan value means the aggregated value of the following interval, for example, in [5, nan, nan, nan], 5 means all four data's sum, although the next 3 values are nan hence the calculation is a little different; if 2, the first non-nan value means the average value of the following interval, for example, in [5, nan, nan, nan], 5 means all four data's mean value; default is [0, 2] device the number of device: -1 -> "cpu" or "cuda:x" (x is 0, 1 or ...) limit_part when transfer learning, we may ignore some part; the default is None, which means no ignorance; other choices are list, such as [0], [0, 1] or [1,2,..]; 0 means the first variable; tensor is [seq, time, var] or [time, seq, var] item_weight use different weight for each item's loss; for example, the default values [0.5, 0.5] means 0.5 * loss1 + 0.5 * loss2

Source code in torchhydro/models/crits.py
def __init__(
    self,
    loss_funcs: Union[torch.nn.Module, list],
    data_gap: list = None,
    device: list = None,
    limit_part: list = None,
    item_weight: list = None,
):
    """
    Loss function for multiple output

    Parameters
    ----------
    loss_funcs
        The loss functions for each output
    data_gap
        It belongs to the feature dim.
        If 1, then the corresponding value is uniformly-spaced with NaN values filling the gap;
        in addition, the first non-nan value means the aggregated value of the following interval,
        for example, in [5, nan, nan, nan], 5 means all four data's sum, although the next 3 values are nan
        hence the calculation is a little different;
        if 2, the first non-nan value means the average value of the following interval,
        for example, in [5, nan, nan, nan], 5 means all four data's mean value;
        default is [0, 2]
    device
        the number of device: -1 -> "cpu" or "cuda:x" (x is 0, 1 or ...)
    limit_part
        when transfer learning, we may ignore some part;
        the default is None, which means no ignorance;
        other choices are list, such as [0], [0, 1] or [1,2,..];
        0 means the first variable;
        tensor is [seq, time, var] or [time, seq, var]
    item_weight
        use different weight for each item's loss;
        for example, the default values [0.5, 0.5] means 0.5 * loss1 + 0.5 * loss2
    """
    if data_gap is None:
        data_gap = [0, 2]
    if device is None:
        device = [0]
    if item_weight is None:
        item_weight = [0.5, 0.5]
    super(MultiOutLoss, self).__init__()
    self.loss_funcs = loss_funcs
    self.data_gap = data_gap
    self.device = get_the_device(device)
    self.limit_part = limit_part
    self.item_weight = item_weight

forward(self, output, target)

Calculate the sum of losses for different variables

When there are NaN values in observation, we will perform a "reduce" operation on prediction. For example, pred = [0,1,2,3,4], obs=[5, nan, nan, 6, nan]; the "reduce" is sum; then, pred_sum = [0+1+2, 3+4], obs_sum=[5,6], loss = loss_func(pred_sum, obs_sum). Notice: when "sum", actually final index is not chosen, because the whole observation may be [5, nan, nan, 6, nan, nan, 7, nan, nan], 6 means sum of three elements. Just as the rho is 5, the final one is not chosen

Parameters

output the prediction tensor; 3-dims are time sequence, batch and feature, respectively target the observation tensor

Returns

Tensor Whole loss

Source code in torchhydro/models/crits.py
def forward(self, output: Tensor, target: Tensor):
    """
    Calculate the sum of losses for different variables

    When there are NaN values in observation, we will perform a "reduce" operation on prediction.
    For example, pred = [0,1,2,3,4], obs=[5, nan, nan, 6, nan]; the "reduce" is sum;
    then, pred_sum = [0+1+2, 3+4], obs_sum=[5,6], loss = loss_func(pred_sum, obs_sum).
    Notice: when "sum", actually final index is not chosen,
    because the whole observation may be [5, nan, nan, 6, nan, nan, 7, nan, nan], 6 means sum of three elements.
    Just as the rho is 5, the final one is not chosen


    Parameters
    ----------
    output
        the prediction tensor; 3-dims are time sequence, batch and feature, respectively
    target
        the observation tensor

    Returns
    -------
    Tensor
        Whole loss
    """
    n_out = target.shape[-1]
    loss = 0
    for k in range(n_out):
        if self.limit_part is not None and k in self.limit_part:
            continue
        p0 = output[:, :, k]
        t0 = target[:, :, k]
        mask = t0 == t0
        p = p0[mask]
        t = t0[mask]
        if self.data_gap[k] > 0:
            p, t = deal_gap_data(p0, t0, self.data_gap[k], self.device)
        if type(self.loss_funcs) is list:
            temp = self.item_weight[k] * self.loss_funcs[k](p, t)
        else:
            temp = self.item_weight[k] * self.loss_funcs(p, t)
        # sum of all k-th loss
        if torch.isnan(temp).any():
            continue
        loss = loss + temp
    return loss

MultiOutWaterBalanceLoss (Module)

Source code in torchhydro/models/crits.py
class MultiOutWaterBalanceLoss(torch.nn.Module):
    def __init__(
        self,
        loss_funcs: Union[torch.nn.Module, list],
        data_gap: list = None,
        device: list = None,
        limit_part: list = None,
        item_weight: list = None,
        alpha=0.5,
        beta=0.0,
        wb_loss_func=None,
        means=None,
        stds=None,
    ):
        """
        Loss function for multiple output considering water balance

        loss = alpha * water_balance_loss + (1-alpha) * mtl_loss

        This loss function is only for p, q, et now
        we use the difference between p_obs_mean-q_obs_mean-et_obs_mean and p_pred_mean-q_pred_mean-et_pred_mean as water balance loss
        which is the difference between (q_obs_mean + et_obs_mean) and (q_pred_mean + et_pred_mean)

        Parameters
        ----------
        loss_funcs
            The loss functions for each output
        data_gap
            It belongs to the feature dim.
            If 1, then the corresponding value is uniformly-spaced with NaN values filling the gap;
            in addition, the first non-nan value means the aggregated value of the following interval,
            for example, in [5, nan, nan, nan], 5 means all four data's sum, although the next 3 values are nan
            hence the calculation is a little different;
            if 2, the first non-nan value means the average value of the following interval,
            for example, in [5, nan, nan, nan], 5 means all four data's mean value;
            default is [0, 2]
        device
            the number of device: -1 -> "cpu" or "cuda:x" (x is 0, 1 or ...)
        limit_part
            when transfer learning, we may ignore some part;
            the default is None, which means no ignorance;
            other choices are list, such as [0], [0, 1] or [1,2,..];
            0 means the first variable;
            tensor is [seq, time, var] or [time, seq, var]
        item_weight
            use different weight for each item's loss;
            for example, the default values [0.5, 0.5] means 0.5 * loss1 + 0.5 * loss2
        alpha
            the weight of the water-balance item's loss
        beta
            the weight of real water-balance item's loss, et_mean/p_mean + q_mean/p_mean = 1 can be a loss.
            It is not strictly correct as training batch only have about one year data, but still could be a constraint
        wb_loss_func
            the loss function for water balance item, by default it is None, which means we use function in loss_funcs
        """
        if data_gap is None:
            data_gap = [0, 2]
        if device is None:
            device = [0]
        if item_weight is None:
            item_weight = [0.5, 0.5]
        super(MultiOutWaterBalanceLoss, self).__init__()
        self.loss_funcs = loss_funcs
        self.data_gap = data_gap
        self.device = get_the_device(device)
        self.limit_part = limit_part
        self.item_weight = item_weight
        self.alpha = alpha
        self.beta = beta
        self.wb_loss_func = wb_loss_func
        self.means = means
        self.stds = stds

    def forward(self, output: Tensor, target: Tensor):
        """
        Calculate the sum of losses for different variables and water-balance loss

        When there are NaN values in observation, we will perform a "reduce" operation on prediction.
        For example, pred = [0,1,2,3,4], obs=[5, nan, nan, 6, nan]; the "reduce" is sum;
        then, pred_sum = [0+1+2, 3+4], obs_sum=[5,6], loss = loss_func(pred_sum, obs_sum).
        Notice: when "sum", actually final index is not chosen,
        because the whole observation may be [5, nan, nan, 6, nan, nan, 7, nan, nan], 6 means sum of three elements.
        Just as the rho is 5, the final one is not chosen


        Parameters
        ----------
        output
            the prediction tensor; 3-dims are time sequence, batch and feature, respectively
        target
            the observation tensor

        Returns
        -------
        Tensor
            Whole loss
        """
        n_out = target.shape[-1]
        loss = 0
        p_means = []
        t_means = []
        all_means = self.means
        all_stds = self.stds
        for k in range(n_out):
            if self.limit_part is not None and k in self.limit_part:
                continue
            p0 = output[:, :, k]
            t0 = target[:, :, k]
            # for water balance loss
            if all_means is not None:
                # denormalize for q and et
                p1 = p0 * all_stds[k] + all_means[k]
                t1 = t0 * all_stds[k] + all_means[k]
                p2 = (10**p1 - 0.1) ** 2
                t2 = (10**t1 - 0.1) ** 2
                p_mean = torch.nanmean(p2, dim=0)
                t_mean = torch.nanmean(t2, dim=0)
            else:
                p_mean = torch.nanmean(p0, dim=0)
                t_mean = torch.nanmean(t0, dim=0)
            p_means.append(p_mean)
            t_means.append(t_mean)
            # for mtl normal loss
            mask = t0 == t0
            p = p0[mask]
            t = t0[mask]
            if self.data_gap[k] > 0:
                p, t = deal_gap_data(p0, t0, self.data_gap[k], self.device)
            if type(self.loss_funcs) is list:
                temp = self.item_weight[k] * self.loss_funcs[k](p, t)
            else:
                temp = self.item_weight[k] * self.loss_funcs(p, t)
            # sum of all k-th loss
            loss = loss + temp
        # water balance loss
        p_mean_q_plus_et = torch.sum(torch.stack(p_means), dim=0)
        t_mean_q_plus_et = torch.sum(torch.stack(t_means), dim=0)
        wb_ones = torch.ones_like(t_mean_q_plus_et)
        if self.wb_loss_func is None:
            if type(self.loss_funcs) is list:
                # if wb_loss_func is None, we use the first loss function in loss_funcs
                wb_loss = self.loss_funcs[0](p_mean_q_plus_et, t_mean_q_plus_et)
                wb_1loss = self.loss_funcs[0](p_mean_q_plus_et, wb_ones)
            else:
                wb_loss = self.loss_funcs(p_mean_q_plus_et, t_mean_q_plus_et)
                wb_1loss = self.loss_funcs(p_mean_q_plus_et, wb_ones)
        else:
            wb_loss = self.wb_loss_func(p_mean_q_plus_et, t_mean_q_plus_et)
            wb_1loss = self.wb_loss_func(p_mean_q_plus_et, wb_ones)
        return (
            self.alpha * wb_loss
            + (1 - self.alpha - self.beta) * loss
            + self.beta * wb_1loss
        )

__init__(self, loss_funcs, data_gap=None, device=None, limit_part=None, item_weight=None, alpha=0.5, beta=0.0, wb_loss_func=None, means=None, stds=None) special

Loss function for multiple output considering water balance

loss = alpha * water_balance_loss + (1-alpha) * mtl_loss

This loss function is only for p, q, et now we use the difference between p_obs_mean-q_obs_mean-et_obs_mean and p_pred_mean-q_pred_mean-et_pred_mean as water balance loss which is the difference between (q_obs_mean + et_obs_mean) and (q_pred_mean + et_pred_mean)

Parameters

loss_funcs The loss functions for each output data_gap It belongs to the feature dim. If 1, then the corresponding value is uniformly-spaced with NaN values filling the gap; in addition, the first non-nan value means the aggregated value of the following interval, for example, in [5, nan, nan, nan], 5 means all four data's sum, although the next 3 values are nan hence the calculation is a little different; if 2, the first non-nan value means the average value of the following interval, for example, in [5, nan, nan, nan], 5 means all four data's mean value; default is [0, 2] device the number of device: -1 -> "cpu" or "cuda:x" (x is 0, 1 or ...) limit_part when transfer learning, we may ignore some part; the default is None, which means no ignorance; other choices are list, such as [0], [0, 1] or [1,2,..]; 0 means the first variable; tensor is [seq, time, var] or [time, seq, var] item_weight use different weight for each item's loss; for example, the default values [0.5, 0.5] means 0.5 * loss1 + 0.5 * loss2 alpha the weight of the water-balance item's loss beta the weight of real water-balance item's loss, et_mean/p_mean + q_mean/p_mean = 1 can be a loss. It is not strictly correct as training batch only have about one year data, but still could be a constraint wb_loss_func the loss function for water balance item, by default it is None, which means we use function in loss_funcs

Source code in torchhydro/models/crits.py
def __init__(
    self,
    loss_funcs: Union[torch.nn.Module, list],
    data_gap: list = None,
    device: list = None,
    limit_part: list = None,
    item_weight: list = None,
    alpha=0.5,
    beta=0.0,
    wb_loss_func=None,
    means=None,
    stds=None,
):
    """
    Loss function for multiple output considering water balance

    loss = alpha * water_balance_loss + (1-alpha) * mtl_loss

    This loss function is only for p, q, et now
    we use the difference between p_obs_mean-q_obs_mean-et_obs_mean and p_pred_mean-q_pred_mean-et_pred_mean as water balance loss
    which is the difference between (q_obs_mean + et_obs_mean) and (q_pred_mean + et_pred_mean)

    Parameters
    ----------
    loss_funcs
        The loss functions for each output
    data_gap
        It belongs to the feature dim.
        If 1, then the corresponding value is uniformly-spaced with NaN values filling the gap;
        in addition, the first non-nan value means the aggregated value of the following interval,
        for example, in [5, nan, nan, nan], 5 means all four data's sum, although the next 3 values are nan
        hence the calculation is a little different;
        if 2, the first non-nan value means the average value of the following interval,
        for example, in [5, nan, nan, nan], 5 means all four data's mean value;
        default is [0, 2]
    device
        the number of device: -1 -> "cpu" or "cuda:x" (x is 0, 1 or ...)
    limit_part
        when transfer learning, we may ignore some part;
        the default is None, which means no ignorance;
        other choices are list, such as [0], [0, 1] or [1,2,..];
        0 means the first variable;
        tensor is [seq, time, var] or [time, seq, var]
    item_weight
        use different weight for each item's loss;
        for example, the default values [0.5, 0.5] means 0.5 * loss1 + 0.5 * loss2
    alpha
        the weight of the water-balance item's loss
    beta
        the weight of real water-balance item's loss, et_mean/p_mean + q_mean/p_mean = 1 can be a loss.
        It is not strictly correct as training batch only have about one year data, but still could be a constraint
    wb_loss_func
        the loss function for water balance item, by default it is None, which means we use function in loss_funcs
    """
    if data_gap is None:
        data_gap = [0, 2]
    if device is None:
        device = [0]
    if item_weight is None:
        item_weight = [0.5, 0.5]
    super(MultiOutWaterBalanceLoss, self).__init__()
    self.loss_funcs = loss_funcs
    self.data_gap = data_gap
    self.device = get_the_device(device)
    self.limit_part = limit_part
    self.item_weight = item_weight
    self.alpha = alpha
    self.beta = beta
    self.wb_loss_func = wb_loss_func
    self.means = means
    self.stds = stds

forward(self, output, target)

Calculate the sum of losses for different variables and water-balance loss

When there are NaN values in observation, we will perform a "reduce" operation on prediction. For example, pred = [0,1,2,3,4], obs=[5, nan, nan, 6, nan]; the "reduce" is sum; then, pred_sum = [0+1+2, 3+4], obs_sum=[5,6], loss = loss_func(pred_sum, obs_sum). Notice: when "sum", actually final index is not chosen, because the whole observation may be [5, nan, nan, 6, nan, nan, 7, nan, nan], 6 means sum of three elements. Just as the rho is 5, the final one is not chosen

Parameters

output the prediction tensor; 3-dims are time sequence, batch and feature, respectively target the observation tensor

Returns

Tensor Whole loss

Source code in torchhydro/models/crits.py
def forward(self, output: Tensor, target: Tensor):
    """
    Calculate the sum of losses for different variables and water-balance loss

    When there are NaN values in observation, we will perform a "reduce" operation on prediction.
    For example, pred = [0,1,2,3,4], obs=[5, nan, nan, 6, nan]; the "reduce" is sum;
    then, pred_sum = [0+1+2, 3+4], obs_sum=[5,6], loss = loss_func(pred_sum, obs_sum).
    Notice: when "sum", actually final index is not chosen,
    because the whole observation may be [5, nan, nan, 6, nan, nan, 7, nan, nan], 6 means sum of three elements.
    Just as the rho is 5, the final one is not chosen


    Parameters
    ----------
    output
        the prediction tensor; 3-dims are time sequence, batch and feature, respectively
    target
        the observation tensor

    Returns
    -------
    Tensor
        Whole loss
    """
    n_out = target.shape[-1]
    loss = 0
    p_means = []
    t_means = []
    all_means = self.means
    all_stds = self.stds
    for k in range(n_out):
        if self.limit_part is not None and k in self.limit_part:
            continue
        p0 = output[:, :, k]
        t0 = target[:, :, k]
        # for water balance loss
        if all_means is not None:
            # denormalize for q and et
            p1 = p0 * all_stds[k] + all_means[k]
            t1 = t0 * all_stds[k] + all_means[k]
            p2 = (10**p1 - 0.1) ** 2
            t2 = (10**t1 - 0.1) ** 2
            p_mean = torch.nanmean(p2, dim=0)
            t_mean = torch.nanmean(t2, dim=0)
        else:
            p_mean = torch.nanmean(p0, dim=0)
            t_mean = torch.nanmean(t0, dim=0)
        p_means.append(p_mean)
        t_means.append(t_mean)
        # for mtl normal loss
        mask = t0 == t0
        p = p0[mask]
        t = t0[mask]
        if self.data_gap[k] > 0:
            p, t = deal_gap_data(p0, t0, self.data_gap[k], self.device)
        if type(self.loss_funcs) is list:
            temp = self.item_weight[k] * self.loss_funcs[k](p, t)
        else:
            temp = self.item_weight[k] * self.loss_funcs(p, t)
        # sum of all k-th loss
        loss = loss + temp
    # water balance loss
    p_mean_q_plus_et = torch.sum(torch.stack(p_means), dim=0)
    t_mean_q_plus_et = torch.sum(torch.stack(t_means), dim=0)
    wb_ones = torch.ones_like(t_mean_q_plus_et)
    if self.wb_loss_func is None:
        if type(self.loss_funcs) is list:
            # if wb_loss_func is None, we use the first loss function in loss_funcs
            wb_loss = self.loss_funcs[0](p_mean_q_plus_et, t_mean_q_plus_et)
            wb_1loss = self.loss_funcs[0](p_mean_q_plus_et, wb_ones)
        else:
            wb_loss = self.loss_funcs(p_mean_q_plus_et, t_mean_q_plus_et)
            wb_1loss = self.loss_funcs(p_mean_q_plus_et, wb_ones)
    else:
        wb_loss = self.wb_loss_func(p_mean_q_plus_et, t_mean_q_plus_et)
        wb_1loss = self.wb_loss_func(p_mean_q_plus_et, wb_ones)
    return (
        self.alpha * wb_loss
        + (1 - self.alpha - self.beta) * loss
        + self.beta * wb_1loss
    )

NSELoss (Module)

Source code in torchhydro/models/crits.py
class NSELoss(torch.nn.Module):
    # Same as Fredrick 2019
    def __init__(self):
        super(NSELoss, self).__init__()

    def forward(self, output, target):
        Ngage = target.shape[1]
        losssum = 0
        nsample = 0
        for ii in range(Ngage):
            t0 = target[:, ii, 0]
            mask = t0 == t0
            if len(mask[mask]) > 0:
                p0 = output[:, ii, 0]
                p = p0[mask]
                t = t0[mask]
                tmean = t.mean()
                SST = torch.sum((t - tmean) ** 2)
                SSRes = torch.sum((t - p) ** 2)
                temp = SSRes / ((torch.sqrt(SST) + 0.1) ** 2)
                # original NSE
                # temp = SSRes / SST
                losssum = losssum + temp
                nsample = nsample + 1
        return losssum / nsample

forward(self, output, target)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/crits.py
def forward(self, output, target):
    Ngage = target.shape[1]
    losssum = 0
    nsample = 0
    for ii in range(Ngage):
        t0 = target[:, ii, 0]
        mask = t0 == t0
        if len(mask[mask]) > 0:
            p0 = output[:, ii, 0]
            p = p0[mask]
            t = t0[mask]
            tmean = t.mean()
            SST = torch.sum((t - tmean) ** 2)
            SSRes = torch.sum((t - p) ** 2)
            temp = SSRes / ((torch.sqrt(SST) + 0.1) ** 2)
            # original NSE
            # temp = SSRes / SST
            losssum = losssum + temp
            nsample = nsample + 1
    return losssum / nsample

NegativeLogLikelihood (Module)

target -> True y output -> predicted distribution

Source code in torchhydro/models/crits.py
class NegativeLogLikelihood(torch.nn.Module):
    """
    target -> True y
    output -> predicted distribution
    """

    def __init__(self):
        super().__init__()

    def forward(self, output: torch.distributions, target: torch.Tensor):
        """
        calculates NegativeLogLikelihood
        """
        return -output.log_prob(target).sum()

forward(self, output, target)

calculates NegativeLogLikelihood

Source code in torchhydro/models/crits.py
def forward(self, output: torch.distributions, target: torch.Tensor):
    """
    calculates NegativeLogLikelihood
    """
    return -output.log_prob(target).sum()

PESLoss (Module)

Source code in torchhydro/models/crits.py
class PESLoss(torch.nn.Module):
    def __init__(self):
        """
        PES Loss: MSE × sigmoid(MSE)

        This loss function applies a sigmoid activation to MSE and then multiplies it with MSE,
        creating a non-linear penalty that increases more gradually for larger errors.
        """
        super(PESLoss, self).__init__()
        self.mse = torch.nn.MSELoss(reduction="none")

    def forward(self, output: torch.Tensor, target: torch.Tensor):
        mse_value = self.mse(output, target)
        sigmoid_mse = torch.sigmoid(mse_value)
        return mse_value * sigmoid_mse

__init__(self) special

PES Loss: MSE × sigmoid(MSE)

This loss function applies a sigmoid activation to MSE and then multiplies it with MSE, creating a non-linear penalty that increases more gradually for larger errors.

Source code in torchhydro/models/crits.py
def __init__(self):
    """
    PES Loss: MSE × sigmoid(MSE)

    This loss function applies a sigmoid activation to MSE and then multiplies it with MSE,
    creating a non-linear penalty that increases more gradually for larger errors.
    """
    super(PESLoss, self).__init__()
    self.mse = torch.nn.MSELoss(reduction="none")

forward(self, output, target)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/crits.py
def forward(self, output: torch.Tensor, target: torch.Tensor):
    mse_value = self.mse(output, target)
    sigmoid_mse = torch.sigmoid(mse_value)
    return mse_value * sigmoid_mse

PenalizedMSELoss (Module)

Returns MSE using: target -> True y output -> Predtion by model source: https://discuss.pytorch.org/t/rmse-loss-function/16540/3

Source code in torchhydro/models/crits.py
class PenalizedMSELoss(torch.nn.Module):
    """
    Returns MSE using:
    target -> True y
    output -> Predtion by model
    source: https://discuss.pytorch.org/t/rmse-loss-function/16540/3
    """

    def __init__(self, variance_penalty=0.0):
        super().__init__()
        self.mse = torch.nn.MSELoss()
        self.variance_penalty = variance_penalty

    def forward(self, output: torch.Tensor, target: torch.Tensor):
        return self.mse(target, output) + self.variance_penalty * torch.std(
            torch.sub(target, output)
        )

forward(self, output, target)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/crits.py
def forward(self, output: torch.Tensor, target: torch.Tensor):
    return self.mse(target, output) + self.variance_penalty * torch.std(
        torch.sub(target, output)
    )

QuantileLoss (Module)

From https://medium.com/the-artificial-impostor/quantile-regression-part-2-6fdbc26b2629

Source code in torchhydro/models/crits.py
class QuantileLoss(torch.nn.Module):
    """From https://medium.com/the-artificial-impostor/quantile-regression-part-2-6fdbc26b2629"""

    def __init__(self, quantiles):
        super().__init__()
        self.quantiles = quantiles

    def forward(self, preds, target):
        assert not target.requires_grad
        assert preds.size(0) == target.size(0)
        losses = []
        for i, q in enumerate(self.quantiles):
            mask = ~torch.isnan(target[:, :, i])
            errors = target[:, :, i][mask] - preds[:, :, i][mask]
            losses.append(torch.max((q - 1) * errors, q * errors))
        return torch.mean(torch.cat(losses, dim=0), dim=0)

forward(self, preds, target)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/crits.py
def forward(self, preds, target):
    assert not target.requires_grad
    assert preds.size(0) == target.size(0)
    losses = []
    for i, q in enumerate(self.quantiles):
        mask = ~torch.isnan(target[:, :, i])
        errors = target[:, :, i][mask] - preds[:, :, i][mask]
        losses.append(torch.max((q - 1) * errors, q * errors))
    return torch.mean(torch.cat(losses, dim=0), dim=0)

RMSELoss (Module)

Source code in torchhydro/models/crits.py
class RMSELoss(torch.nn.Module):
    def __init__(self, variance_penalty=0.0):
        """
        Calculate RMSE

        using:
            target -> True y
            output -> Prediction by model
            source: https://discuss.pytorch.org/t/rmse-loss-function/16540/3

        Parameters
        ----------
        variance_penalty
            penalty for big variance; default is 0
        """
        super().__init__()
        self.mse = torch.nn.MSELoss()
        self.variance_penalty = variance_penalty

    def forward(self, output: torch.Tensor, target: torch.Tensor):
        if len(output) <= 1 or self.variance_penalty <= 0.0:
            return torch.sqrt(self.mse(target, output))
        diff = torch.sub(target, output)
        std_dev = torch.std(diff)
        var_penalty = self.variance_penalty * std_dev

        return torch.sqrt(self.mse(target, output)) + var_penalty

__init__(self, variance_penalty=0.0) special

Calculate RMSE

!!! using target -> True y output -> Prediction by model source: https://discuss.pytorch.org/t/rmse-loss-function/16540/3

Parameters

variance_penalty penalty for big variance; default is 0

Source code in torchhydro/models/crits.py
def __init__(self, variance_penalty=0.0):
    """
    Calculate RMSE

    using:
        target -> True y
        output -> Prediction by model
        source: https://discuss.pytorch.org/t/rmse-loss-function/16540/3

    Parameters
    ----------
    variance_penalty
        penalty for big variance; default is 0
    """
    super().__init__()
    self.mse = torch.nn.MSELoss()
    self.variance_penalty = variance_penalty

forward(self, output, target)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/crits.py
def forward(self, output: torch.Tensor, target: torch.Tensor):
    if len(output) <= 1 or self.variance_penalty <= 0.0:
        return torch.sqrt(self.mse(target, output))
    diff = torch.sub(target, output)
    std_dev = torch.std(diff)
    var_penalty = self.variance_penalty * std_dev

    return torch.sqrt(self.mse(target, output)) + var_penalty

RmseLoss (Module)

Source code in torchhydro/models/crits.py
class RmseLoss(torch.nn.Module):
    def __init__(self):
        """
        RMSE loss which could ignore NaN values

        Now we only support 3-d tensor and 1-d tensor
        """
        super(RmseLoss, self).__init__()

    def forward(self, output, target):
        if target.dim() == 1:
            mask = target == target
            p = output[mask]
            t = target[mask]
            return torch.sqrt(((p - t) ** 2).mean())
        ny = target.shape[2]
        loss = 0
        for k in range(ny):
            p0 = output[:, :, k]
            t0 = target[:, :, k]
            mask = t0 == t0
            p = p0[mask]
            p = torch.where(torch.isnan(p), torch.full_like(p, 0), p)
            t = t0[mask]
            t = torch.where(torch.isnan(t), torch.full_like(t, 0), t)
            temp = torch.sqrt(((p - t) ** 2).mean())
            loss = loss + temp
        return loss

__init__(self) special

RMSE loss which could ignore NaN values

Now we only support 3-d tensor and 1-d tensor

Source code in torchhydro/models/crits.py
def __init__(self):
    """
    RMSE loss which could ignore NaN values

    Now we only support 3-d tensor and 1-d tensor
    """
    super(RmseLoss, self).__init__()

forward(self, output, target)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/crits.py
def forward(self, output, target):
    if target.dim() == 1:
        mask = target == target
        p = output[mask]
        t = target[mask]
        return torch.sqrt(((p - t) ** 2).mean())
    ny = target.shape[2]
    loss = 0
    for k in range(ny):
        p0 = output[:, :, k]
        t0 = target[:, :, k]
        mask = t0 == t0
        p = p0[mask]
        p = torch.where(torch.isnan(p), torch.full_like(p, 0), p)
        t = t0[mask]
        t = torch.where(torch.isnan(t), torch.full_like(t, 0), t)
        temp = torch.sqrt(((p - t) ** 2).mean())
        loss = loss + temp
    return loss

SigmaLoss (Module)

Source code in torchhydro/models/crits.py
class SigmaLoss(torch.nn.Module):
    def __init__(self, prior="gauss"):
        super(SigmaLoss, self).__init__()
        self.reduction = "elementwise_mean"
        self.prior = None if prior == "" else prior.split("+")

    def forward(self, output, target):
        ny = target.shape[-1]
        lossMean = 0
        for k in range(ny):
            p0 = output[:, :, k * 2]
            s0 = output[:, :, k * 2 + 1]
            t0 = target[:, :, k]
            mask = t0 == t0
            p = p0[mask]
            s = s0[mask]
            t = t0[mask]
            if self.prior[0] == "gauss":
                loss = torch.exp(-s).mul((p - t) ** 2) / 2 + s / 2
            elif self.prior[0] == "invGamma":
                c1 = float(self.prior[1])
                c2 = float(self.prior[2])
                nt = p.shape[0]
                loss = (
                    torch.exp(-s).mul((p - t) ** 2 + c2 / nt) / 2
                    + (1 / 2 + c1 / nt) * s
                )
            lossMean = lossMean + torch.mean(loss)
        return lossMean

forward(self, output, target)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/crits.py
def forward(self, output, target):
    ny = target.shape[-1]
    lossMean = 0
    for k in range(ny):
        p0 = output[:, :, k * 2]
        s0 = output[:, :, k * 2 + 1]
        t0 = target[:, :, k]
        mask = t0 == t0
        p = p0[mask]
        s = s0[mask]
        t = t0[mask]
        if self.prior[0] == "gauss":
            loss = torch.exp(-s).mul((p - t) ** 2) / 2 + s / 2
        elif self.prior[0] == "invGamma":
            c1 = float(self.prior[1])
            c2 = float(self.prior[2])
            nt = p.shape[0]
            loss = (
                torch.exp(-s).mul((p - t) ** 2 + c2 / nt) / 2
                + (1 / 2 + c1 / nt) * s
            )
        lossMean = lossMean + torch.mean(loss)
    return lossMean

UncertaintyWeights (Module)

Uncertainty Weights (UW).

This method is proposed in Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics (CVPR 2018) <https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf>_ \ and implemented by us.

Source code in torchhydro/models/crits.py
class UncertaintyWeights(torch.nn.Module):
    r"""Uncertainty Weights (UW).

    This method is proposed in `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics (CVPR 2018) <https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf>`_ \
    and implemented by us.

    """

    def __init__(
        self,
        loss_funcs: Union[torch.nn.Module, list],
        data_gap: list = None,
        device: list = None,
        limit_part: list = None,
    ):
        if data_gap is None:
            data_gap = [0, 2]
        if device is None:
            device = [0]
        super(UncertaintyWeights, self).__init__()
        self.loss_funcs = loss_funcs
        self.data_gap = data_gap
        self.device = get_the_device(device)
        self.limit_part = limit_part

    def forward(self, output, target, log_vars):
        """

        Parameters
        ----------
        output
        target
        log_vars
            sigma in uncertainty weighting;
            default is None, meaning we manually set weights for different target's loss;
            more info could be seen in
            https://libmtl.readthedocs.io/en/latest/docs/_autoapi/LibMTL/weighting/index.html#LibMTL.weighting.UW

        Returns
        -------
        torch.Tensor
            multi-task loss by uncertainty weighting method
        """
        n_out = target.shape[-1]
        loss = 0
        for k in range(n_out):
            precision = torch.exp(-log_vars[k])
            if self.limit_part is not None and k in self.limit_part:
                continue
            p0 = output[:, :, k]
            t0 = target[:, :, k]
            mask = t0 == t0
            p = p0[mask]
            t = t0[mask]
            if self.data_gap[k] > 0:
                p, t = deal_gap_data(p0, t0, self.data_gap[k], self.device)
            if type(self.loss_funcs) is list:
                temp = self.loss_funcs[k](p, t)
            else:
                temp = self.loss_funcs(p, t)
            loss += torch.sum(precision * temp + log_vars[k], -1)
        return loss

forward(self, output, target, log_vars)

Parameters

output target log_vars sigma in uncertainty weighting; default is None, meaning we manually set weights for different target's loss; more info could be seen in https://libmtl.readthedocs.io/en/latest/docs/_autoapi/LibMTL/weighting/index.html#LibMTL.weighting.UW

Returns

torch.Tensor multi-task loss by uncertainty weighting method

Source code in torchhydro/models/crits.py
def forward(self, output, target, log_vars):
    """

    Parameters
    ----------
    output
    target
    log_vars
        sigma in uncertainty weighting;
        default is None, meaning we manually set weights for different target's loss;
        more info could be seen in
        https://libmtl.readthedocs.io/en/latest/docs/_autoapi/LibMTL/weighting/index.html#LibMTL.weighting.UW

    Returns
    -------
    torch.Tensor
        multi-task loss by uncertainty weighting method
    """
    n_out = target.shape[-1]
    loss = 0
    for k in range(n_out):
        precision = torch.exp(-log_vars[k])
        if self.limit_part is not None and k in self.limit_part:
            continue
        p0 = output[:, :, k]
        t0 = target[:, :, k]
        mask = t0 == t0
        p = p0[mask]
        t = t0[mask]
        if self.data_gap[k] > 0:
            p, t = deal_gap_data(p0, t0, self.data_gap[k], self.device)
        if type(self.loss_funcs) is list:
            temp = self.loss_funcs[k](p, t)
        else:
            temp = self.loss_funcs(p, t)
        loss += torch.sum(precision * temp + log_vars[k], -1)
    return loss

deal_gap_data(output, target, data_gap, device)

How to handle with gap data

When there are NaN values in observation, we will perform a "reduce" operation on prediction. For example, pred = [0,1,2,3,4], obs=[5, nan, nan, 6, nan]; the "reduce" is sum; then, pred_sum = [0+1+2, 3+4], obs_sum=[5,6], loss = loss_func(pred_sum, obs_sum). Notice: when "sum", actually final index is not chosen, because the whole observation may be [5, nan, nan, 6, nan, nan, 7, nan, nan], 6 means sum of three elements. Just as the rho is 5, the final one is not chosen

Parameters

output model output for k-th variable target target for k-th variable data_gap data_gap=1: reduce is sum data_gap=2: reduce is mean device where to save the data

Returns

tuple[tensor, tensor] output and target after dealing with gap

Source code in torchhydro/models/crits.py
def deal_gap_data(output, target, data_gap, device):
    """
    How to handle with gap data

    When there are NaN values in observation, we will perform a "reduce" operation on prediction.
    For example, pred = [0,1,2,3,4], obs=[5, nan, nan, 6, nan]; the "reduce" is sum;
    then, pred_sum = [0+1+2, 3+4], obs_sum=[5,6], loss = loss_func(pred_sum, obs_sum).
    Notice: when "sum", actually final index is not chosen,
    because the whole observation may be [5, nan, nan, 6, nan, nan, 7, nan, nan], 6 means sum of three elements.
    Just as the rho is 5, the final one is not chosen

    Parameters
    ----------
    output
        model output for k-th variable
    target
        target for k-th variable
    data_gap
        data_gap=1: reduce is sum
        data_gap=2: reduce is mean
    device
        where to save the data

    Returns
    -------
    tuple[tensor, tensor]
        output and target after dealing with gap
    """
    # all members in a batch has different NaN-gap, so we need a loop
    seg_p_lst = []
    seg_t_lst = []
    for j in range(target.shape[1]):
        non_nan_idx = torch.nonzero(
            ~torch.isnan(target[:, j]), as_tuple=False
        ).squeeze()
        if len(non_nan_idx) < 1:
            raise ArithmeticError("All NaN elements, please check your data")

        # 使用 cumsum 生成 scatter_index
        is_not_nan = ~torch.isnan(target[:, j])
        cumsum_is_not_nan = torch.cumsum(is_not_nan.to(torch.int), dim=0)
        first_non_nan = non_nan_idx[0]
        scatter_index = torch.full_like(
            target[:, j], fill_value=-1, dtype=torch.long
        )  # 将所有值初始化为 -1
        scatter_index[first_non_nan:] = cumsum_is_not_nan[first_non_nan:] - 1
        scatter_index = scatter_index.to(device=device)

        # 创建掩码,只保留有效的索引
        valid_mask = scatter_index >= 0

        if data_gap == 1:
            seg = torch.zeros(
                len(non_nan_idx), device=device, dtype=output.dtype
            ).scatter_add_(0, scatter_index[valid_mask], output[valid_mask, j])
            # for sum, better exclude final non-nan value as it didn't include all necessary periods
            seg_p_lst.append(seg[:-1])
            seg_t_lst.append(target[non_nan_idx[:-1], j])

        elif data_gap == 2:
            counts = torch.zeros(
                len(non_nan_idx), device=device, dtype=output.dtype
            ).scatter_add_(
                0,
                scatter_index[valid_mask],
                torch.ones_like(output[valid_mask, j], dtype=output.dtype),
            )
            seg = torch.zeros(
                len(non_nan_idx), device=device, dtype=output.dtype
            ).scatter_add_(0, scatter_index[valid_mask], output[valid_mask, j])
            seg = seg / counts.clamp(min=1)
            # for mean, we can include all periods
            seg_p_lst.append(seg)
            seg_t_lst.append(target[non_nan_idx, j])
        else:
            raise NotImplementedError(
                "We have not provided this reduce way now!! Please choose 1 or 2!!"
            )

    p = torch.cat(seg_p_lst)
    t = torch.cat(seg_t_lst)
    return p, t

l1_regularizer(model, lambda_l1=0.01)

source: https://stackoverflow.com/questions/58172188/how-to-add-l1-regularization-to-pytorch-nn-model

Source code in torchhydro/models/crits.py
def l1_regularizer(model, lambda_l1=0.01):
    """
    source: https://stackoverflow.com/questions/58172188/how-to-add-l1-regularization-to-pytorch-nn-model
    """
    lossl1 = 0
    for model_param_name, model_param_value in model.named_parameters():
        if model_param_name.endswith("weight"):
            lossl1 += lambda_l1 * model_param_value.abs().sum()
        return lossl1

orth_regularizer(model, lambda_orth=0.01)

source: https://stackoverflow.com/questions/58172188/how-to-add-l1-regularization-to-pytorch-nn-model

Source code in torchhydro/models/crits.py
def orth_regularizer(model, lambda_orth=0.01):
    """
    source: https://stackoverflow.com/questions/58172188/how-to-add-l1-regularization-to-pytorch-nn-model
    """
    lossorth = 0
    for model_param_name, model_param_value in model.named_parameters():
        if model_param_name.endswith("weight"):
            param_flat = model_param_value.view(model_param_value.shape[0], -1)
            sym = torch.mm(param_flat, torch.t(param_flat))
            sym -= torch.eye(param_flat.shape[0])
            lossorth += lambda_orth * sym.sum()

        return lossorth

cudnnlstm

Author: MHPI group, Wenyu Ouyang Date: 2021-12-31 11:08:29 LastEditTime: 2024-10-09 16:36:34 LastEditors: Wenyu Ouyang Description: LSTM with dropout implemented by Kuai Fang and more LSTMs using it FilePath: orchhydro orchhydro\models\cudnnlstm.py Copyright (c) 2021-2022 MHPI group, Wenyu Ouyang. All rights reserved.

CNN1dKernel (Module)

Source code in torchhydro/models/cudnnlstm.py
class CNN1dKernel(torch.nn.Module):
    def __init__(self, *, ninchannel=1, nkernel=3, kernelSize=3, stride=1, padding=0):
        super(CNN1dKernel, self).__init__()
        self.cnn1d = torch.nn.Conv1d(
            in_channels=ninchannel,
            out_channels=nkernel,
            kernel_size=kernelSize,
            padding=padding,
            stride=stride,
        )
        self.name = "CNN1dkernel"
        self.is_legacy = True

    def forward(self, x):
        return F.relu(self.cnn1d(x))

forward(self, x)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/cudnnlstm.py
def forward(self, x):
    return F.relu(self.cnn1d(x))

CNN1dLCmodel (Module)

Source code in torchhydro/models/cudnnlstm.py
class CNN1dLCmodel(nn.Module):
    # Directly add the CNN extracted features into LSTM inputSize
    def __init__(
        self,
        nx,
        ny,
        nobs,
        hidden_size,
        n_kernel: Union[list, tuple] = (10, 5),
        kernel_size: Union[list, tuple] = (3, 3),
        stride: Union[list, tuple] = (2, 1),
        dr=0.5,
        pool_opt=None,
        cnn_dr=0.5,
        cat_first=True,
    ):
        """cat_first means: we will concatenate the CNN output with the x, then input them to the CudnnLstm model;
        if not cat_first, it is relu_first, meaning we will relu the CNN output firstly, then concatenate it with x
        """
        # two convolutional layer
        super(CNN1dLCmodel, self).__init__()
        # N_cnn_out代表输出的特征数量
        # nx代表历史的输入
        # ny代表最后线性层输出的维度,如果只预报流量,则为1
        # nobs代表要输入到CNN的维度
        # hidden_size是线性层的隐藏层的节点数
        self.nx = nx
        self.ny = ny
        self.obs = nobs
        self.hiddenSize = hidden_size
        n_layer = len(n_kernel)
        self.features = nn.Sequential()
        n_in_chan = 1
        lout = nobs
        for ii in range(n_layer):
            conv_layer = CNN1dKernel(
                ninchannel=n_in_chan,
                nkernel=n_kernel[ii],
                kernelSize=kernel_size[ii],
                stride=stride[ii],
            )
            self.features.add_module("CnnLayer%d" % (ii + 1), conv_layer)
            if cnn_dr != 0.0:
                self.features.add_module("dropout%d" % (ii + 1), nn.Dropout(p=cnn_dr))
            n_in_chan = n_kernel[ii]
            lout = cal_conv_size(lin=lout, kernel=kernel_size[ii], stride=stride[ii])
            self.features.add_module("Relu%d" % (ii + 1), nn.ReLU())
            if pool_opt is not None:
                self.features.add_module(
                    "Pooling%d" % (ii + 1), nn.MaxPool1d(pool_opt[ii])
                )
                lout = cal_pool_size(lin=lout, kernel=pool_opt[ii])
        self.N_cnn_out = int(
            lout * n_kernel[-1]
        )  # total CNN feature number after convolution
        self.cat_first = cat_first
        # 要不要先拼接?
        # 先拼接,则代表线性层中,输入的维度是未来的降水等输入输出的CNN特征维度,和历史观测的等时间序列的特征数量,通过线性层合并成一个,然后再把这些特征输出到一个线性层中
        # 如果不拼接,那么历史观测数据先进入一个线性层
        if cat_first:
            nf = self.N_cnn_out + nx
            self.linearIn = torch.nn.Linear(nf, hidden_size)
            # CudnnLstm除了最基础的部分以外,主要是有个h和c两个门为空的纠错,这个在论文里讲述的是因为可能输入缺失,但是又不想用插值处理
            # 不想用插值处理是因为认为会暴露未来信息
            # 采用了置零操作,原文的表述是这种缺失点较少,在模型的不断更新参数后,这种置零的影响对于模型的输出影响很小
            self.lstm = CudnnLstm(
                input_size=hidden_size, hidden_size=hidden_size, dr=dr
            )
        else:
            nf = self.N_cnn_out + hidden_size
            self.linearIn = torch.nn.Linear(nx, hidden_size)
            self.lstm = CudnnLstm(input_size=nf, hidden_size=hidden_size, dr=dr)
        self.linearOut = torch.nn.Linear(hidden_size, ny)
        self.gpu = 1

    def forward(self, x, z, do_drop_mc=False):
        # z = n_grid*nVar add a channel dimension
        # z = z.t()
        n_grid, nobs, _ = z.shape
        z = z.reshape(n_grid * nobs, 1)
        n_t, bs, n_var = x.shape
        # add a channel dimension
        z = torch.unsqueeze(z, dim=1)
        z0 = self.features(z)
        # z0 = (n_grid) * n_kernel * sizeafterconv
        z0 = z0.view(n_grid, self.N_cnn_out).repeat(n_t, 1, 1)
        if self.cat_first:
            x = torch.cat((x, z0), dim=2)
            x0 = F.relu(self.linearIn(x))
        else:
            x = F.relu(self.linearIn(x))
            x0 = torch.cat((x, z0), dim=2)
        out_lstm, (hn, cn) = self.lstm(x0, do_drop_mc=do_drop_mc)
        return self.linearOut(out_lstm)

__init__(self, nx, ny, nobs, hidden_size, n_kernel=(10, 5), kernel_size=(3, 3), stride=(2, 1), dr=0.5, pool_opt=None, cnn_dr=0.5, cat_first=True) special

cat_first means: we will concatenate the CNN output with the x, then input them to the CudnnLstm model; if not cat_first, it is relu_first, meaning we will relu the CNN output firstly, then concatenate it with x

Source code in torchhydro/models/cudnnlstm.py
def __init__(
    self,
    nx,
    ny,
    nobs,
    hidden_size,
    n_kernel: Union[list, tuple] = (10, 5),
    kernel_size: Union[list, tuple] = (3, 3),
    stride: Union[list, tuple] = (2, 1),
    dr=0.5,
    pool_opt=None,
    cnn_dr=0.5,
    cat_first=True,
):
    """cat_first means: we will concatenate the CNN output with the x, then input them to the CudnnLstm model;
    if not cat_first, it is relu_first, meaning we will relu the CNN output firstly, then concatenate it with x
    """
    # two convolutional layer
    super(CNN1dLCmodel, self).__init__()
    # N_cnn_out代表输出的特征数量
    # nx代表历史的输入
    # ny代表最后线性层输出的维度,如果只预报流量,则为1
    # nobs代表要输入到CNN的维度
    # hidden_size是线性层的隐藏层的节点数
    self.nx = nx
    self.ny = ny
    self.obs = nobs
    self.hiddenSize = hidden_size
    n_layer = len(n_kernel)
    self.features = nn.Sequential()
    n_in_chan = 1
    lout = nobs
    for ii in range(n_layer):
        conv_layer = CNN1dKernel(
            ninchannel=n_in_chan,
            nkernel=n_kernel[ii],
            kernelSize=kernel_size[ii],
            stride=stride[ii],
        )
        self.features.add_module("CnnLayer%d" % (ii + 1), conv_layer)
        if cnn_dr != 0.0:
            self.features.add_module("dropout%d" % (ii + 1), nn.Dropout(p=cnn_dr))
        n_in_chan = n_kernel[ii]
        lout = cal_conv_size(lin=lout, kernel=kernel_size[ii], stride=stride[ii])
        self.features.add_module("Relu%d" % (ii + 1), nn.ReLU())
        if pool_opt is not None:
            self.features.add_module(
                "Pooling%d" % (ii + 1), nn.MaxPool1d(pool_opt[ii])
            )
            lout = cal_pool_size(lin=lout, kernel=pool_opt[ii])
    self.N_cnn_out = int(
        lout * n_kernel[-1]
    )  # total CNN feature number after convolution
    self.cat_first = cat_first
    # 要不要先拼接?
    # 先拼接,则代表线性层中,输入的维度是未来的降水等输入输出的CNN特征维度,和历史观测的等时间序列的特征数量,通过线性层合并成一个,然后再把这些特征输出到一个线性层中
    # 如果不拼接,那么历史观测数据先进入一个线性层
    if cat_first:
        nf = self.N_cnn_out + nx
        self.linearIn = torch.nn.Linear(nf, hidden_size)
        # CudnnLstm除了最基础的部分以外,主要是有个h和c两个门为空的纠错,这个在论文里讲述的是因为可能输入缺失,但是又不想用插值处理
        # 不想用插值处理是因为认为会暴露未来信息
        # 采用了置零操作,原文的表述是这种缺失点较少,在模型的不断更新参数后,这种置零的影响对于模型的输出影响很小
        self.lstm = CudnnLstm(
            input_size=hidden_size, hidden_size=hidden_size, dr=dr
        )
    else:
        nf = self.N_cnn_out + hidden_size
        self.linearIn = torch.nn.Linear(nx, hidden_size)
        self.lstm = CudnnLstm(input_size=nf, hidden_size=hidden_size, dr=dr)
    self.linearOut = torch.nn.Linear(hidden_size, ny)
    self.gpu = 1

forward(self, x, z, do_drop_mc=False)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/cudnnlstm.py
def forward(self, x, z, do_drop_mc=False):
    # z = n_grid*nVar add a channel dimension
    # z = z.t()
    n_grid, nobs, _ = z.shape
    z = z.reshape(n_grid * nobs, 1)
    n_t, bs, n_var = x.shape
    # add a channel dimension
    z = torch.unsqueeze(z, dim=1)
    z0 = self.features(z)
    # z0 = (n_grid) * n_kernel * sizeafterconv
    z0 = z0.view(n_grid, self.N_cnn_out).repeat(n_t, 1, 1)
    if self.cat_first:
        x = torch.cat((x, z0), dim=2)
        x0 = F.relu(self.linearIn(x))
    else:
        x = F.relu(self.linearIn(x))
        x0 = torch.cat((x, z0), dim=2)
    out_lstm, (hn, cn) = self.lstm(x0, do_drop_mc=do_drop_mc)
    return self.linearOut(out_lstm)

CpuLstmModel (Module)

Cpu version of CudnnLstmModel

Source code in torchhydro/models/cudnnlstm.py
class CpuLstmModel(nn.Module):
    """Cpu version of CudnnLstmModel"""

    def __init__(self, *, n_input_features, n_output_features, n_hidden_states, dr=0.5):
        super(CpuLstmModel, self).__init__()
        self.nx = n_input_features
        self.ny = n_output_features
        self.hiddenSize = n_hidden_states
        self.ct = 0
        self.nLayer = 1
        self.linearIn = torch.nn.Linear(n_input_features, n_hidden_states)
        self.lstm = LstmCellTied(
            input_size=n_hidden_states,
            hidden_size=n_hidden_states,
            dr=dr,
            dr_method="drW",
            gpu=-1,
        )
        self.linearOut = torch.nn.Linear(n_hidden_states, n_output_features)
        self.gpu = -1

    def forward(self, x, do_drop_mc=False):
        # x0 = F.relu(self.linearIn(x))
        # outLSTM, (hn, cn) = self.lstm(x0, do_drop_mc=do_drop_mc)
        # out = self.linearOut(outLSTM)
        # return out
        nt, ngrid, nx = x.shape
        yt = torch.zeros(ngrid, 1)
        out = torch.zeros(nt, ngrid, self.ny)
        ht = None
        ct = None
        reset_mask = True
        for t in range(nt):
            xt = x[t, :, :]
            xt = torch.where(torch.isnan(xt), torch.full_like(xt, 0), xt)
            x0 = F.relu(self.linearIn(xt))
            ht, ct = self.lstm(x0, hidden=(ht, ct), do_reset_mask=reset_mask)
            yt = self.linearOut(ht)
            reset_mask = False
            out[t, :, :] = yt
        return out

forward(self, x, do_drop_mc=False)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/cudnnlstm.py
def forward(self, x, do_drop_mc=False):
    # x0 = F.relu(self.linearIn(x))
    # outLSTM, (hn, cn) = self.lstm(x0, do_drop_mc=do_drop_mc)
    # out = self.linearOut(outLSTM)
    # return out
    nt, ngrid, nx = x.shape
    yt = torch.zeros(ngrid, 1)
    out = torch.zeros(nt, ngrid, self.ny)
    ht = None
    ct = None
    reset_mask = True
    for t in range(nt):
        xt = x[t, :, :]
        xt = torch.where(torch.isnan(xt), torch.full_like(xt, 0), xt)
        x0 = F.relu(self.linearIn(xt))
        ht, ct = self.lstm(x0, hidden=(ht, ct), do_reset_mask=reset_mask)
        yt = self.linearOut(ht)
        reset_mask = False
        out[t, :, :] = yt
    return out

CudnnLstm (Module)

LSTM with dropout implemented by Kuai Fang: https://github.com/mhpi/hydroDL/blob/release/hydroDL/model/rnn.py

Only run in GPU; the CPU version is LstmCellTied in this file

Source code in torchhydro/models/cudnnlstm.py
class CudnnLstm(nn.Module):
    """
    LSTM with dropout implemented by Kuai Fang: https://github.com/mhpi/hydroDL/blob/release/hydroDL/model/rnn.py

    Only run in GPU; the CPU version is LstmCellTied in this file
    """

    def __init__(self, *, input_size, hidden_size, dr=0.5):
        """

        Parameters
        ----------
        input_size
            number of neurons in input layer
        hidden_size
            number of neurons in hidden layer
        dr
            dropout rate
        """
        super(CudnnLstm, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.dr = dr
        self.w_ih = Parameter(torch.Tensor(hidden_size * 4, input_size))
        self.w_hh = Parameter(torch.Tensor(hidden_size * 4, hidden_size))
        self.b_ih = Parameter(torch.Tensor(hidden_size * 4))
        self.b_hh = Parameter(torch.Tensor(hidden_size * 4))
        self._all_weights = [["w_ih", "w_hh", "b_ih", "b_hh"]]
        # self.cuda()
        # set the mask
        self.reset_mask()
        # initialize the weights and bias of the model
        self.reset_parameters()

    def _apply(self, fn):
        """just use the default _apply function

        Parameters
        ----------
        fn : function
            _description_

        Returns
        -------
        _type_
            _description_
        """
        # _apply is always recursively applied to all submodules and the module itself such as move all to GPU
        return super()._apply(fn)

    def __setstate__(self, d):
        """a python magic function to set the state of the object used for deserialization

        Parameters
        ----------
        d : _type_
            _description_
        """
        super().__setstate__(d)
        # set a default value for _data_ptrs
        self.__dict__.setdefault("_data_ptrs", [])
        if "all_weights" in d:
            self._all_weights = d["all_weights"]
        if isinstance(self._all_weights[0][0], str):
            return
        self._all_weights = [["w_ih", "w_hh", "b_ih", "b_hh"]]

    def reset_mask(self):
        """generate mask for dropout"""
        self.mask_w_ih = create_mask(self.w_ih, self.dr)
        self.mask_w_hh = create_mask(self.w_hh, self.dr)

    def reset_parameters(self):
        """initialize the weights and bias of the model using Xavier initialization"""
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            # uniform distribution between -stdv and stdv for the weights and bias initialization
            weight.data.uniform_(-stdv, stdv)

    def forward(self, input, hx=None, cx=None, do_drop_mc=False, dropout_false=False):
        # dropout_false: it will ensure do_drop is false, unless do_drop_mc is true
        if dropout_false and (not do_drop_mc):
            do_drop = False
        elif self.dr > 0 and (do_drop_mc is True or self.training is True):
            # if train mode and set self.dr > 0, then do_drop is true
            # so each time the model forward function is called, the dropout is applied
            do_drop = True
        else:
            do_drop = False
        # input must be a tensor with shape (seq_len, batch, input_size)
        batch_size = input.size(1)

        if hx is None:
            hx = input.new_zeros(1, batch_size, self.hidden_size, requires_grad=False)
        if cx is None:
            cx = input.new_zeros(1, batch_size, self.hidden_size, requires_grad=False)

        # handle = torch.backends.cudnn.get_handle()
        if do_drop is True:
            # cuDNN backend - disabled flat weight
            # NOTE: each time the mask is newly generated, so for each batch the mask is different
            self.reset_mask()
            # apply the mask to the weights
            weight = [
                DropMask.apply(self.w_ih, self.mask_w_ih, True),
                DropMask.apply(self.w_hh, self.mask_w_hh, True),
                self.b_ih,
                self.b_hh,
            ]
        else:
            weight = [self.w_ih, self.w_hh, self.b_ih, self.b_hh]
        if torch.__version__ < "1.8":
            output, hy, cy, reserve, new_weight_buf = torch._cudnn_rnn(
                input,
                weight,
                4,
                None,
                hx,
                cx,
                2,  # 2 means LSTM
                self.hidden_size,
                1,
                False,
                0,
                self.training,
                False,
                (),
                None,
            )
        else:
            output, hy, cy, reserve, new_weight_buf = torch._cudnn_rnn(
                input,
                weight,
                4,
                None,
                hx,
                cx,
                2,  # 2 means LSTM
                self.hidden_size,
                0,
                1,
                False,
                0,
                self.training,
                False,
                (),
                None,
            )
        return output, (hy, cy)

    @property
    def all_weights(self):
        """return all weights and bias of the model as a list"""
        # getattr() is used to get the value of an object's attribute
        return [
            [getattr(self, weight) for weight in weights]
            for weights in self._all_weights
        ]

all_weights property readonly

return all weights and bias of the model as a list

__init__(self, *, input_size, hidden_size, dr=0.5) special

Parameters

input_size number of neurons in input layer hidden_size number of neurons in hidden layer dr dropout rate

Source code in torchhydro/models/cudnnlstm.py
def __init__(self, *, input_size, hidden_size, dr=0.5):
    """

    Parameters
    ----------
    input_size
        number of neurons in input layer
    hidden_size
        number of neurons in hidden layer
    dr
        dropout rate
    """
    super(CudnnLstm, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.dr = dr
    self.w_ih = Parameter(torch.Tensor(hidden_size * 4, input_size))
    self.w_hh = Parameter(torch.Tensor(hidden_size * 4, hidden_size))
    self.b_ih = Parameter(torch.Tensor(hidden_size * 4))
    self.b_hh = Parameter(torch.Tensor(hidden_size * 4))
    self._all_weights = [["w_ih", "w_hh", "b_ih", "b_hh"]]
    # self.cuda()
    # set the mask
    self.reset_mask()
    # initialize the weights and bias of the model
    self.reset_parameters()

__setstate__(self, d) special

a python magic function to set the state of the object used for deserialization

Parameters

d : type description

Source code in torchhydro/models/cudnnlstm.py
def __setstate__(self, d):
    """a python magic function to set the state of the object used for deserialization

    Parameters
    ----------
    d : _type_
        _description_
    """
    super().__setstate__(d)
    # set a default value for _data_ptrs
    self.__dict__.setdefault("_data_ptrs", [])
    if "all_weights" in d:
        self._all_weights = d["all_weights"]
    if isinstance(self._all_weights[0][0], str):
        return
    self._all_weights = [["w_ih", "w_hh", "b_ih", "b_hh"]]

forward(self, input, hx=None, cx=None, do_drop_mc=False, dropout_false=False)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/cudnnlstm.py
def forward(self, input, hx=None, cx=None, do_drop_mc=False, dropout_false=False):
    # dropout_false: it will ensure do_drop is false, unless do_drop_mc is true
    if dropout_false and (not do_drop_mc):
        do_drop = False
    elif self.dr > 0 and (do_drop_mc is True or self.training is True):
        # if train mode and set self.dr > 0, then do_drop is true
        # so each time the model forward function is called, the dropout is applied
        do_drop = True
    else:
        do_drop = False
    # input must be a tensor with shape (seq_len, batch, input_size)
    batch_size = input.size(1)

    if hx is None:
        hx = input.new_zeros(1, batch_size, self.hidden_size, requires_grad=False)
    if cx is None:
        cx = input.new_zeros(1, batch_size, self.hidden_size, requires_grad=False)

    # handle = torch.backends.cudnn.get_handle()
    if do_drop is True:
        # cuDNN backend - disabled flat weight
        # NOTE: each time the mask is newly generated, so for each batch the mask is different
        self.reset_mask()
        # apply the mask to the weights
        weight = [
            DropMask.apply(self.w_ih, self.mask_w_ih, True),
            DropMask.apply(self.w_hh, self.mask_w_hh, True),
            self.b_ih,
            self.b_hh,
        ]
    else:
        weight = [self.w_ih, self.w_hh, self.b_ih, self.b_hh]
    if torch.__version__ < "1.8":
        output, hy, cy, reserve, new_weight_buf = torch._cudnn_rnn(
            input,
            weight,
            4,
            None,
            hx,
            cx,
            2,  # 2 means LSTM
            self.hidden_size,
            1,
            False,
            0,
            self.training,
            False,
            (),
            None,
        )
    else:
        output, hy, cy, reserve, new_weight_buf = torch._cudnn_rnn(
            input,
            weight,
            4,
            None,
            hx,
            cx,
            2,  # 2 means LSTM
            self.hidden_size,
            0,
            1,
            False,
            0,
            self.training,
            False,
            (),
            None,
        )
    return output, (hy, cy)

reset_mask(self)

generate mask for dropout

Source code in torchhydro/models/cudnnlstm.py
def reset_mask(self):
    """generate mask for dropout"""
    self.mask_w_ih = create_mask(self.w_ih, self.dr)
    self.mask_w_hh = create_mask(self.w_hh, self.dr)

reset_parameters(self)

initialize the weights and bias of the model using Xavier initialization

Source code in torchhydro/models/cudnnlstm.py
def reset_parameters(self):
    """initialize the weights and bias of the model using Xavier initialization"""
    stdv = 1.0 / math.sqrt(self.hidden_size)
    for weight in self.parameters():
        # uniform distribution between -stdv and stdv for the weights and bias initialization
        weight.data.uniform_(-stdv, stdv)

CudnnLstmModel (Module)

Source code in torchhydro/models/cudnnlstm.py
class CudnnLstmModel(nn.Module):
    def __init__(self, n_input_features, n_output_features, n_hidden_states, dr=0.5):
        """
        An LSTM model writen by Kuai Fang from this paper: https://doi.org/10.1002/2017GL075619

        only gpu version

        Parameters
        ----------
        n_input_features
            the number of input features
        n_output_features
            the number of output features
        n_hidden_states
            the number of hidden features
        dr
            dropout rate and its default is 0.5
        """
        super(CudnnLstmModel, self).__init__()
        self.nx = n_input_features
        self.ny = n_output_features
        self.hidden_size = n_hidden_states
        self.ct = 0
        self.nLayer = 1
        self.linearIn = torch.nn.Linear(self.nx, self.hidden_size)
        self.lstm = CudnnLstm(
            input_size=self.hidden_size, hidden_size=self.hidden_size, dr=dr
        )
        self.linearOut = torch.nn.Linear(self.hidden_size, self.ny)

    def forward(self, x, do_drop_mc=False, dropout_false=False, return_h_c=False):
        x0 = F.relu(self.linearIn(x))
        out_lstm, (hn, cn) = self.lstm(
            x0, do_drop_mc=do_drop_mc, dropout_false=dropout_false
        )
        out = self.linearOut(out_lstm)
        return (out, (hn, cn)) if return_h_c else out

__init__(self, n_input_features, n_output_features, n_hidden_states, dr=0.5) special

An LSTM model writen by Kuai Fang from this paper: https://doi.org/10.1002/2017GL075619

only gpu version

Parameters

n_input_features the number of input features n_output_features the number of output features n_hidden_states the number of hidden features dr dropout rate and its default is 0.5

Source code in torchhydro/models/cudnnlstm.py
def __init__(self, n_input_features, n_output_features, n_hidden_states, dr=0.5):
    """
    An LSTM model writen by Kuai Fang from this paper: https://doi.org/10.1002/2017GL075619

    only gpu version

    Parameters
    ----------
    n_input_features
        the number of input features
    n_output_features
        the number of output features
    n_hidden_states
        the number of hidden features
    dr
        dropout rate and its default is 0.5
    """
    super(CudnnLstmModel, self).__init__()
    self.nx = n_input_features
    self.ny = n_output_features
    self.hidden_size = n_hidden_states
    self.ct = 0
    self.nLayer = 1
    self.linearIn = torch.nn.Linear(self.nx, self.hidden_size)
    self.lstm = CudnnLstm(
        input_size=self.hidden_size, hidden_size=self.hidden_size, dr=dr
    )
    self.linearOut = torch.nn.Linear(self.hidden_size, self.ny)

forward(self, x, do_drop_mc=False, dropout_false=False, return_h_c=False)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/cudnnlstm.py
def forward(self, x, do_drop_mc=False, dropout_false=False, return_h_c=False):
    x0 = F.relu(self.linearIn(x))
    out_lstm, (hn, cn) = self.lstm(
        x0, do_drop_mc=do_drop_mc, dropout_false=dropout_false
    )
    out = self.linearOut(out_lstm)
    return (out, (hn, cn)) if return_h_c else out

CudnnLstmModelLstmKernel (Module)

use a trained/un-trained CudnnLstm as a kernel generator before another CudnnLstm.

Source code in torchhydro/models/cudnnlstm.py
class CudnnLstmModelLstmKernel(nn.Module):
    """use a trained/un-trained CudnnLstm as a kernel generator before another CudnnLstm."""

    def __init__(
        self,
        nx,
        ny,
        hidden_size,
        nk=None,
        hidden_size_later=None,
        cut=False,
        dr=0.5,
        delta_s=False,
    ):
        """delta_s means we will use the difference of the first lstm's output and the second's as the final output"""
        super(CudnnLstmModelLstmKernel, self).__init__()
        # These three layers are same with CudnnLstmModel to be used for transfer learning or just vanilla-use
        self.linearIn = torch.nn.Linear(nx, hidden_size)
        self.lstm = CudnnLstm(input_size=hidden_size, hidden_size=hidden_size, dr=dr)
        self.linearOut = torch.nn.Linear(hidden_size, ny)
        # if cut is True, we will only select the final index in nk, and repeat it, then concatenate with x
        self.cut = cut
        # the second lstm has more input than the previous
        if nk is None:
            nk = ny
        if hidden_size_later is None:
            hidden_size_later = hidden_size
        self.linear_in_later = torch.nn.Linear(nx + nk, hidden_size_later)
        self.lstm_later = CudnnLstm(
            input_size=hidden_size_later, hidden_size=hidden_size_later, dr=dr
        )
        self.linear_out_later = torch.nn.Linear(hidden_size_later, ny)

        self.delta_s = delta_s
        # when delta_s is true, cut cannot be true, because they have to have same number params
        assert not (cut and delta_s)

    def forward(self, x, do_drop_mc=False, dropout_false=False):
        x0 = F.relu(self.linearIn(x))
        out_lstm1, (hn1, cn1) = self.lstm(
            x0, do_drop_mc=do_drop_mc, dropout_false=dropout_false
        )
        gen = self.linearOut(out_lstm1)
        if self.cut:
            gen = gen[-1, :, :].repeat(x.shape[0], 1, 1)
        x1 = torch.cat((x, gen), dim=len(gen.shape) - 1)
        x2 = F.relu(self.linear_in_later(x1))
        out_lstm2, (hn2, cn2) = self.lstm_later(
            x2, do_drop_mc=do_drop_mc, dropout_false=dropout_false
        )
        out = self.linear_out_later(out_lstm2)
        return gen - out if self.delta_s else (out, gen)

__init__(self, nx, ny, hidden_size, nk=None, hidden_size_later=None, cut=False, dr=0.5, delta_s=False) special

delta_s means we will use the difference of the first lstm's output and the second's as the final output

Source code in torchhydro/models/cudnnlstm.py
def __init__(
    self,
    nx,
    ny,
    hidden_size,
    nk=None,
    hidden_size_later=None,
    cut=False,
    dr=0.5,
    delta_s=False,
):
    """delta_s means we will use the difference of the first lstm's output and the second's as the final output"""
    super(CudnnLstmModelLstmKernel, self).__init__()
    # These three layers are same with CudnnLstmModel to be used for transfer learning or just vanilla-use
    self.linearIn = torch.nn.Linear(nx, hidden_size)
    self.lstm = CudnnLstm(input_size=hidden_size, hidden_size=hidden_size, dr=dr)
    self.linearOut = torch.nn.Linear(hidden_size, ny)
    # if cut is True, we will only select the final index in nk, and repeat it, then concatenate with x
    self.cut = cut
    # the second lstm has more input than the previous
    if nk is None:
        nk = ny
    if hidden_size_later is None:
        hidden_size_later = hidden_size
    self.linear_in_later = torch.nn.Linear(nx + nk, hidden_size_later)
    self.lstm_later = CudnnLstm(
        input_size=hidden_size_later, hidden_size=hidden_size_later, dr=dr
    )
    self.linear_out_later = torch.nn.Linear(hidden_size_later, ny)

    self.delta_s = delta_s
    # when delta_s is true, cut cannot be true, because they have to have same number params
    assert not (cut and delta_s)

forward(self, x, do_drop_mc=False, dropout_false=False)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/cudnnlstm.py
def forward(self, x, do_drop_mc=False, dropout_false=False):
    x0 = F.relu(self.linearIn(x))
    out_lstm1, (hn1, cn1) = self.lstm(
        x0, do_drop_mc=do_drop_mc, dropout_false=dropout_false
    )
    gen = self.linearOut(out_lstm1)
    if self.cut:
        gen = gen[-1, :, :].repeat(x.shape[0], 1, 1)
    x1 = torch.cat((x, gen), dim=len(gen.shape) - 1)
    x2 = F.relu(self.linear_in_later(x1))
    out_lstm2, (hn2, cn2) = self.lstm_later(
        x2, do_drop_mc=do_drop_mc, dropout_false=dropout_false
    )
    out = self.linear_out_later(out_lstm2)
    return gen - out if self.delta_s else (out, gen)

CudnnLstmModelMultiOutput (Module)

Source code in torchhydro/models/cudnnlstm.py
class CudnnLstmModelMultiOutput(nn.Module):
    def __init__(
        self,
        n_input_features,
        n_output_features,
        n_hidden_states,
        layer_hidden_size=(128, 64),
        dr=0.5,
        dr_hidden=0.0,
    ):
        """
        Multiple output CudnnLSTM.

        It has multiple output layers, each for one output, so that we can easily freeze any output layer.

        Parameters
        ----------
        n_input_features
            the size of input features
        n_output_features
            the size of output features; in this model, we set different nonlinear layer for each output
        n_hidden_states
            the size of LSTM's hidden features
        layer_hidden_size
            hidden_size for multi-layers
        dr
            dropout rate
        dr_hidden
            dropout rates of hidden layers
        """
        super(CudnnLstmModelMultiOutput, self).__init__()
        self.ct = 0
        multi_layers = torch.nn.ModuleList()
        for i in range(n_output_features):
            multi_layers.add_module(
                "layer%d" % (i + 1),
                SimpleAnn(n_hidden_states, 1, layer_hidden_size, dr=dr_hidden),
            )
        self.multi_layers = multi_layers
        self.linearIn = torch.nn.Linear(n_input_features, n_hidden_states)
        self.lstm = CudnnLstm(
            input_size=n_hidden_states, hidden_size=n_hidden_states, dr=dr
        )

    def forward(self, x, do_drop_mc=False, dropout_false=False, return_h_c=False):
        x0 = F.relu(self.linearIn(x))
        out_lstm, (hn, cn) = self.lstm(
            x0, do_drop_mc=do_drop_mc, dropout_false=dropout_false
        )
        outs = [mod(out_lstm) for mod in self.multi_layers]
        final = torch.cat(outs, dim=-1)
        return (final, (hn, cn)) if return_h_c else final

__init__(self, n_input_features, n_output_features, n_hidden_states, layer_hidden_size=(128, 64), dr=0.5, dr_hidden=0.0) special

Multiple output CudnnLSTM.

It has multiple output layers, each for one output, so that we can easily freeze any output layer.

Parameters

n_input_features the size of input features n_output_features the size of output features; in this model, we set different nonlinear layer for each output n_hidden_states the size of LSTM's hidden features layer_hidden_size hidden_size for multi-layers dr dropout rate dr_hidden dropout rates of hidden layers

Source code in torchhydro/models/cudnnlstm.py
def __init__(
    self,
    n_input_features,
    n_output_features,
    n_hidden_states,
    layer_hidden_size=(128, 64),
    dr=0.5,
    dr_hidden=0.0,
):
    """
    Multiple output CudnnLSTM.

    It has multiple output layers, each for one output, so that we can easily freeze any output layer.

    Parameters
    ----------
    n_input_features
        the size of input features
    n_output_features
        the size of output features; in this model, we set different nonlinear layer for each output
    n_hidden_states
        the size of LSTM's hidden features
    layer_hidden_size
        hidden_size for multi-layers
    dr
        dropout rate
    dr_hidden
        dropout rates of hidden layers
    """
    super(CudnnLstmModelMultiOutput, self).__init__()
    self.ct = 0
    multi_layers = torch.nn.ModuleList()
    for i in range(n_output_features):
        multi_layers.add_module(
            "layer%d" % (i + 1),
            SimpleAnn(n_hidden_states, 1, layer_hidden_size, dr=dr_hidden),
        )
    self.multi_layers = multi_layers
    self.linearIn = torch.nn.Linear(n_input_features, n_hidden_states)
    self.lstm = CudnnLstm(
        input_size=n_hidden_states, hidden_size=n_hidden_states, dr=dr
    )

forward(self, x, do_drop_mc=False, dropout_false=False, return_h_c=False)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/cudnnlstm.py
def forward(self, x, do_drop_mc=False, dropout_false=False, return_h_c=False):
    x0 = F.relu(self.linearIn(x))
    out_lstm, (hn, cn) = self.lstm(
        x0, do_drop_mc=do_drop_mc, dropout_false=dropout_false
    )
    outs = [mod(out_lstm) for mod in self.multi_layers]
    final = torch.cat(outs, dim=-1)
    return (final, (hn, cn)) if return_h_c else final

LinearCudnnLstmModel (CudnnLstmModel)

This model is nonlinear layer + CudnnLSTM/CudnnLstm-MultiOutput-Model. kai_tl: model from this paper by Ma et al. -- https://doi.org/10.1029/2020WR028600

Source code in torchhydro/models/cudnnlstm.py
class LinearCudnnLstmModel(CudnnLstmModel):
    """This model is nonlinear layer + CudnnLSTM/CudnnLstm-MultiOutput-Model.
    kai_tl: model from this paper by Ma et al. -- https://doi.org/10.1029/2020WR028600
    """

    def __init__(self, linear_size, **kwargs):
        """

        Parameters
        ----------
        linear_size
            the number of input features for the first input linear layer
        """
        super(LinearCudnnLstmModel, self).__init__(**kwargs)
        self.former_linear = torch.nn.Linear(linear_size, kwargs["n_input_features"])

    def forward(self, x, do_drop_mc=False, dropout_false=False):
        x0 = F.relu(self.former_linear(x))
        return super(LinearCudnnLstmModel, self).forward(
            x0, do_drop_mc=do_drop_mc, dropout_false=dropout_false
        )

__init__(self, linear_size, **kwargs) special

Parameters

linear_size the number of input features for the first input linear layer

Source code in torchhydro/models/cudnnlstm.py
def __init__(self, linear_size, **kwargs):
    """

    Parameters
    ----------
    linear_size
        the number of input features for the first input linear layer
    """
    super(LinearCudnnLstmModel, self).__init__(**kwargs)
    self.former_linear = torch.nn.Linear(linear_size, kwargs["n_input_features"])

forward(self, x, do_drop_mc=False, dropout_false=False)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/cudnnlstm.py
def forward(self, x, do_drop_mc=False, dropout_false=False):
    x0 = F.relu(self.former_linear(x))
    return super(LinearCudnnLstmModel, self).forward(
        x0, do_drop_mc=do_drop_mc, dropout_false=dropout_false
    )

LstmCellTied (Module)

LSTM with dropout implemented by Kuai Fang: https://github.com/mhpi/hydroDL/blob/release/hydroDL/model/rnn.py

the name of "Tied" comes from this paper: http://papers.nips.cc/paper/6241-a-theoretically-grounded-application-of-dropout-in-recurrent-neural-networks.pdf which means the weights of all gates will be tied together to be used (eq. 6 in this paper). this code is mainly used as a CPU version of CudnnLstm

Source code in torchhydro/models/cudnnlstm.py
class LstmCellTied(nn.Module):
    """
    LSTM with dropout implemented by Kuai Fang: https://github.com/mhpi/hydroDL/blob/release/hydroDL/model/rnn.py

    the name of "Tied" comes from this paper:
    http://papers.nips.cc/paper/6241-a-theoretically-grounded-application-of-dropout-in-recurrent-neural-networks.pdf
    which means the weights of all gates will be tied together to be used (eq. 6 in this paper).
    this code is mainly used as a CPU version of CudnnLstm
    """

    def __init__(
        self,
        *,
        input_size,
        hidden_size,
        mode="train",
        dr=0.5,
        dr_method="drX+drW+drC",
        gpu=1,
    ):
        super(LstmCellTied, self).__init__()

        self.inputSize = input_size
        self.hiddenSize = hidden_size
        self.dr = dr

        self.w_ih = Parameter(torch.Tensor(hidden_size * 4, input_size))
        self.w_hh = Parameter(torch.Tensor(hidden_size * 4, hidden_size))
        self.b_ih = Parameter(torch.Tensor(hidden_size * 4))
        self.b_hh = Parameter(torch.Tensor(hidden_size * 4))

        self.drMethod = dr_method.split("+")
        self.gpu = gpu
        self.mode = mode
        if mode == "train":
            self.train(mode=True)
        elif mode in ["test", "drMC"]:
            self.train(mode=False)
        if gpu >= 0:
            self = self.cuda()
            self.is_cuda = True
        else:
            self.is_cuda = False
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hiddenSize)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def reset_mask(self, x, h, c):
        self.mask_x = create_mask(x, self.dr)
        self.mask_h = create_mask(h, self.dr)
        self.mask_c = create_mask(c, self.dr)
        self.mask_w_ih = create_mask(self.w_ih, self.dr)
        self.mask_w_hh = create_mask(self.w_hh, self.dr)

    def forward(self, x, hidden, *, do_reset_mask=True, do_drop_mc=False):
        do_drop = self.dr > 0 and (do_drop_mc is True or self.training is True)
        batch_size = x.size(0)
        h0, c0 = hidden
        if h0 is None:
            h0 = x.new_zeros(batch_size, self.hiddenSize, requires_grad=False)
        if c0 is None:
            c0 = x.new_zeros(batch_size, self.hiddenSize, requires_grad=False)

        if self.dr > 0 and self.training is True and do_reset_mask is True:
            self.reset_mask(x, h0, c0)

        if do_drop and "drH" in self.drMethod:
            h0 = DropMask.apply(h0, self.mask_h, True)

        if do_drop and "drX" in self.drMethod:
            x = DropMask.apply(x, self.mask_x, True)

        if do_drop and "drW" in self.drMethod:
            w_ih = DropMask.apply(self.w_ih, self.mask_w_ih, True)
            w_hh = DropMask.apply(self.w_hh, self.mask_w_hh, True)
        else:
            # self.w are parameters, while w are not
            w_ih = self.w_ih
            w_hh = self.w_hh

        gates = F.linear(x, w_ih, self.b_ih) + F.linear(h0, w_hh, self.b_hh)
        gate_i, gate_f, gate_c, gate_o = gates.chunk(4, 1)

        gate_i = torch.sigmoid(gate_i)
        gate_f = torch.sigmoid(gate_f)
        gate_c = torch.tanh(gate_c)
        gate_o = torch.sigmoid(gate_o)

        if self.training is True and "drC" in self.drMethod:
            gate_c = gate_c.mul(self.mask_c)

        c1 = (gate_f * c0) + (gate_i * gate_c)
        h1 = gate_o * torch.tanh(c1)

        return h1, c1

forward(self, x, hidden, *, do_reset_mask=True, do_drop_mc=False)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/cudnnlstm.py
def forward(self, x, hidden, *, do_reset_mask=True, do_drop_mc=False):
    do_drop = self.dr > 0 and (do_drop_mc is True or self.training is True)
    batch_size = x.size(0)
    h0, c0 = hidden
    if h0 is None:
        h0 = x.new_zeros(batch_size, self.hiddenSize, requires_grad=False)
    if c0 is None:
        c0 = x.new_zeros(batch_size, self.hiddenSize, requires_grad=False)

    if self.dr > 0 and self.training is True and do_reset_mask is True:
        self.reset_mask(x, h0, c0)

    if do_drop and "drH" in self.drMethod:
        h0 = DropMask.apply(h0, self.mask_h, True)

    if do_drop and "drX" in self.drMethod:
        x = DropMask.apply(x, self.mask_x, True)

    if do_drop and "drW" in self.drMethod:
        w_ih = DropMask.apply(self.w_ih, self.mask_w_ih, True)
        w_hh = DropMask.apply(self.w_hh, self.mask_w_hh, True)
    else:
        # self.w are parameters, while w are not
        w_ih = self.w_ih
        w_hh = self.w_hh

    gates = F.linear(x, w_ih, self.b_ih) + F.linear(h0, w_hh, self.b_hh)
    gate_i, gate_f, gate_c, gate_o = gates.chunk(4, 1)

    gate_i = torch.sigmoid(gate_i)
    gate_f = torch.sigmoid(gate_f)
    gate_c = torch.tanh(gate_c)
    gate_o = torch.sigmoid(gate_o)

    if self.training is True and "drC" in self.drMethod:
        gate_c = gate_c.mul(self.mask_c)

    c1 = (gate_f * c0) + (gate_i * gate_c)
    h1 = gate_o * torch.tanh(c1)

    return h1, c1

dpl4gr4j

Simulates streamflow over time using the model logic from GR4J as implemented in PyTorch. This function can be used to offer up the functionality of GR4J with added gradient information.

DplAnnGr4j (Module)

Source code in torchhydro/models/dpl4gr4j.py
class DplAnnGr4j(nn.Module):
    def __init__(
        self,
        n_input_features: int,
        n_output_features: int,
        n_hidden_states: Union[int, tuple, list],
        warmup_length: int,
        param_limit_func="sigmoid",
        param_test_way="final",
    ):
        """
        Differential Parameter learning model only with attributes as DL model's input: ANN -> Param -> Gr4j

        The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

        Parameters
        ----------
        n_input_features
            the number of input features of ANN
        n_output_features
            the number of output features of ANN, and it should be equal to the number of learning parameters in XAJ
        n_hidden_states
            the number of hidden features of ANN; it could be Union[int, tuple, list]
        warmup_length
            the length of warmup periods;
            hydrologic models need a warmup period to generate reasonable initial state values
        param_limit_func
            function used to limit the range of params; now it is sigmoid or clamp function
        param_test_way
            how we use parameters from dl model when testing;
            now we have three ways:
            1. "final" -- use the final period's parameter for each period
            2. "mean_time" -- Mean values of all periods' parameters is used
            3. "mean_basin" -- Mean values of all basins' final periods' parameters is used
        """
        super(DplAnnGr4j, self).__init__()
        self.dl_model = SimpleAnn(n_input_features, n_output_features, n_hidden_states)
        self.pb_model = Gr4j4Dpl(warmup_length)
        self.param_func = param_limit_func
        self.param_test_way = param_test_way

    def forward(self, x, z):
        """
        Differential parameter learning

        z (normalized input) -> ANN -> param -> + x (not normalized) -> gr4j -> q
        Parameters will be denormalized in gr4j model

        Parameters
        ----------
        x
            not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature]
        z
            normalized data used for DL model; a 2-dim tensor. [batch, feature]

        Returns
        -------
        torch.Tensor
            one time forward result
        """
        return ann_pbm(self.dl_model, self.pb_model, self.param_func, x, z)

__init__(self, n_input_features, n_output_features, n_hidden_states, warmup_length, param_limit_func='sigmoid', param_test_way='final') special

Differential Parameter learning model only with attributes as DL model's input: ANN -> Param -> Gr4j

The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

Parameters

n_input_features the number of input features of ANN n_output_features the number of output features of ANN, and it should be equal to the number of learning parameters in XAJ n_hidden_states the number of hidden features of ANN; it could be Union[int, tuple, list] warmup_length the length of warmup periods; hydrologic models need a warmup period to generate reasonable initial state values param_limit_func function used to limit the range of params; now it is sigmoid or clamp function param_test_way how we use parameters from dl model when testing; now we have three ways: 1. "final" -- use the final period's parameter for each period 2. "mean_time" -- Mean values of all periods' parameters is used 3. "mean_basin" -- Mean values of all basins' final periods' parameters is used

Source code in torchhydro/models/dpl4gr4j.py
def __init__(
    self,
    n_input_features: int,
    n_output_features: int,
    n_hidden_states: Union[int, tuple, list],
    warmup_length: int,
    param_limit_func="sigmoid",
    param_test_way="final",
):
    """
    Differential Parameter learning model only with attributes as DL model's input: ANN -> Param -> Gr4j

    The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

    Parameters
    ----------
    n_input_features
        the number of input features of ANN
    n_output_features
        the number of output features of ANN, and it should be equal to the number of learning parameters in XAJ
    n_hidden_states
        the number of hidden features of ANN; it could be Union[int, tuple, list]
    warmup_length
        the length of warmup periods;
        hydrologic models need a warmup period to generate reasonable initial state values
    param_limit_func
        function used to limit the range of params; now it is sigmoid or clamp function
    param_test_way
        how we use parameters from dl model when testing;
        now we have three ways:
        1. "final" -- use the final period's parameter for each period
        2. "mean_time" -- Mean values of all periods' parameters is used
        3. "mean_basin" -- Mean values of all basins' final periods' parameters is used
    """
    super(DplAnnGr4j, self).__init__()
    self.dl_model = SimpleAnn(n_input_features, n_output_features, n_hidden_states)
    self.pb_model = Gr4j4Dpl(warmup_length)
    self.param_func = param_limit_func
    self.param_test_way = param_test_way

forward(self, x, z)

Differential parameter learning

z (normalized input) -> ANN -> param -> + x (not normalized) -> gr4j -> q Parameters will be denormalized in gr4j model

Parameters

x not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature] z normalized data used for DL model; a 2-dim tensor. [batch, feature]

Returns

torch.Tensor one time forward result

Source code in torchhydro/models/dpl4gr4j.py
def forward(self, x, z):
    """
    Differential parameter learning

    z (normalized input) -> ANN -> param -> + x (not normalized) -> gr4j -> q
    Parameters will be denormalized in gr4j model

    Parameters
    ----------
    x
        not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature]
    z
        normalized data used for DL model; a 2-dim tensor. [batch, feature]

    Returns
    -------
    torch.Tensor
        one time forward result
    """
    return ann_pbm(self.dl_model, self.pb_model, self.param_func, x, z)

DplLstmGr4j (Module)

Source code in torchhydro/models/dpl4gr4j.py
class DplLstmGr4j(nn.Module):
    def __init__(
        self,
        n_input_features,
        n_output_features,
        n_hidden_states,
        warmup_length,
        param_limit_func="sigmoid",
        param_test_way="final",
    ):
        """
        Differential Parameter learning model: LSTM -> Param -> Gr4j

        The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

        Parameters
        ----------
        n_input_features
            the number of input features of LSTM
        n_output_features
            the number of output features of LSTM, and it should be equal to the number of learning parameters in XAJ
        n_hidden_states
            the number of hidden features of LSTM
        warmup_length
            the length of warmup periods;
            hydrologic models need a warmup period to generate reasonable initial state values
        param_limit_func
            function used to limit the range of params; now it is sigmoid or clamp function
        param_test_way
            how we use parameters from dl model when testing;
            now we have three ways:
            1. "final" -- use the final period's parameter for each period
            2. "mean_time" -- Mean values of all periods' parameters is used
            3. "mean_basin" -- Mean values of all basins' final periods' parameters is used
        """
        super(DplLstmGr4j, self).__init__()
        self.dl_model = SimpleLSTM(n_input_features, n_output_features, n_hidden_states)
        self.pb_model = Gr4j4Dpl(warmup_length)
        self.param_func = param_limit_func
        self.param_test_way = param_test_way

    def forward(self, x, z):
        """
        Differential parameter learning

        z (normalized input) -> lstm -> param -> + x (not normalized) -> gr4j -> q
        Parameters will be denormalized in gr4j model

        Parameters
        ----------
        x
            not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature]
        z
            normalized data used for DL model; a sequence-first 3-dim tensor. [sequence, batch, feature]

        Returns
        -------
        torch.Tensor
            one time forward result
        """
        return lstm_pbm(self.dl_model, self.pb_model, self.param_func, x, z)

__init__(self, n_input_features, n_output_features, n_hidden_states, warmup_length, param_limit_func='sigmoid', param_test_way='final') special

Differential Parameter learning model: LSTM -> Param -> Gr4j

The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

Parameters

n_input_features the number of input features of LSTM n_output_features the number of output features of LSTM, and it should be equal to the number of learning parameters in XAJ n_hidden_states the number of hidden features of LSTM warmup_length the length of warmup periods; hydrologic models need a warmup period to generate reasonable initial state values param_limit_func function used to limit the range of params; now it is sigmoid or clamp function param_test_way how we use parameters from dl model when testing; now we have three ways: 1. "final" -- use the final period's parameter for each period 2. "mean_time" -- Mean values of all periods' parameters is used 3. "mean_basin" -- Mean values of all basins' final periods' parameters is used

Source code in torchhydro/models/dpl4gr4j.py
def __init__(
    self,
    n_input_features,
    n_output_features,
    n_hidden_states,
    warmup_length,
    param_limit_func="sigmoid",
    param_test_way="final",
):
    """
    Differential Parameter learning model: LSTM -> Param -> Gr4j

    The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

    Parameters
    ----------
    n_input_features
        the number of input features of LSTM
    n_output_features
        the number of output features of LSTM, and it should be equal to the number of learning parameters in XAJ
    n_hidden_states
        the number of hidden features of LSTM
    warmup_length
        the length of warmup periods;
        hydrologic models need a warmup period to generate reasonable initial state values
    param_limit_func
        function used to limit the range of params; now it is sigmoid or clamp function
    param_test_way
        how we use parameters from dl model when testing;
        now we have three ways:
        1. "final" -- use the final period's parameter for each period
        2. "mean_time" -- Mean values of all periods' parameters is used
        3. "mean_basin" -- Mean values of all basins' final periods' parameters is used
    """
    super(DplLstmGr4j, self).__init__()
    self.dl_model = SimpleLSTM(n_input_features, n_output_features, n_hidden_states)
    self.pb_model = Gr4j4Dpl(warmup_length)
    self.param_func = param_limit_func
    self.param_test_way = param_test_way

forward(self, x, z)

Differential parameter learning

z (normalized input) -> lstm -> param -> + x (not normalized) -> gr4j -> q Parameters will be denormalized in gr4j model

Parameters

x not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature] z normalized data used for DL model; a sequence-first 3-dim tensor. [sequence, batch, feature]

Returns

torch.Tensor one time forward result

Source code in torchhydro/models/dpl4gr4j.py
def forward(self, x, z):
    """
    Differential parameter learning

    z (normalized input) -> lstm -> param -> + x (not normalized) -> gr4j -> q
    Parameters will be denormalized in gr4j model

    Parameters
    ----------
    x
        not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature]
    z
        normalized data used for DL model; a sequence-first 3-dim tensor. [sequence, batch, feature]

    Returns
    -------
    torch.Tensor
        one time forward result
    """
    return lstm_pbm(self.dl_model, self.pb_model, self.param_func, x, z)

Gr4j4Dpl (Module)

the nn.Module style GR4J model

Source code in torchhydro/models/dpl4gr4j.py
class Gr4j4Dpl(nn.Module):
    """
    the nn.Module style GR4J model
    """

    def __init__(self, warmup_length: int):
        """
        Parameters
        ----------
        warmup_length
            length of warmup period
        """
        super(Gr4j4Dpl, self).__init__()
        self.params_names = ["X1", "X2", "X3", "X4"]
        self.x1_scale = [100.0, 1200.0]
        self.x2_sacle = [-5.0, 3.0]
        self.x3_scale = [20.0, 300.0]
        self.x4_scale = [1.1, 2.9]
        self.warmup_length = warmup_length
        self.feature_size = 2

    def forward(self, p_and_e, parameters, return_state=False):
        gr4j_device = p_and_e.device
        x1 = self.x1_scale[0] + parameters[:, 0] * (self.x1_scale[1] - self.x1_scale[0])
        x2 = self.x2_sacle[0] + parameters[:, 1] * (self.x2_sacle[1] - self.x2_sacle[0])
        x3 = self.x3_scale[0] + parameters[:, 2] * (self.x3_scale[1] - self.x3_scale[0])
        x4 = self.x4_scale[0] + parameters[:, 3] * (self.x4_scale[1] - self.x4_scale[0])

        warmup_length = self.warmup_length
        if warmup_length > 0:
            # set no_grad for warmup periods
            with torch.no_grad():
                p_and_e_warmup = p_and_e[0:warmup_length, :, :]
                cal_init = Gr4j4Dpl(0)
                if cal_init.warmup_length > 0:
                    raise RuntimeError("Please set init model's warmup length to 0!!!")
                _, s0, r0 = cal_init(p_and_e_warmup, parameters, return_state=True)
        else:
            # use detach func to make wu0 no_grad as it is an initial value
            s0 = 0.5 * x1.detach()
            r0 = 0.5 * x3.detach()
        inputs = p_and_e[warmup_length:, :, :]
        streamflow_ = torch.full(inputs.shape[:2], 0.0).to(gr4j_device)
        prs = torch.full(inputs.shape[:2], 0.0).to(gr4j_device)
        for i in range(inputs.shape[0]):
            if i == 0:
                pr, s = production(inputs[i, :, :], x1, s0)
            else:
                pr, s = production(inputs[i, :, :], x1, s)
            prs[i, :] = pr
        prs_x = torch.unsqueeze(prs, dim=2)
        conv_q9, conv_q1 = uh_gr4j(x4)
        q9 = torch.full([inputs.shape[0], inputs.shape[1], 1], 0.0).to(gr4j_device)
        q1 = torch.full([inputs.shape[0], inputs.shape[1], 1], 0.0).to(gr4j_device)
        for j in range(inputs.shape[1]):
            q9[:, j : j + 1, :] = uh_conv(
                prs_x[:, j : j + 1, :], conv_q9[j].reshape(-1, 1, 1)
            )
            q1[:, j : j + 1, :] = uh_conv(
                prs_x[:, j : j + 1, :], conv_q1[j].reshape(-1, 1, 1)
            )
        for i in range(inputs.shape[0]):
            if i == 0:
                q, r = routing(q9[i, :, 0], q1[i, :, 0], x2, x3, r0)
            else:
                q, r = routing(q9[i, :, 0], q1[i, :, 0], x2, x3, r)
            streamflow_[i, :] = q
        streamflow = torch.unsqueeze(streamflow_, dim=2)
        return (streamflow, s, r) if return_state else streamflow

__init__(self, warmup_length) special

Parameters

warmup_length length of warmup period

Source code in torchhydro/models/dpl4gr4j.py
def __init__(self, warmup_length: int):
    """
    Parameters
    ----------
    warmup_length
        length of warmup period
    """
    super(Gr4j4Dpl, self).__init__()
    self.params_names = ["X1", "X2", "X3", "X4"]
    self.x1_scale = [100.0, 1200.0]
    self.x2_sacle = [-5.0, 3.0]
    self.x3_scale = [20.0, 300.0]
    self.x4_scale = [1.1, 2.9]
    self.warmup_length = warmup_length
    self.feature_size = 2

forward(self, p_and_e, parameters, return_state=False)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/dpl4gr4j.py
def forward(self, p_and_e, parameters, return_state=False):
    gr4j_device = p_and_e.device
    x1 = self.x1_scale[0] + parameters[:, 0] * (self.x1_scale[1] - self.x1_scale[0])
    x2 = self.x2_sacle[0] + parameters[:, 1] * (self.x2_sacle[1] - self.x2_sacle[0])
    x3 = self.x3_scale[0] + parameters[:, 2] * (self.x3_scale[1] - self.x3_scale[0])
    x4 = self.x4_scale[0] + parameters[:, 3] * (self.x4_scale[1] - self.x4_scale[0])

    warmup_length = self.warmup_length
    if warmup_length > 0:
        # set no_grad for warmup periods
        with torch.no_grad():
            p_and_e_warmup = p_and_e[0:warmup_length, :, :]
            cal_init = Gr4j4Dpl(0)
            if cal_init.warmup_length > 0:
                raise RuntimeError("Please set init model's warmup length to 0!!!")
            _, s0, r0 = cal_init(p_and_e_warmup, parameters, return_state=True)
    else:
        # use detach func to make wu0 no_grad as it is an initial value
        s0 = 0.5 * x1.detach()
        r0 = 0.5 * x3.detach()
    inputs = p_and_e[warmup_length:, :, :]
    streamflow_ = torch.full(inputs.shape[:2], 0.0).to(gr4j_device)
    prs = torch.full(inputs.shape[:2], 0.0).to(gr4j_device)
    for i in range(inputs.shape[0]):
        if i == 0:
            pr, s = production(inputs[i, :, :], x1, s0)
        else:
            pr, s = production(inputs[i, :, :], x1, s)
        prs[i, :] = pr
    prs_x = torch.unsqueeze(prs, dim=2)
    conv_q9, conv_q1 = uh_gr4j(x4)
    q9 = torch.full([inputs.shape[0], inputs.shape[1], 1], 0.0).to(gr4j_device)
    q1 = torch.full([inputs.shape[0], inputs.shape[1], 1], 0.0).to(gr4j_device)
    for j in range(inputs.shape[1]):
        q9[:, j : j + 1, :] = uh_conv(
            prs_x[:, j : j + 1, :], conv_q9[j].reshape(-1, 1, 1)
        )
        q1[:, j : j + 1, :] = uh_conv(
            prs_x[:, j : j + 1, :], conv_q1[j].reshape(-1, 1, 1)
        )
    for i in range(inputs.shape[0]):
        if i == 0:
            q, r = routing(q9[i, :, 0], q1[i, :, 0], x2, x3, r0)
        else:
            q, r = routing(q9[i, :, 0], q1[i, :, 0], x2, x3, r)
        streamflow_[i, :] = q
    streamflow = torch.unsqueeze(streamflow_, dim=2)
    return (streamflow, s, r) if return_state else streamflow

calculate_evap_store(s, evap_net, x1)

Calculates the amount of evaporation out of the storage reservoir.

Source code in torchhydro/models/dpl4gr4j.py
def calculate_evap_store(s, evap_net, x1):
    """Calculates the amount of evaporation out of the storage reservoir."""
    n = s * (2.0 - s / x1) * torch.tanh(evap_net / x1)
    d = 1.0 + (1.0 - s / x1) * torch.tanh(evap_net / x1)
    return n / d

calculate_perc(current_store, x1)

Calculates the percolation from the storage reservoir into streamflow.

Source code in torchhydro/models/dpl4gr4j.py
def calculate_perc(current_store, x1):
    """Calculates the percolation from the storage reservoir into streamflow."""
    return current_store * (
        1.0 - (1.0 + (4.0 / 9.0 * current_store / x1) ** 4) ** -0.25
    )

calculate_precip_store(s, precip_net, x1)

Calculates the amount of rainfall which enters the storage reservoir.

Source code in torchhydro/models/dpl4gr4j.py
def calculate_precip_store(s, precip_net, x1):
    """Calculates the amount of rainfall which enters the storage reservoir."""
    n = x1 * (1.0 - (s / x1) ** 2) * torch.tanh(precip_net / x1)
    d = 1.0 + (s / x1) * torch.tanh(precip_net / x1)
    return n / d

production(p_and_e, x1, s_level=None)

an one-step calculation for production store in GR4J the dimension of the cell: [batch, feature]

Parameters

p_and_e P is pe[:, 0] and E is pe[:, 1]; similar with the "input" in the RNNCell !!! x1 Storage reservoir parameter; s_level s_level means S in the GR4J Model; similar with the "hx" in the RNNCell Initial value of storage in the storage reservoir.

Returns

tuple contains the Pr and updated S

Source code in torchhydro/models/dpl4gr4j.py
def production(
    p_and_e: Tensor, x1: Tensor, s_level: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor]:
    """
    an one-step calculation for production store in GR4J
    the dimension of the cell: [batch, feature]

    Parameters
    ----------
    p_and_e
        P is pe[:, 0] and E is pe[:, 1]; similar with the "input" in the RNNCell
    x1:
        Storage reservoir parameter;
    s_level
        s_level means S in the GR4J Model; similar with the "hx" in the RNNCell
        Initial value of storage in the storage reservoir.

    Returns
    -------
    tuple
        contains the Pr and updated S
    """
    gr4j_device = p_and_e.device
    # Calculate net precipitation and evapotranspiration
    precip_difference = p_and_e[:, 0] - p_and_e[:, 1]
    precip_net = torch.maximum(precip_difference, Tensor([0.0]).to(gr4j_device))
    evap_net = torch.maximum(-precip_difference, Tensor([0.0]).to(gr4j_device))

    if s_level is None:
        s_level = 0.6 * (x1.detach())

    # s_level should not be larger than x1
    s_level = torch.clamp(s_level, torch.full(s_level.shape, 0.0).to(gr4j_device), x1)

    # Calculate the fraction of net precipitation that is stored
    precip_store = calculate_precip_store(s_level, precip_net, x1)

    # Calculate the amount of evaporation from storage
    evap_store = calculate_evap_store(s_level, evap_net, x1)

    # Update the storage by adding effective precipitation and
    # removing evaporation
    s_update = s_level - evap_store + precip_store
    # s_level should not be larger than self.x1
    s_update = torch.clamp(
        s_update, torch.full(s_update.shape, 0.0).to(gr4j_device), x1
    )

    # Update the storage again to reflect percolation out of the store
    perc = calculate_perc(s_update, x1)
    s_update = s_update - perc
    # perc is always lower than S because of the calculation itself, so we don't need clamp here anymore.

    # The precip. for routing is the sum of the rainfall which
    # did not make it to storage and the percolation from the store
    current_runoff = perc + (precip_net - precip_store)
    return current_runoff, s_update

routing(q9, q1, x2, x3, r_level=None)

the GR4J routing-module unit cell for time-sequence loop

Parameters

q9

q1

x2 Catchment water exchange parameter x3 Routing reservoir parameters r_level Beginning value of storage in the routing reservoir.

Returns
Source code in torchhydro/models/dpl4gr4j.py
def routing(q9: Tensor, q1: Tensor, x2, x3, r_level: Optional[Tensor] = None):
    """
    the GR4J routing-module unit cell for time-sequence loop

    Parameters
    ----------
    q9

    q1

    x2
        Catchment water exchange parameter
    x3
        Routing reservoir parameters
    r_level
        Beginning value of storage in the routing reservoir.

    Returns
    -------

    """
    gr4j_device = q9.device
    if r_level is None:
        r_level = 0.7 * (x3.detach())
    # r_level should not be larger than self.x3
    r_level = torch.clamp(r_level, torch.full(r_level.shape, 0.0).to(gr4j_device), x3)
    groundwater_ex = x2 * (r_level / x3) ** 3.5
    r_updated = torch.maximum(
        torch.full(r_level.shape, 0.0).to(gr4j_device), r_level + q9 + groundwater_ex
    )

    qr = r_updated * (1.0 - (1.0 + (r_updated / x3) ** 4) ** -0.25)
    r_updated = r_updated - qr

    qd = torch.maximum(
        torch.full(groundwater_ex.shape, 0.0).to(gr4j_device), q1 + groundwater_ex
    )
    q = qr + qd
    return q, r_updated

uh_gr4j(x4)

Generate the convolution kernel for the convolution operation in routing module of GR4J

Parameters

x4 the dim of x4 is [batch]

Returns

list UH1s and UH2s for all basins

Source code in torchhydro/models/dpl4gr4j.py
def uh_gr4j(x4):
    """
    Generate the convolution kernel for the convolution operation in routing module of GR4J

    Parameters
    ----------
    x4
        the dim of x4 is [batch]

    Returns
    -------
    list
        UH1s and UH2s for all basins
    """
    gr4j_device = x4.device
    uh1_ordinates = []
    uh2_ordinates = []
    for i in range(len(x4)):
        # for SH1, the pieces are: 0, 0<t<x4, t>=x4
        uh1_ordinates_t1 = torch.arange(
            0.0, torch.ceil(x4[i]).detach().cpu().numpy().item()
        ).to(gr4j_device)
        uh1_ordinates_t = torch.arange(
            1.0, torch.ceil(x4[i] + 1.0).detach().cpu().numpy().item()
        ).to(gr4j_device)
        # for SH2, the pieces are: 0, 0<t<=x4, x4<t<2x4, t>=2x4
        uh2_ords_t1_seq_x4 = torch.arange(
            0.0, torch.floor(x4[i] + 1).detach().cpu().numpy().item()
        ).to(gr4j_device)
        uh2_ords_t1_larger_x4 = torch.arange(
            torch.floor(x4[i] + 1).detach().cpu().numpy().item(),
            torch.ceil(2 * x4[i]).detach().cpu().numpy().item(),
        ).to(gr4j_device)
        uh2_ords_t_seq_x4 = torch.arange(
            1.0, torch.floor(x4[i] + 1).detach().cpu().numpy().item()
        ).to(gr4j_device)
        uh2_ords_t_larger_x4 = torch.arange(
            torch.floor(x4[i] + 1).detach().cpu().numpy().item(),
            torch.ceil(2 * x4[i] + 1.0).detach().cpu().numpy().item(),
        ).to(gr4j_device)
        s_curve1t1 = (uh1_ordinates_t1 / x4[i]) ** 2.5
        s_curve21t1 = 0.5 * (uh2_ords_t1_seq_x4 / x4[i]) ** 2.5
        s_curve22t1 = 1.0 - 0.5 * (2 - uh2_ords_t1_larger_x4 / x4[i]) ** 2.5
        s_curve2t1 = torch.cat([s_curve21t1, s_curve22t1])
        # t1 cannot be larger than x4, but t can, so we should set (uh1_ordinates_t / x4[i]) <=1
        # we don't use torch.clamp, because it seems we have to use mask, or we will get nan for grad. More details
        # could be seen here: https://github.com/waterDLut/hydro-dl-basic/tree/dev/3-more-knowledge/5-grad-problem.ipynb
        uh1_x4 = uh1_ordinates_t / x4[i]
        limit_uh1_x4 = 1 - F.relu(1 - uh1_x4)
        limit_uh2_smaller_x4 = uh2_ords_t_seq_x4 / x4[i]
        uh2_larger_x4 = 2 - uh2_ords_t_larger_x4 / x4[i]
        limit_uh2_larger_x4 = F.relu(uh2_larger_x4)
        s_curve1t = limit_uh1_x4**2.5
        s_curve21t = 0.5 * limit_uh2_smaller_x4**2.5
        s_curve22t = 1.0 - 0.5 * limit_uh2_larger_x4**2.5
        s_curve2t = torch.cat([s_curve21t, s_curve22t])
        uh1_ordinate = s_curve1t - s_curve1t1
        uh2_ordinate = s_curve2t - s_curve2t1
        uh1_ordinates.append(uh1_ordinate)
        uh2_ordinates.append(uh2_ordinate)

    return uh1_ordinates, uh2_ordinates

dpl4hbv

DplAnnHbv (Module)

Source code in torchhydro/models/dpl4hbv.py
class DplAnnHbv(nn.Module):
    def __init__(
        self,
        n_input_features: int,
        n_output_features: int,
        n_hidden_states: Union[int, tuple, list],
        kernel_size,
        warmup_length: int,
        param_limit_func="sigmoid",
        param_test_way="final",
    ):
        """
        Differential Parameter learning model only with attributes as DL model's input: ANN -> Param -> Gr4j

        The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

        Parameters
        ----------
        n_input_features
            the number of input features of ANN
        n_output_features
            the number of output features of ANN, and it should be equal to the number of learning parameters in XAJ
        n_hidden_states
            the number of hidden features of ANN; it could be Union[int, tuple, list]
        kernel_size
            size for unit hydrograph
        warmup_length
            the length of warmup periods;
            hydrologic models need a warmup period to generate reasonable initial state values
        param_limit_func
            function used to limit the range of params; now it is sigmoid or clamp function
        param_test_way
            how we use parameters from dl model when testing;
            now we have three ways:
            1. "final" -- use the final period's parameter for each period
            2. "mean_time" -- Mean values of all periods' parameters is used
            3. "mean_basin" -- Mean values of all basins' final periods' parameters is used
        """
        super(DplAnnHbv, self).__init__()
        self.dl_model = SimpleAnn(n_input_features, n_output_features, n_hidden_states)
        self.pb_model = Hbv4Dpl(warmup_length, kernel_size)
        self.param_func = param_limit_func
        self.param_test_way = param_test_way

    def forward(self, x, z):
        """
        Differential parameter learning

        z (normalized input) -> ANN -> param -> + x (not normalized) -> gr4j -> q
        Parameters will be denormalized in gr4j model

        Parameters
        ----------
        x
            not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature]
        z
            normalized data used for DL model; a 2-dim tensor. [batch, feature]

        Returns
        -------
        torch.Tensor
            one time forward result
        """
        return ann_pbm(self.dl_model, self.pb_model, self.param_func, x, z)

__init__(self, n_input_features, n_output_features, n_hidden_states, kernel_size, warmup_length, param_limit_func='sigmoid', param_test_way='final') special

Differential Parameter learning model only with attributes as DL model's input: ANN -> Param -> Gr4j

The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

Parameters

n_input_features the number of input features of ANN n_output_features the number of output features of ANN, and it should be equal to the number of learning parameters in XAJ n_hidden_states the number of hidden features of ANN; it could be Union[int, tuple, list] kernel_size size for unit hydrograph warmup_length the length of warmup periods; hydrologic models need a warmup period to generate reasonable initial state values param_limit_func function used to limit the range of params; now it is sigmoid or clamp function param_test_way how we use parameters from dl model when testing; now we have three ways: 1. "final" -- use the final period's parameter for each period 2. "mean_time" -- Mean values of all periods' parameters is used 3. "mean_basin" -- Mean values of all basins' final periods' parameters is used

Source code in torchhydro/models/dpl4hbv.py
def __init__(
    self,
    n_input_features: int,
    n_output_features: int,
    n_hidden_states: Union[int, tuple, list],
    kernel_size,
    warmup_length: int,
    param_limit_func="sigmoid",
    param_test_way="final",
):
    """
    Differential Parameter learning model only with attributes as DL model's input: ANN -> Param -> Gr4j

    The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

    Parameters
    ----------
    n_input_features
        the number of input features of ANN
    n_output_features
        the number of output features of ANN, and it should be equal to the number of learning parameters in XAJ
    n_hidden_states
        the number of hidden features of ANN; it could be Union[int, tuple, list]
    kernel_size
        size for unit hydrograph
    warmup_length
        the length of warmup periods;
        hydrologic models need a warmup period to generate reasonable initial state values
    param_limit_func
        function used to limit the range of params; now it is sigmoid or clamp function
    param_test_way
        how we use parameters from dl model when testing;
        now we have three ways:
        1. "final" -- use the final period's parameter for each period
        2. "mean_time" -- Mean values of all periods' parameters is used
        3. "mean_basin" -- Mean values of all basins' final periods' parameters is used
    """
    super(DplAnnHbv, self).__init__()
    self.dl_model = SimpleAnn(n_input_features, n_output_features, n_hidden_states)
    self.pb_model = Hbv4Dpl(warmup_length, kernel_size)
    self.param_func = param_limit_func
    self.param_test_way = param_test_way

forward(self, x, z)

Differential parameter learning

z (normalized input) -> ANN -> param -> + x (not normalized) -> gr4j -> q Parameters will be denormalized in gr4j model

Parameters

x not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature] z normalized data used for DL model; a 2-dim tensor. [batch, feature]

Returns

torch.Tensor one time forward result

Source code in torchhydro/models/dpl4hbv.py
def forward(self, x, z):
    """
    Differential parameter learning

    z (normalized input) -> ANN -> param -> + x (not normalized) -> gr4j -> q
    Parameters will be denormalized in gr4j model

    Parameters
    ----------
    x
        not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature]
    z
        normalized data used for DL model; a 2-dim tensor. [batch, feature]

    Returns
    -------
    torch.Tensor
        one time forward result
    """
    return ann_pbm(self.dl_model, self.pb_model, self.param_func, x, z)

DplLstmHbv (Module)

Source code in torchhydro/models/dpl4hbv.py
class DplLstmHbv(nn.Module):
    def __init__(
        self,
        n_input_features,
        n_output_features,
        n_hidden_states,
        kernel_size,
        warmup_length,
        param_limit_func="sigmoid",
        param_test_way="final",
    ):
        """
        Differential Parameter learning model: LSTM -> Param -> XAJ

        The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

        Parameters
        ----------
        n_input_features
            the number of input features of LSTM
        n_output_features
            the number of output features of LSTM, and it should be equal to the number of learning parameters in XAJ
        n_hidden_states
            the number of hidden features of LSTM
        kernel_size
            size for unit hydrograph
        warmup_length
            the time length of warmup period
        param_limit_func
            function used to limit the range of params; now it is sigmoid or clamp function
        param_test_way
            how we use parameters from dl model when testing;
            now we have three ways:
            1. "final" -- use the final period's parameter for each period
            2. "mean_time" -- Mean values of all periods' parameters is used
            3. "mean_basin" -- Mean values of all basins' final periods' parameters is used
        """
        super(DplLstmHbv, self).__init__()
        self.dl_model = SimpleLSTM(n_input_features, n_output_features, n_hidden_states)
        self.pb_model = Hbv4Dpl(warmup_length, kernel_size)
        self.param_func = param_limit_func
        self.param_test_way = param_test_way

    def forward(self, x, z):
        """
        Differential parameter learning

        z (normalized input) -> lstm -> param -> + x (not normalized) -> xaj -> q
        Parameters will be denormalized in xaj model

        Parameters
        ----------
        x
            not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature]
        z
            normalized data used for DL model; a sequence-first 3-dim tensor. [sequence, batch, feature]

        Returns
        -------
        torch.Tensor
            one time forward result
        """
        return lstm_pbm(self.dl_model, self.pb_model, self.param_func, x, z)

__init__(self, n_input_features, n_output_features, n_hidden_states, kernel_size, warmup_length, param_limit_func='sigmoid', param_test_way='final') special

Differential Parameter learning model: LSTM -> Param -> XAJ

The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

Parameters

n_input_features the number of input features of LSTM n_output_features the number of output features of LSTM, and it should be equal to the number of learning parameters in XAJ n_hidden_states the number of hidden features of LSTM kernel_size size for unit hydrograph warmup_length the time length of warmup period param_limit_func function used to limit the range of params; now it is sigmoid or clamp function param_test_way how we use parameters from dl model when testing; now we have three ways: 1. "final" -- use the final period's parameter for each period 2. "mean_time" -- Mean values of all periods' parameters is used 3. "mean_basin" -- Mean values of all basins' final periods' parameters is used

Source code in torchhydro/models/dpl4hbv.py
def __init__(
    self,
    n_input_features,
    n_output_features,
    n_hidden_states,
    kernel_size,
    warmup_length,
    param_limit_func="sigmoid",
    param_test_way="final",
):
    """
    Differential Parameter learning model: LSTM -> Param -> XAJ

    The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

    Parameters
    ----------
    n_input_features
        the number of input features of LSTM
    n_output_features
        the number of output features of LSTM, and it should be equal to the number of learning parameters in XAJ
    n_hidden_states
        the number of hidden features of LSTM
    kernel_size
        size for unit hydrograph
    warmup_length
        the time length of warmup period
    param_limit_func
        function used to limit the range of params; now it is sigmoid or clamp function
    param_test_way
        how we use parameters from dl model when testing;
        now we have three ways:
        1. "final" -- use the final period's parameter for each period
        2. "mean_time" -- Mean values of all periods' parameters is used
        3. "mean_basin" -- Mean values of all basins' final periods' parameters is used
    """
    super(DplLstmHbv, self).__init__()
    self.dl_model = SimpleLSTM(n_input_features, n_output_features, n_hidden_states)
    self.pb_model = Hbv4Dpl(warmup_length, kernel_size)
    self.param_func = param_limit_func
    self.param_test_way = param_test_way

forward(self, x, z)

Differential parameter learning

z (normalized input) -> lstm -> param -> + x (not normalized) -> xaj -> q Parameters will be denormalized in xaj model

Parameters

x not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature] z normalized data used for DL model; a sequence-first 3-dim tensor. [sequence, batch, feature]

Returns

torch.Tensor one time forward result

Source code in torchhydro/models/dpl4hbv.py
def forward(self, x, z):
    """
    Differential parameter learning

    z (normalized input) -> lstm -> param -> + x (not normalized) -> xaj -> q
    Parameters will be denormalized in xaj model

    Parameters
    ----------
    x
        not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature]
    z
        normalized data used for DL model; a sequence-first 3-dim tensor. [sequence, batch, feature]

    Returns
    -------
    torch.Tensor
        one time forward result
    """
    return lstm_pbm(self.dl_model, self.pb_model, self.param_func, x, z)

Hbv4Dpl (Module)

HBV Model Pytorch version

Source code in torchhydro/models/dpl4hbv.py
class Hbv4Dpl(torch.nn.Module):
    """HBV Model Pytorch version"""

    def __init__(self, warmup_length, kernel_size=15):
        """Initiate a HBV instance

        Parameters
        ----------
        warmup_length : _type_
            _description_
        kernel_size : int, optional
            conv kernel for unit hydrograph, by default 15
        """
        super(Hbv4Dpl, self).__init__()
        self.name = "HBV"
        self.params_names = MODEL_PARAM_DICT["hbv"]["param_name"]
        parasca_lst = MODEL_PARAM_DICT["hbv"]["param_range"]
        self.beta_scale = parasca_lst["BETA"]
        self.fc_scale = parasca_lst["FC"]
        self.k0_scale = parasca_lst["K0"]
        self.k1_scale = parasca_lst["K1"]
        self.k2_scale = parasca_lst["K2"]
        self.lp_scale = parasca_lst["LP"]
        self.perc_scale = parasca_lst["PERC"]
        self.uzl_scale = parasca_lst["UZL"]
        self.tt_scale = parasca_lst["TT"]
        self.cfmax_scale = parasca_lst["CFMAX"]
        self.cfr_scale = parasca_lst["CFR"]
        self.cwh_scale = parasca_lst["CWH"]
        self.a_scale = parasca_lst["A"]
        self.theta_scale = parasca_lst["THETA"]
        self.warmup_length = warmup_length
        self.kernel_size = kernel_size
        # there are 3 input vars in HBV: P, PET and TEMPERATURE
        self.feature_size = 3

    def forward(
        self, x, parameters, out_state=False, rout_opt=True
    ) -> Union[tuple, torch.Tensor]:
        """
        Runs the HBV-light hydrological model (Seibert, 2005).

        The code comes from mhpi/hydro-dev.
        NaN values have to be removed from the inputs.

        Parameters
        ----------
        x
            p_all = array with daily values of precipitation (mm/d)
            pet_all = array with daily values of potential evapotranspiration (mm/d)
            t_all = array with daily values of air temperature (deg C)
        parameters
            array with parameter values having the following structure and scales
            BETA: parameter in soil routine
            FC: maximum soil moisture content
            K0: recession coefficient
            K1: recession coefficient
            K2: recession coefficient
            LP: limit for potential evapotranspiration
            PERC: percolation from upper to lower response box
            UZL: upper zone limit
            TT: temperature limit for snow/rain; distinguish rainfall from snowfall
            CFMAX: degree day factor; used for melting calculation
            CFR: refreezing factor
            CWH: liquid water holding capacity of the snowpack
            A: parameter of mizuRoute
            THETA: parameter of mizuRoute
        out_state
            if True, the state variables' value will be output
        rout_opt
            if True, route module will be performed

        Returns
        -------
        Union[tuple, torch.Tensor]
            q_sim = daily values of simulated streamflow (mm)
            sm = soil storage (mm)
            suz = upper zone storage (mm)
            slz = lower zone storage (mm)
            snowpack = snow depth (mm)
            et_act = actual evaporation (mm)
        """
        hbv_device = x.device
        precision = 1e-5
        buffer_time = self.warmup_length
        # Initialization
        if buffer_time > 0:
            with torch.no_grad():
                x_init = x[0:buffer_time, :, :]
                warmup_length = 0
                init_model = Hbv4Dpl(warmup_length, kernel_size=self.kernel_size)
                if init_model.warmup_length > 0:
                    raise RuntimeError(
                        "Please set warmup_length as 0 when initializing HBV model"
                    )
                _, snowpack, meltwater, sm, suz, slz = init_model(
                    x_init, parameters, out_state=True, rout_opt=False
                )
        else:

            # Without buff time, initialize state variables with zeros
            n_grid = x.shape[1]
            snowpack = (torch.zeros(n_grid, dtype=torch.float32) + 0.001).to(hbv_device)
            meltwater = (torch.zeros(n_grid, dtype=torch.float32) + 0.001).to(
                hbv_device
            )
            sm = (torch.zeros(n_grid, dtype=torch.float32) + 0.001).to(hbv_device)
            suz = (torch.zeros(n_grid, dtype=torch.float32) + 0.001).to(hbv_device)
            slz = (torch.zeros(n_grid, dtype=torch.float32) + 0.001).to(hbv_device)

        # the sequence must be p, pet and t
        p_all = x[buffer_time:, :, 0]
        pet_all = x[buffer_time:, :, 1]
        t_all = x[buffer_time:, :, 2]

        # scale the parameters
        par_beta = self.beta_scale[0] + parameters[:, 0] * (
            self.beta_scale[1] - self.beta_scale[0]
        )
        # parCET = parameters[:,1]
        par_fc = self.fc_scale[0] + parameters[:, 1] * (
            self.fc_scale[1] - self.fc_scale[0]
        )
        par_k0 = self.k0_scale[0] + parameters[:, 2] * (
            self.k0_scale[1] - self.k0_scale[0]
        )
        par_k1 = self.k1_scale[0] + parameters[:, 3] * (
            self.k1_scale[1] - self.k1_scale[0]
        )
        par_k2 = self.k2_scale[0] + parameters[:, 4] * (
            self.k2_scale[1] - self.k2_scale[0]
        )
        par_lp = self.lp_scale[0] + parameters[:, 5] * (
            self.lp_scale[1] - self.lp_scale[0]
        )
        # parMAXBAS = parameters[:,7]
        par_perc = self.perc_scale[0] + parameters[:, 6] * (
            self.perc_scale[1] - self.perc_scale[0]
        )
        par_uzl = self.uzl_scale[0] + parameters[:, 7] * (
            self.uzl_scale[1] - self.uzl_scale[0]
        )
        # parPCORR = parameters[:,10]
        par_tt = self.tt_scale[0] + parameters[:, 8] * (
            self.tt_scale[1] - self.tt_scale[0]
        )
        par_cfmax = self.cfmax_scale[0] + parameters[:, 9] * (
            self.cfmax_scale[1] - self.cfmax_scale[0]
        )
        # parSFCF = parameters[:,13]
        par_cfr = self.cfr_scale[0] + parameters[:, 10] * (
            self.cfr_scale[1] - self.cfr_scale[0]
        )
        par_cwh = self.cwh_scale[0] + parameters[:, 11] * (
            self.cwh_scale[1] - self.cwh_scale[0]
        )

        n_step, n_grid = p_all.size()
        # Apply correction factor to precipitation
        # p_all = parPCORR.repeat(n_step, 1) * p_all

        # Initialize time series of model variables
        q_sim = (torch.zeros(p_all.size(), dtype=torch.float32) + 0.001).to(hbv_device)
        # # Debug for the state variables
        # SMlog = np.zeros(p_all.size())
        log_sm = np.zeros(p_all.size())
        log_ps = np.zeros(p_all.size())
        log_swet = np.zeros(p_all.size())
        log_re = np.zeros(p_all.size())

        for i in range(n_step):
            # Separate precipitation into liquid and solid components
            precip = p_all[i, :]
            tempre = t_all[i, :]
            potent = pet_all[i, :]
            rain = torch.mul(precip, (tempre >= par_tt).type(torch.float32))
            snow = torch.mul(precip, (tempre < par_tt).type(torch.float32))
            # snow = snow * parSFCF

            # Snow
            snowpack = snowpack + snow
            melt = par_cfmax * (tempre - par_tt)
            # melt[melt < 0.0] = 0.0
            melt = torch.clamp(melt, min=0.0)
            # melt[melt > snowpack] = snowpack[melt > snowpack]
            melt = torch.min(melt, snowpack)
            meltwater = meltwater + melt
            snowpack = snowpack - melt
            refreezing = par_cfr * par_cfmax * (par_tt - tempre)
            # refreezing[refreezing < 0.0] = 0.0
            # refreezing[refreezing > meltwater] = meltwater[refreezing > meltwater]
            refreezing = torch.clamp(refreezing, min=0.0)
            refreezing = torch.min(refreezing, meltwater)
            snowpack = snowpack + refreezing
            meltwater = meltwater - refreezing
            to_soil = meltwater - (par_cwh * snowpack)
            # to_soil[to_soil < 0.0] = 0.0
            to_soil = torch.clamp(to_soil, min=0.0)
            meltwater = meltwater - to_soil

            # Soil and evaporation
            soil_wetness = (sm / par_fc) ** par_beta
            # soil_wetness[soil_wetness < 0.0] = 0.0
            # soil_wetness[soil_wetness > 1.0] = 1.0
            soil_wetness = torch.clamp(soil_wetness, min=0.0, max=1.0)
            recharge = (rain + to_soil) * soil_wetness

            # log for displaying
            log_sm[i, :] = sm.detach().cpu().numpy()
            log_ps[i, :] = (rain + to_soil).detach().cpu().numpy()
            log_swet[i, :] = (sm / par_fc).detach().cpu().numpy()
            log_re[i, :] = recharge.detach().cpu().numpy()

            sm = sm + rain + to_soil - recharge
            excess = sm - par_fc
            # excess[excess < 0.0] = 0.0
            excess = torch.clamp(excess, min=0.0)
            sm = sm - excess
            evap_factor = sm / (par_lp * par_fc)
            # evap_factor[evap_factor < 0.0] = 0.0
            # evap_factor[evap_factor > 1.0] = 1.0
            evap_factor = torch.clamp(evap_factor, min=0.0, max=1.0)
            et_act = potent * evap_factor
            et_act = torch.min(sm, et_act)
            sm = torch.clamp(
                sm - et_act, min=precision
            )  # sm can not be zero for gradient tracking

            # Groundwater boxes
            suz = suz + recharge + excess
            perc = torch.min(suz, par_perc)
            suz = suz - perc
            q0 = par_k0 * torch.clamp(suz - par_uzl, min=0.0)
            suz = suz - q0
            q1 = par_k1 * suz
            suz = suz - q1
            slz = slz + perc
            q2 = par_k2 * slz
            slz = slz - q2
            q_sim[i, :] = q0 + q1 + q2

            # # for debug state variables
            # SMlog[t,:] = sm.detach().cpu().numpy()

        if rout_opt is True:  # routing
            temp_a = self.a_scale[0] + parameters[:, -2] * (
                self.a_scale[1] - self.a_scale[0]
            )
            temp_b = self.theta_scale[0] + parameters[:, -1] * (
                self.theta_scale[1] - self.theta_scale[0]
            )
            rout_a = temp_a.repeat(n_step, 1).unsqueeze(-1)
            rout_b = temp_b.repeat(n_step, 1).unsqueeze(-1)
            uh_from_gamma = uh_gamma(rout_a, rout_b, len_uh=self.kernel_size)
            rf = torch.unsqueeze(q_sim, -1)
            qs = uh_conv(rf, uh_from_gamma)

        else:
            qs = torch.unsqueeze(q_sim, -1)  # add a dimension
        return (qs, snowpack, meltwater, sm, suz, slz) if out_state is True else qs

__init__(self, warmup_length, kernel_size=15) special

Initiate a HBV instance

Parameters

warmup_length : type description kernel_size : int, optional conv kernel for unit hydrograph, by default 15

Source code in torchhydro/models/dpl4hbv.py
def __init__(self, warmup_length, kernel_size=15):
    """Initiate a HBV instance

    Parameters
    ----------
    warmup_length : _type_
        _description_
    kernel_size : int, optional
        conv kernel for unit hydrograph, by default 15
    """
    super(Hbv4Dpl, self).__init__()
    self.name = "HBV"
    self.params_names = MODEL_PARAM_DICT["hbv"]["param_name"]
    parasca_lst = MODEL_PARAM_DICT["hbv"]["param_range"]
    self.beta_scale = parasca_lst["BETA"]
    self.fc_scale = parasca_lst["FC"]
    self.k0_scale = parasca_lst["K0"]
    self.k1_scale = parasca_lst["K1"]
    self.k2_scale = parasca_lst["K2"]
    self.lp_scale = parasca_lst["LP"]
    self.perc_scale = parasca_lst["PERC"]
    self.uzl_scale = parasca_lst["UZL"]
    self.tt_scale = parasca_lst["TT"]
    self.cfmax_scale = parasca_lst["CFMAX"]
    self.cfr_scale = parasca_lst["CFR"]
    self.cwh_scale = parasca_lst["CWH"]
    self.a_scale = parasca_lst["A"]
    self.theta_scale = parasca_lst["THETA"]
    self.warmup_length = warmup_length
    self.kernel_size = kernel_size
    # there are 3 input vars in HBV: P, PET and TEMPERATURE
    self.feature_size = 3

forward(self, x, parameters, out_state=False, rout_opt=True)

Runs the HBV-light hydrological model (Seibert, 2005).

The code comes from mhpi/hydro-dev. NaN values have to be removed from the inputs.

Parameters

x p_all = array with daily values of precipitation (mm/d) pet_all = array with daily values of potential evapotranspiration (mm/d) t_all = array with daily values of air temperature (deg C) parameters array with parameter values having the following structure and scales BETA: parameter in soil routine FC: maximum soil moisture content K0: recession coefficient K1: recession coefficient K2: recession coefficient LP: limit for potential evapotranspiration PERC: percolation from upper to lower response box UZL: upper zone limit TT: temperature limit for snow/rain; distinguish rainfall from snowfall CFMAX: degree day factor; used for melting calculation CFR: refreezing factor CWH: liquid water holding capacity of the snowpack A: parameter of mizuRoute THETA: parameter of mizuRoute out_state if True, the state variables' value will be output rout_opt if True, route module will be performed

Returns

Union[tuple, torch.Tensor] q_sim = daily values of simulated streamflow (mm) sm = soil storage (mm) suz = upper zone storage (mm) slz = lower zone storage (mm) snowpack = snow depth (mm) et_act = actual evaporation (mm)

Source code in torchhydro/models/dpl4hbv.py
def forward(
    self, x, parameters, out_state=False, rout_opt=True
) -> Union[tuple, torch.Tensor]:
    """
    Runs the HBV-light hydrological model (Seibert, 2005).

    The code comes from mhpi/hydro-dev.
    NaN values have to be removed from the inputs.

    Parameters
    ----------
    x
        p_all = array with daily values of precipitation (mm/d)
        pet_all = array with daily values of potential evapotranspiration (mm/d)
        t_all = array with daily values of air temperature (deg C)
    parameters
        array with parameter values having the following structure and scales
        BETA: parameter in soil routine
        FC: maximum soil moisture content
        K0: recession coefficient
        K1: recession coefficient
        K2: recession coefficient
        LP: limit for potential evapotranspiration
        PERC: percolation from upper to lower response box
        UZL: upper zone limit
        TT: temperature limit for snow/rain; distinguish rainfall from snowfall
        CFMAX: degree day factor; used for melting calculation
        CFR: refreezing factor
        CWH: liquid water holding capacity of the snowpack
        A: parameter of mizuRoute
        THETA: parameter of mizuRoute
    out_state
        if True, the state variables' value will be output
    rout_opt
        if True, route module will be performed

    Returns
    -------
    Union[tuple, torch.Tensor]
        q_sim = daily values of simulated streamflow (mm)
        sm = soil storage (mm)
        suz = upper zone storage (mm)
        slz = lower zone storage (mm)
        snowpack = snow depth (mm)
        et_act = actual evaporation (mm)
    """
    hbv_device = x.device
    precision = 1e-5
    buffer_time = self.warmup_length
    # Initialization
    if buffer_time > 0:
        with torch.no_grad():
            x_init = x[0:buffer_time, :, :]
            warmup_length = 0
            init_model = Hbv4Dpl(warmup_length, kernel_size=self.kernel_size)
            if init_model.warmup_length > 0:
                raise RuntimeError(
                    "Please set warmup_length as 0 when initializing HBV model"
                )
            _, snowpack, meltwater, sm, suz, slz = init_model(
                x_init, parameters, out_state=True, rout_opt=False
            )
    else:

        # Without buff time, initialize state variables with zeros
        n_grid = x.shape[1]
        snowpack = (torch.zeros(n_grid, dtype=torch.float32) + 0.001).to(hbv_device)
        meltwater = (torch.zeros(n_grid, dtype=torch.float32) + 0.001).to(
            hbv_device
        )
        sm = (torch.zeros(n_grid, dtype=torch.float32) + 0.001).to(hbv_device)
        suz = (torch.zeros(n_grid, dtype=torch.float32) + 0.001).to(hbv_device)
        slz = (torch.zeros(n_grid, dtype=torch.float32) + 0.001).to(hbv_device)

    # the sequence must be p, pet and t
    p_all = x[buffer_time:, :, 0]
    pet_all = x[buffer_time:, :, 1]
    t_all = x[buffer_time:, :, 2]

    # scale the parameters
    par_beta = self.beta_scale[0] + parameters[:, 0] * (
        self.beta_scale[1] - self.beta_scale[0]
    )
    # parCET = parameters[:,1]
    par_fc = self.fc_scale[0] + parameters[:, 1] * (
        self.fc_scale[1] - self.fc_scale[0]
    )
    par_k0 = self.k0_scale[0] + parameters[:, 2] * (
        self.k0_scale[1] - self.k0_scale[0]
    )
    par_k1 = self.k1_scale[0] + parameters[:, 3] * (
        self.k1_scale[1] - self.k1_scale[0]
    )
    par_k2 = self.k2_scale[0] + parameters[:, 4] * (
        self.k2_scale[1] - self.k2_scale[0]
    )
    par_lp = self.lp_scale[0] + parameters[:, 5] * (
        self.lp_scale[1] - self.lp_scale[0]
    )
    # parMAXBAS = parameters[:,7]
    par_perc = self.perc_scale[0] + parameters[:, 6] * (
        self.perc_scale[1] - self.perc_scale[0]
    )
    par_uzl = self.uzl_scale[0] + parameters[:, 7] * (
        self.uzl_scale[1] - self.uzl_scale[0]
    )
    # parPCORR = parameters[:,10]
    par_tt = self.tt_scale[0] + parameters[:, 8] * (
        self.tt_scale[1] - self.tt_scale[0]
    )
    par_cfmax = self.cfmax_scale[0] + parameters[:, 9] * (
        self.cfmax_scale[1] - self.cfmax_scale[0]
    )
    # parSFCF = parameters[:,13]
    par_cfr = self.cfr_scale[0] + parameters[:, 10] * (
        self.cfr_scale[1] - self.cfr_scale[0]
    )
    par_cwh = self.cwh_scale[0] + parameters[:, 11] * (
        self.cwh_scale[1] - self.cwh_scale[0]
    )

    n_step, n_grid = p_all.size()
    # Apply correction factor to precipitation
    # p_all = parPCORR.repeat(n_step, 1) * p_all

    # Initialize time series of model variables
    q_sim = (torch.zeros(p_all.size(), dtype=torch.float32) + 0.001).to(hbv_device)
    # # Debug for the state variables
    # SMlog = np.zeros(p_all.size())
    log_sm = np.zeros(p_all.size())
    log_ps = np.zeros(p_all.size())
    log_swet = np.zeros(p_all.size())
    log_re = np.zeros(p_all.size())

    for i in range(n_step):
        # Separate precipitation into liquid and solid components
        precip = p_all[i, :]
        tempre = t_all[i, :]
        potent = pet_all[i, :]
        rain = torch.mul(precip, (tempre >= par_tt).type(torch.float32))
        snow = torch.mul(precip, (tempre < par_tt).type(torch.float32))
        # snow = snow * parSFCF

        # Snow
        snowpack = snowpack + snow
        melt = par_cfmax * (tempre - par_tt)
        # melt[melt < 0.0] = 0.0
        melt = torch.clamp(melt, min=0.0)
        # melt[melt > snowpack] = snowpack[melt > snowpack]
        melt = torch.min(melt, snowpack)
        meltwater = meltwater + melt
        snowpack = snowpack - melt
        refreezing = par_cfr * par_cfmax * (par_tt - tempre)
        # refreezing[refreezing < 0.0] = 0.0
        # refreezing[refreezing > meltwater] = meltwater[refreezing > meltwater]
        refreezing = torch.clamp(refreezing, min=0.0)
        refreezing = torch.min(refreezing, meltwater)
        snowpack = snowpack + refreezing
        meltwater = meltwater - refreezing
        to_soil = meltwater - (par_cwh * snowpack)
        # to_soil[to_soil < 0.0] = 0.0
        to_soil = torch.clamp(to_soil, min=0.0)
        meltwater = meltwater - to_soil

        # Soil and evaporation
        soil_wetness = (sm / par_fc) ** par_beta
        # soil_wetness[soil_wetness < 0.0] = 0.0
        # soil_wetness[soil_wetness > 1.0] = 1.0
        soil_wetness = torch.clamp(soil_wetness, min=0.0, max=1.0)
        recharge = (rain + to_soil) * soil_wetness

        # log for displaying
        log_sm[i, :] = sm.detach().cpu().numpy()
        log_ps[i, :] = (rain + to_soil).detach().cpu().numpy()
        log_swet[i, :] = (sm / par_fc).detach().cpu().numpy()
        log_re[i, :] = recharge.detach().cpu().numpy()

        sm = sm + rain + to_soil - recharge
        excess = sm - par_fc
        # excess[excess < 0.0] = 0.0
        excess = torch.clamp(excess, min=0.0)
        sm = sm - excess
        evap_factor = sm / (par_lp * par_fc)
        # evap_factor[evap_factor < 0.0] = 0.0
        # evap_factor[evap_factor > 1.0] = 1.0
        evap_factor = torch.clamp(evap_factor, min=0.0, max=1.0)
        et_act = potent * evap_factor
        et_act = torch.min(sm, et_act)
        sm = torch.clamp(
            sm - et_act, min=precision
        )  # sm can not be zero for gradient tracking

        # Groundwater boxes
        suz = suz + recharge + excess
        perc = torch.min(suz, par_perc)
        suz = suz - perc
        q0 = par_k0 * torch.clamp(suz - par_uzl, min=0.0)
        suz = suz - q0
        q1 = par_k1 * suz
        suz = suz - q1
        slz = slz + perc
        q2 = par_k2 * slz
        slz = slz - q2
        q_sim[i, :] = q0 + q1 + q2

        # # for debug state variables
        # SMlog[t,:] = sm.detach().cpu().numpy()

    if rout_opt is True:  # routing
        temp_a = self.a_scale[0] + parameters[:, -2] * (
            self.a_scale[1] - self.a_scale[0]
        )
        temp_b = self.theta_scale[0] + parameters[:, -1] * (
            self.theta_scale[1] - self.theta_scale[0]
        )
        rout_a = temp_a.repeat(n_step, 1).unsqueeze(-1)
        rout_b = temp_b.repeat(n_step, 1).unsqueeze(-1)
        uh_from_gamma = uh_gamma(rout_a, rout_b, len_uh=self.kernel_size)
        rf = torch.unsqueeze(q_sim, -1)
        qs = uh_conv(rf, uh_from_gamma)

    else:
        qs = torch.unsqueeze(q_sim, -1)  # add a dimension
    return (qs, snowpack, meltwater, sm, suz, slz) if out_state is True else qs

dpl4xaj

Author: Wenyu Ouyang Date: 2023-09-19 09:36:25 LastEditTime: 2025-06-25 15:50:57 LastEditors: Wenyu Ouyang Description: The method comes from this paper: https://doi.org/10.1038/s41467-021-26107-z It use Deep Learning (DL) methods to Learn the Parameters of physics-based models (PBM), which is called "differentiable parameter learning" (dPL). FilePath: orchhydro orchhydro\models\dpl4xaj.py Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.

DplAnnXaj (Module)

Source code in torchhydro/models/dpl4xaj.py
class DplAnnXaj(nn.Module):
    def __init__(
        self,
        n_input_features: int,
        n_output_features: int,
        n_hidden_states: Union[int, tuple, list],
        kernel_size: int,
        warmup_length: int,
        dr: Union[int, tuple, list] = 0.1,
        param_limit_func="sigmoid",
        param_test_way="final",
        source_book="HF",
        source_type="sources",
        return_et=True,
    ):
        """
        Differential Parameter learning model only with attributes as DL model's input: ANN -> Param -> Gr4j

        The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

        Parameters
        ----------
        n_input_features
            the number of input features of ANN
        n_output_features
            the number of output features of ANN, and it should be equal to the number of learning parameters in XAJ
        n_hidden_states
            the number of hidden features of ANN; it could be Union[int, tuple, list]
        kernel_size
            the time length of unit hydrograph
        warmup_length
            the length of warmup periods;
            hydrologic models need a warmup period to generate reasonable initial state values
        param_limit_func
            function used to limit the range of params; now it is sigmoid or clamp function
        param_test_way
            how we use parameters from dl model when testing;
            now we have three ways:
            1. "final" -- use the final period's parameter for each period
            2. "mean_time" -- Mean values of all periods' parameters is used
            3. "mean_basin" -- Mean values of all basins' final periods' parameters is used
        return_et
            if True, return evapotranspiration
        """
        super(DplAnnXaj, self).__init__()
        self.dl_model = SimpleAnn(
            n_input_features, n_output_features, n_hidden_states, dr
        )
        self.pb_model = Xaj4Dpl(
            kernel_size, warmup_length, source_book=source_book, source_type=source_type
        )
        self.param_func = param_limit_func
        self.param_test_way = param_test_way
        self.return_et = return_et

    def forward(self, x, z):
        """
        Differential parameter learning

        z (normalized input) -> ANN -> param -> + x (not normalized) -> gr4j -> q
        Parameters will be denormalized in gr4j model

        Parameters
        ----------
        x
            not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature]
        z
            normalized data used for DL model; a 2-dim tensor. [batch, feature]

        Returns
        -------
        torch.Tensor
            one time forward result
        """
        q, e = ann_pbm(self.dl_model, self.pb_model, self.param_func, x, z)
        return torch.cat([q, e], dim=-1) if self.return_et else q

__init__(self, n_input_features, n_output_features, n_hidden_states, kernel_size, warmup_length, dr=0.1, param_limit_func='sigmoid', param_test_way='final', source_book='HF', source_type='sources', return_et=True) special

Differential Parameter learning model only with attributes as DL model's input: ANN -> Param -> Gr4j

The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

Parameters

n_input_features the number of input features of ANN n_output_features the number of output features of ANN, and it should be equal to the number of learning parameters in XAJ n_hidden_states the number of hidden features of ANN; it could be Union[int, tuple, list] kernel_size the time length of unit hydrograph warmup_length the length of warmup periods; hydrologic models need a warmup period to generate reasonable initial state values param_limit_func function used to limit the range of params; now it is sigmoid or clamp function param_test_way how we use parameters from dl model when testing; now we have three ways: 1. "final" -- use the final period's parameter for each period 2. "mean_time" -- Mean values of all periods' parameters is used 3. "mean_basin" -- Mean values of all basins' final periods' parameters is used return_et if True, return evapotranspiration

Source code in torchhydro/models/dpl4xaj.py
def __init__(
    self,
    n_input_features: int,
    n_output_features: int,
    n_hidden_states: Union[int, tuple, list],
    kernel_size: int,
    warmup_length: int,
    dr: Union[int, tuple, list] = 0.1,
    param_limit_func="sigmoid",
    param_test_way="final",
    source_book="HF",
    source_type="sources",
    return_et=True,
):
    """
    Differential Parameter learning model only with attributes as DL model's input: ANN -> Param -> Gr4j

    The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

    Parameters
    ----------
    n_input_features
        the number of input features of ANN
    n_output_features
        the number of output features of ANN, and it should be equal to the number of learning parameters in XAJ
    n_hidden_states
        the number of hidden features of ANN; it could be Union[int, tuple, list]
    kernel_size
        the time length of unit hydrograph
    warmup_length
        the length of warmup periods;
        hydrologic models need a warmup period to generate reasonable initial state values
    param_limit_func
        function used to limit the range of params; now it is sigmoid or clamp function
    param_test_way
        how we use parameters from dl model when testing;
        now we have three ways:
        1. "final" -- use the final period's parameter for each period
        2. "mean_time" -- Mean values of all periods' parameters is used
        3. "mean_basin" -- Mean values of all basins' final periods' parameters is used
    return_et
        if True, return evapotranspiration
    """
    super(DplAnnXaj, self).__init__()
    self.dl_model = SimpleAnn(
        n_input_features, n_output_features, n_hidden_states, dr
    )
    self.pb_model = Xaj4Dpl(
        kernel_size, warmup_length, source_book=source_book, source_type=source_type
    )
    self.param_func = param_limit_func
    self.param_test_way = param_test_way
    self.return_et = return_et

forward(self, x, z)

Differential parameter learning

z (normalized input) -> ANN -> param -> + x (not normalized) -> gr4j -> q Parameters will be denormalized in gr4j model

Parameters

x not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature] z normalized data used for DL model; a 2-dim tensor. [batch, feature]

Returns

torch.Tensor one time forward result

Source code in torchhydro/models/dpl4xaj.py
def forward(self, x, z):
    """
    Differential parameter learning

    z (normalized input) -> ANN -> param -> + x (not normalized) -> gr4j -> q
    Parameters will be denormalized in gr4j model

    Parameters
    ----------
    x
        not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature]
    z
        normalized data used for DL model; a 2-dim tensor. [batch, feature]

    Returns
    -------
    torch.Tensor
        one time forward result
    """
    q, e = ann_pbm(self.dl_model, self.pb_model, self.param_func, x, z)
    return torch.cat([q, e], dim=-1) if self.return_et else q

DplLstmXaj (Module)

Source code in torchhydro/models/dpl4xaj.py
class DplLstmXaj(nn.Module):
    def __init__(
        self,
        n_input_features,
        n_output_features,
        n_hidden_states,
        kernel_size,
        warmup_length,
        param_limit_func="sigmoid",
        param_test_way="final",
        source_book="HF",
        source_type="sources",
        return_et=True,
    ):
        """
        Differential Parameter learning model: LSTM -> Param -> XAJ

        The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

        Parameters
        ----------
        n_input_features
            the number of input features of LSTM
        n_output_features
            the number of output features of LSTM, and it should be equal to the number of learning parameters in XAJ
        n_hidden_states
            the number of hidden features of LSTM
        kernel_size
            the time length of unit hydrograph
        warmup_length
            the length of warmup periods;
            hydrologic models need a warmup period to generate reasonable initial state values
        param_limit_func
            function used to limit the range of params; now it is sigmoid or clamp function
        param_test_way
            how we use parameters from dl model when testing;
            now we have three ways:
            1. "final" -- use the final period's parameter for each period
            2. "mean_time" -- Mean values of all periods' parameters is used
            3. "mean_basin" -- Mean values of all basins' final periods' parameters is used
        return_et
            if True, return evapotranspiration
        """
        super(DplLstmXaj, self).__init__()
        self.dl_model = SimpleLSTM(n_input_features, n_output_features, n_hidden_states)
        self.pb_model = Xaj4Dpl(
            kernel_size, warmup_length, source_book=source_book, source_type=source_type
        )
        self.param_func = param_limit_func
        self.param_test_way = param_test_way
        self.return_et = return_et

    def forward(self, x, z):
        """
        Differential parameter learning

        z (normalized input) -> lstm -> param -> + x (not normalized) -> xaj -> q
        Parameters will be denormalized in xaj model

        Parameters
        ----------
        x
            not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature]
        z
            normalized data used for DL model; a sequence-first 3-dim tensor. [sequence, batch, feature]

        Returns
        -------
        torch.Tensor
            one time forward result
        """
        q, e = lstm_pbm(self.dl_model, self.pb_model, self.param_func, x, z)
        return torch.cat([q, e], dim=-1) if self.return_et else q

__init__(self, n_input_features, n_output_features, n_hidden_states, kernel_size, warmup_length, param_limit_func='sigmoid', param_test_way='final', source_book='HF', source_type='sources', return_et=True) special

Differential Parameter learning model: LSTM -> Param -> XAJ

The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

Parameters

n_input_features the number of input features of LSTM n_output_features the number of output features of LSTM, and it should be equal to the number of learning parameters in XAJ n_hidden_states the number of hidden features of LSTM kernel_size the time length of unit hydrograph warmup_length the length of warmup periods; hydrologic models need a warmup period to generate reasonable initial state values param_limit_func function used to limit the range of params; now it is sigmoid or clamp function param_test_way how we use parameters from dl model when testing; now we have three ways: 1. "final" -- use the final period's parameter for each period 2. "mean_time" -- Mean values of all periods' parameters is used 3. "mean_basin" -- Mean values of all basins' final periods' parameters is used return_et if True, return evapotranspiration

Source code in torchhydro/models/dpl4xaj.py
def __init__(
    self,
    n_input_features,
    n_output_features,
    n_hidden_states,
    kernel_size,
    warmup_length,
    param_limit_func="sigmoid",
    param_test_way="final",
    source_book="HF",
    source_type="sources",
    return_et=True,
):
    """
    Differential Parameter learning model: LSTM -> Param -> XAJ

    The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

    Parameters
    ----------
    n_input_features
        the number of input features of LSTM
    n_output_features
        the number of output features of LSTM, and it should be equal to the number of learning parameters in XAJ
    n_hidden_states
        the number of hidden features of LSTM
    kernel_size
        the time length of unit hydrograph
    warmup_length
        the length of warmup periods;
        hydrologic models need a warmup period to generate reasonable initial state values
    param_limit_func
        function used to limit the range of params; now it is sigmoid or clamp function
    param_test_way
        how we use parameters from dl model when testing;
        now we have three ways:
        1. "final" -- use the final period's parameter for each period
        2. "mean_time" -- Mean values of all periods' parameters is used
        3. "mean_basin" -- Mean values of all basins' final periods' parameters is used
    return_et
        if True, return evapotranspiration
    """
    super(DplLstmXaj, self).__init__()
    self.dl_model = SimpleLSTM(n_input_features, n_output_features, n_hidden_states)
    self.pb_model = Xaj4Dpl(
        kernel_size, warmup_length, source_book=source_book, source_type=source_type
    )
    self.param_func = param_limit_func
    self.param_test_way = param_test_way
    self.return_et = return_et

forward(self, x, z)

Differential parameter learning

z (normalized input) -> lstm -> param -> + x (not normalized) -> xaj -> q Parameters will be denormalized in xaj model

Parameters

x not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature] z normalized data used for DL model; a sequence-first 3-dim tensor. [sequence, batch, feature]

Returns

torch.Tensor one time forward result

Source code in torchhydro/models/dpl4xaj.py
def forward(self, x, z):
    """
    Differential parameter learning

    z (normalized input) -> lstm -> param -> + x (not normalized) -> xaj -> q
    Parameters will be denormalized in xaj model

    Parameters
    ----------
    x
        not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature]
    z
        normalized data used for DL model; a sequence-first 3-dim tensor. [sequence, batch, feature]

    Returns
    -------
    torch.Tensor
        one time forward result
    """
    q, e = lstm_pbm(self.dl_model, self.pb_model, self.param_func, x, z)
    return torch.cat([q, e], dim=-1) if self.return_et else q

Xaj4Dpl (Module)

XAJ model for Differential Parameter learning

Source code in torchhydro/models/dpl4xaj.py
class Xaj4Dpl(nn.Module):
    """
    XAJ model for Differential Parameter learning
    """

    def __init__(
        self,
        kernel_size: int,
        warmup_length: int,
        source_book="HF",
        source_type="sources",
    ):
        """
        Parameters
        ----------
        kernel_size
            the time length of unit hydrograph
        warmup_length
            the length of warmup periods;
            XAJ needs a warmup period to generate reasonable initial state values
        """
        super(Xaj4Dpl, self).__init__()
        self.params_names = MODEL_PARAM_DICT["xaj_mz"]["param_name"]
        param_range = MODEL_PARAM_DICT["xaj_mz"]["param_range"]
        self.k_scale = param_range["K"]
        self.b_scale = param_range["B"]
        self.im_sacle = param_range["IM"]
        self.um_scale = param_range["UM"]
        self.lm_scale = param_range["LM"]
        self.dm_scale = param_range["DM"]
        self.c_scale = param_range["C"]
        self.sm_scale = param_range["SM"]
        self.ex_scale = param_range["EX"]
        self.ki_scale = param_range["KI"]
        self.kg_scale = param_range["KG"]
        self.a_scale = param_range["A"]
        self.theta_scale = param_range["THETA"]
        self.ci_scale = param_range["CI"]
        self.cg_scale = param_range["CG"]
        self.kernel_size = kernel_size
        self.warmup_length = warmup_length
        # there are 2 input variables in XAJ: P and PET
        self.feature_size = 2
        self.source_book = source_book
        self.source_type = source_type

    def forward(self, p_and_e, parameters, return_state=False):
        """
        run XAJ model

        Parameters
        ----------
        p_and_e
            precipitation and potential evapotranspiration
        parameters
            parameters of XAJ model
        return_state
            if True, return state values, mainly for warmup periods

        Returns
        -------
        torch.Tensor
            streamflow got by XAJ
        """
        xaj_device = p_and_e.device
        # denormalize the parameters to general range
        k = self.k_scale[0] + parameters[:, 0] * (self.k_scale[1] - self.k_scale[0])
        b = self.b_scale[0] + parameters[:, 1] * (self.b_scale[1] - self.b_scale[0])
        im = self.im_sacle[0] + parameters[:, 2] * (self.im_sacle[1] - self.im_sacle[0])
        um = self.um_scale[0] + parameters[:, 3] * (self.um_scale[1] - self.um_scale[0])
        lm = self.lm_scale[0] + parameters[:, 4] * (self.lm_scale[1] - self.lm_scale[0])
        dm = self.dm_scale[0] + parameters[:, 5] * (self.dm_scale[1] - self.dm_scale[0])
        c = self.c_scale[0] + parameters[:, 6] * (self.c_scale[1] - self.c_scale[0])
        sm = self.sm_scale[0] + parameters[:, 7] * (self.sm_scale[1] - self.sm_scale[0])
        ex = self.ex_scale[0] + parameters[:, 8] * (self.ex_scale[1] - self.ex_scale[0])
        ki_ = self.ki_scale[0] + parameters[:, 9] * (
            self.ki_scale[1] - self.ki_scale[0]
        )
        kg_ = self.kg_scale[0] + parameters[:, 10] * (
            self.kg_scale[1] - self.kg_scale[0]
        )
        # ki+kg should be smaller than 1; if not, we scale them, but note float only contain 4 digits, so we need 0.999
        ki = torch.where(
            ki_ + kg_ < 1.0,
            ki_,
            (1 - PRECISION) / (ki_ + kg_) * ki_,
        )
        kg = torch.where(
            ki_ + kg_ < 1.0,
            kg_,
            (1 - PRECISION) / (ki_ + kg_) * kg_,
        )
        a = self.a_scale[0] + parameters[:, 11] * (self.a_scale[1] - self.a_scale[0])
        theta = self.theta_scale[0] + parameters[:, 12] * (
            self.theta_scale[1] - self.theta_scale[0]
        )
        ci = self.ci_scale[0] + parameters[:, 13] * (
            self.ci_scale[1] - self.ci_scale[0]
        )
        cg = self.cg_scale[0] + parameters[:, 14] * (
            self.cg_scale[1] - self.cg_scale[0]
        )

        # initialize state values
        warmup_length = self.warmup_length
        if warmup_length > 0:
            # set no_grad for warmup periods
            with torch.no_grad():
                p_and_e_warmup = p_and_e[0:warmup_length, :, :]
                cal_init_xaj4dpl = Xaj4Dpl(
                    self.kernel_size, 0, self.source_book, self.source_type
                )
                if cal_init_xaj4dpl.warmup_length > 0:
                    raise RuntimeError("Please set init model's warmup length to 0!!!")
                _, _, *w0, s0, fr0, qi0, qg0 = cal_init_xaj4dpl(
                    p_and_e_warmup, parameters, return_state=True
                )
        else:
            # use detach func to make wu0 no_grad as it is an initial value
            w0 = (0.5 * (um.detach()), 0.5 * (lm.detach()), 0.5 * (dm.detach()))
            s0 = 0.5 * (sm.detach())
            fr0 = torch.full(ci.size(), 0.1).to(xaj_device)
            qi0 = torch.full(ci.size(), 0.1).to(xaj_device)
            qg0 = torch.full(cg.size(), 0.1).to(xaj_device)

        inputs = p_and_e[warmup_length:, :, :]
        runoff_ims_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
        rss_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
        ris_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
        rgs_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
        es_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
        for i in range(inputs.shape[0]):
            if i == 0:
                (r, rim, e, pe), w = xaj_generation(
                    inputs[i, :, :], k, b, im, um, lm, dm, c, *w0
                )
                if self.source_type == "sources":
                    (rs, ri, rg), (s, fr) = xaj_sources(
                        pe, r, sm, ex, ki, kg, s0, fr0, book=self.source_book
                    )
                elif self.source_type == "sources5mm":
                    (rs, ri, rg), (s, fr) = xaj_sources5mm(
                        pe, r, sm, ex, ki, kg, s0, fr0, book=self.source_book
                    )
                else:
                    raise NotImplementedError("No such divide-sources method")
            else:
                (r, rim, e, pe), w = xaj_generation(
                    inputs[i, :, :], k, b, im, um, lm, dm, c, *w
                )
                if self.source_type == "sources":
                    (rs, ri, rg), (s, fr) = xaj_sources(
                        pe, r, sm, ex, ki, kg, s, fr, book=self.source_book
                    )
                elif self.source_type == "sources5mm":
                    (rs, ri, rg), (s, fr) = xaj_sources5mm(
                        pe, r, sm, ex, ki, kg, s, fr, book=self.source_book
                    )
                else:
                    raise NotImplementedError("No such divide-sources method")
            # impevious part is pe * im
            runoff_ims_[i, :] = rim
            # so for non-imprvious part, the result should be corrected
            rss_[i, :] = rs * (1 - im)
            ris_[i, :] = ri * (1 - im)
            rgs_[i, :] = rg * (1 - im)
            es_[i, :] = e
            # rss_[i, :] = 0.7 * r
            # ris_[i, :] = 0.2 * r
            # rgs_[i, :] = 0.1 * r
        # seq, batch, feature
        runoff_im = torch.unsqueeze(runoff_ims_, dim=2)
        rss = torch.unsqueeze(rss_, dim=2)
        es = torch.unsqueeze(es_, dim=2)

        conv_uh = KernelConv(a, theta, self.kernel_size)
        qs_ = conv_uh(runoff_im + rss)
        qs = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
        for i in range(inputs.shape[0]):
            if i == 0:
                qi = linear_reservoir(ris_[i], ci, qi0)
                qg = linear_reservoir(rgs_[i], cg, qg0)
            else:
                qi = linear_reservoir(ris_[i], ci, qi)
                qg = linear_reservoir(rgs_[i], cg, qg)
            qs[i, :] = qs_[i, :, 0] + qi + qg
        # seq, batch, feature
        q_sim = torch.unsqueeze(qs, dim=2)
        if return_state:
            return q_sim, es, *w, s, fr, qi, qg
        return q_sim, es

__init__(self, kernel_size, warmup_length, source_book='HF', source_type='sources') special

Parameters

kernel_size the time length of unit hydrograph warmup_length the length of warmup periods; XAJ needs a warmup period to generate reasonable initial state values

Source code in torchhydro/models/dpl4xaj.py
def __init__(
    self,
    kernel_size: int,
    warmup_length: int,
    source_book="HF",
    source_type="sources",
):
    """
    Parameters
    ----------
    kernel_size
        the time length of unit hydrograph
    warmup_length
        the length of warmup periods;
        XAJ needs a warmup period to generate reasonable initial state values
    """
    super(Xaj4Dpl, self).__init__()
    self.params_names = MODEL_PARAM_DICT["xaj_mz"]["param_name"]
    param_range = MODEL_PARAM_DICT["xaj_mz"]["param_range"]
    self.k_scale = param_range["K"]
    self.b_scale = param_range["B"]
    self.im_sacle = param_range["IM"]
    self.um_scale = param_range["UM"]
    self.lm_scale = param_range["LM"]
    self.dm_scale = param_range["DM"]
    self.c_scale = param_range["C"]
    self.sm_scale = param_range["SM"]
    self.ex_scale = param_range["EX"]
    self.ki_scale = param_range["KI"]
    self.kg_scale = param_range["KG"]
    self.a_scale = param_range["A"]
    self.theta_scale = param_range["THETA"]
    self.ci_scale = param_range["CI"]
    self.cg_scale = param_range["CG"]
    self.kernel_size = kernel_size
    self.warmup_length = warmup_length
    # there are 2 input variables in XAJ: P and PET
    self.feature_size = 2
    self.source_book = source_book
    self.source_type = source_type

forward(self, p_and_e, parameters, return_state=False)

run XAJ model

Parameters

p_and_e precipitation and potential evapotranspiration parameters parameters of XAJ model return_state if True, return state values, mainly for warmup periods

Returns

torch.Tensor streamflow got by XAJ

Source code in torchhydro/models/dpl4xaj.py
def forward(self, p_and_e, parameters, return_state=False):
    """
    run XAJ model

    Parameters
    ----------
    p_and_e
        precipitation and potential evapotranspiration
    parameters
        parameters of XAJ model
    return_state
        if True, return state values, mainly for warmup periods

    Returns
    -------
    torch.Tensor
        streamflow got by XAJ
    """
    xaj_device = p_and_e.device
    # denormalize the parameters to general range
    k = self.k_scale[0] + parameters[:, 0] * (self.k_scale[1] - self.k_scale[0])
    b = self.b_scale[0] + parameters[:, 1] * (self.b_scale[1] - self.b_scale[0])
    im = self.im_sacle[0] + parameters[:, 2] * (self.im_sacle[1] - self.im_sacle[0])
    um = self.um_scale[0] + parameters[:, 3] * (self.um_scale[1] - self.um_scale[0])
    lm = self.lm_scale[0] + parameters[:, 4] * (self.lm_scale[1] - self.lm_scale[0])
    dm = self.dm_scale[0] + parameters[:, 5] * (self.dm_scale[1] - self.dm_scale[0])
    c = self.c_scale[0] + parameters[:, 6] * (self.c_scale[1] - self.c_scale[0])
    sm = self.sm_scale[0] + parameters[:, 7] * (self.sm_scale[1] - self.sm_scale[0])
    ex = self.ex_scale[0] + parameters[:, 8] * (self.ex_scale[1] - self.ex_scale[0])
    ki_ = self.ki_scale[0] + parameters[:, 9] * (
        self.ki_scale[1] - self.ki_scale[0]
    )
    kg_ = self.kg_scale[0] + parameters[:, 10] * (
        self.kg_scale[1] - self.kg_scale[0]
    )
    # ki+kg should be smaller than 1; if not, we scale them, but note float only contain 4 digits, so we need 0.999
    ki = torch.where(
        ki_ + kg_ < 1.0,
        ki_,
        (1 - PRECISION) / (ki_ + kg_) * ki_,
    )
    kg = torch.where(
        ki_ + kg_ < 1.0,
        kg_,
        (1 - PRECISION) / (ki_ + kg_) * kg_,
    )
    a = self.a_scale[0] + parameters[:, 11] * (self.a_scale[1] - self.a_scale[0])
    theta = self.theta_scale[0] + parameters[:, 12] * (
        self.theta_scale[1] - self.theta_scale[0]
    )
    ci = self.ci_scale[0] + parameters[:, 13] * (
        self.ci_scale[1] - self.ci_scale[0]
    )
    cg = self.cg_scale[0] + parameters[:, 14] * (
        self.cg_scale[1] - self.cg_scale[0]
    )

    # initialize state values
    warmup_length = self.warmup_length
    if warmup_length > 0:
        # set no_grad for warmup periods
        with torch.no_grad():
            p_and_e_warmup = p_and_e[0:warmup_length, :, :]
            cal_init_xaj4dpl = Xaj4Dpl(
                self.kernel_size, 0, self.source_book, self.source_type
            )
            if cal_init_xaj4dpl.warmup_length > 0:
                raise RuntimeError("Please set init model's warmup length to 0!!!")
            _, _, *w0, s0, fr0, qi0, qg0 = cal_init_xaj4dpl(
                p_and_e_warmup, parameters, return_state=True
            )
    else:
        # use detach func to make wu0 no_grad as it is an initial value
        w0 = (0.5 * (um.detach()), 0.5 * (lm.detach()), 0.5 * (dm.detach()))
        s0 = 0.5 * (sm.detach())
        fr0 = torch.full(ci.size(), 0.1).to(xaj_device)
        qi0 = torch.full(ci.size(), 0.1).to(xaj_device)
        qg0 = torch.full(cg.size(), 0.1).to(xaj_device)

    inputs = p_and_e[warmup_length:, :, :]
    runoff_ims_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
    rss_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
    ris_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
    rgs_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
    es_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
    for i in range(inputs.shape[0]):
        if i == 0:
            (r, rim, e, pe), w = xaj_generation(
                inputs[i, :, :], k, b, im, um, lm, dm, c, *w0
            )
            if self.source_type == "sources":
                (rs, ri, rg), (s, fr) = xaj_sources(
                    pe, r, sm, ex, ki, kg, s0, fr0, book=self.source_book
                )
            elif self.source_type == "sources5mm":
                (rs, ri, rg), (s, fr) = xaj_sources5mm(
                    pe, r, sm, ex, ki, kg, s0, fr0, book=self.source_book
                )
            else:
                raise NotImplementedError("No such divide-sources method")
        else:
            (r, rim, e, pe), w = xaj_generation(
                inputs[i, :, :], k, b, im, um, lm, dm, c, *w
            )
            if self.source_type == "sources":
                (rs, ri, rg), (s, fr) = xaj_sources(
                    pe, r, sm, ex, ki, kg, s, fr, book=self.source_book
                )
            elif self.source_type == "sources5mm":
                (rs, ri, rg), (s, fr) = xaj_sources5mm(
                    pe, r, sm, ex, ki, kg, s, fr, book=self.source_book
                )
            else:
                raise NotImplementedError("No such divide-sources method")
        # impevious part is pe * im
        runoff_ims_[i, :] = rim
        # so for non-imprvious part, the result should be corrected
        rss_[i, :] = rs * (1 - im)
        ris_[i, :] = ri * (1 - im)
        rgs_[i, :] = rg * (1 - im)
        es_[i, :] = e
        # rss_[i, :] = 0.7 * r
        # ris_[i, :] = 0.2 * r
        # rgs_[i, :] = 0.1 * r
    # seq, batch, feature
    runoff_im = torch.unsqueeze(runoff_ims_, dim=2)
    rss = torch.unsqueeze(rss_, dim=2)
    es = torch.unsqueeze(es_, dim=2)

    conv_uh = KernelConv(a, theta, self.kernel_size)
    qs_ = conv_uh(runoff_im + rss)
    qs = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
    for i in range(inputs.shape[0]):
        if i == 0:
            qi = linear_reservoir(ris_[i], ci, qi0)
            qg = linear_reservoir(rgs_[i], cg, qg0)
        else:
            qi = linear_reservoir(ris_[i], ci, qi)
            qg = linear_reservoir(rgs_[i], cg, qg)
        qs[i, :] = qs_[i, :, 0] + qi + qg
    # seq, batch, feature
    q_sim = torch.unsqueeze(qs, dim=2)
    if return_state:
        return q_sim, es, *w, s, fr, qi, qg
    return q_sim, es

ann_pbm(dl_model, pb_model, param_func, x, z)

Differential parameter learning

z (normalized input) -> ann -> param -> + x (not normalized) -> pbm -> q Parameters will be denormalized in pbm model

Parameters

dl_model ann model pb_model physics-based model param_func function used to limit the range of params; now it is sigmoid or clamp function x not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature] z normalized data used for DL model; a 2-dim tensor. [batch, feature]

Returns

torch.Tensor one time forward result

Source code in torchhydro/models/dpl4xaj.py
def ann_pbm(dl_model, pb_model, param_func, x, z):
    """
    Differential parameter learning

    z (normalized input) -> ann -> param -> + x (not normalized) -> pbm -> q
    Parameters will be denormalized in pbm model

    Parameters
    ----------
    dl_model
        ann model
    pb_model
        physics-based model
    param_func
        function used to limit the range of params; now it is sigmoid or clamp function
    x
        not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature]
    z
        normalized data used for DL model; a 2-dim tensor. [batch, feature]

    Returns
    -------
    torch.Tensor
        one time forward result
    """
    gen = dl_model(z)
    if torch.isnan(gen).any():
        raise ValueError("Error: NaN values detected. Check your data firstly!!!")
    # we set all params' values in [0, 1] and will scale them when forwarding
    if param_func == "sigmoid":
        params = F.sigmoid(gen)
    elif param_func == "clamp":
        params = torch.clamp(gen, min=0.0, max=1.0)
    else:
        raise NotImplementedError(
            "We don't provide this way to limit parameters' range!! Please choose sigmoid or clamp"
        )
    return pb_model(x[:, :, : pb_model.feature_size], params)

calculate_evap(lm, c, wu0, wl0, prcp, pet)

Three-layers evaporation model from "Watershed Hydrologic Simulation" written by Prof. RenJun Zhao.

The book is Chinese, and its real name is 《流域水文模拟》; The three-layers evaporation model is descibed in Page 76; The method is same with that in Page 22-23 in "Hydrologic Forecasting (5-th version)" written by Prof. Weimin Bao. This book's Chinese name is 《水文预报》

Parameters

lm average soil moisture storage capacity of lower layer (mm) c coefficient of deep layer wu0 initial soil moisture of upper layer; update in each time step (mm) wl0 initial soil moisture of lower layer; update in each time step (mm) prcp basin mean precipitation (mm/day) pet potential evapotranspiration (mm/day)

Returns

torch.Tensor eu/el/ed are evaporation from upper/lower/deeper layer, respectively

Source code in torchhydro/models/dpl4xaj.py
def calculate_evap(
    lm, c, wu0, wl0, prcp, pet
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Three-layers evaporation model from "Watershed Hydrologic Simulation" written by Prof. RenJun Zhao.

    The book is Chinese, and its real name is 《流域水文模拟》;
    The three-layers evaporation model is descibed in Page 76;
    The method is same with that in Page 22-23 in "Hydrologic Forecasting (5-th version)" written by Prof. Weimin Bao.
    This book's Chinese name is 《水文预报》

    Parameters
    ----------
    lm
        average soil moisture storage capacity of lower layer (mm)
    c
        coefficient of deep layer
    wu0
        initial soil moisture of upper layer; update in each time step (mm)
    wl0
        initial soil moisture of lower layer; update in each time step (mm)
    prcp
        basin mean precipitation (mm/day)
    pet
        potential evapotranspiration (mm/day)

    Returns
    -------
    torch.Tensor
        eu/el/ed are evaporation from upper/lower/deeper layer, respectively
    """
    tensor_min = torch.full(wu0.size(), 0.0).to(prcp.device)
    # when using torch.where, please see here: https://github.com/pytorch/pytorch/issues/9190
    # it's element-wise operation, no problem here. For example:
    # In: torch.where(torch.Tensor([2,1])>torch.Tensor([1,1]),torch.Tensor([1,2]),torch.Tensor([3,4]))
    # Out: tensor([1., 4.])
    eu = torch.where(wu0 + prcp >= pet, pet, wu0 + prcp)
    ed = torch.where(
        (wl0 < c * lm) & (wl0 < c * (pet - eu)), c * (pet - eu) - wl0, tensor_min
    )
    el = torch.where(
        wu0 + prcp >= pet,
        tensor_min,
        torch.where(
            wl0 >= c * lm,
            (pet - eu) * wl0 / lm,
            torch.where(wl0 >= c * (pet - eu), c * (pet - eu), wl0),
        ),
    )
    return eu, el, ed

calculate_prcp_runoff(b, im, wm, w0, pe)

Calculates the amount of runoff generated from rainfall after entering the underlying surface

Same in "Watershed Hydrologic Simulation" and "Hydrologic Forecasting (5-th version)"

Parameters

b B exponent coefficient im IMP imperiousness coefficient wm average soil moisture storage capacity w0 initial soil moisture pe net precipitation

Returns

torch.Tensor r -- runoff; r_im -- runoff of impervious part

Source code in torchhydro/models/dpl4xaj.py
def calculate_prcp_runoff(b, im, wm, w0, pe):
    """
    Calculates the amount of runoff generated from rainfall after entering the underlying surface

    Same in "Watershed Hydrologic Simulation" and "Hydrologic Forecasting (5-th version)"

    Parameters
    ----------
    b
        B exponent coefficient
    im
        IMP imperiousness coefficient
    wm
        average soil moisture storage capacity
    w0
        initial soil moisture
    pe
        net precipitation

    Returns
    -------
    torch.Tensor
        r -- runoff; r_im -- runoff of impervious part
    """
    wmm = wm * (1 + b)
    a = wmm * (1 - (1 - w0 / wm) ** (1 / (1 + b)))
    if any(torch.isnan(a)):
        raise ValueError(
            "Error: NaN values detected. Try set clamp function or check your data!!!"
        )
    r_cal = torch.where(
        pe > 0.0,
        torch.where(
            pe + a < wmm,
            # torch.clamp is used for gradient not to be NaN, see more in xaj_sources function
            pe - (wm - w0) + wm * (1 - torch.clamp(a + pe, max=wmm) / wmm) ** (1 + b),
            pe - (wm - w0),
        ),
        torch.full(pe.size(), 0.0).to(pe.device),
    )
    if any(torch.isnan(r_cal)):
        raise ValueError(
            "Error: NaN values detected. Try set clamp function or check your data!!!"
        )
    r = torch.clamp(r_cal, min=0.0)
    r_im_cal = pe * im
    r_im = torch.clamp(r_im_cal, min=0.0)
    return r, r_im

calculate_w_storage(um, lm, dm, wu0, wl0, wd0, eu, el, ed, pe, r)

Update the soil moisture values of the three layers.

According to the runoff-generation equation 2.60 in the book "SHUIWENYUBAO", dW = dPE - dR

Parameters

um average soil moisture storage capacity of the upper layer (mm) lm average soil moisture storage capacity of the lower layer (mm) dm average soil moisture storage capacity of the deep layer (mm) wu0 initial values of soil moisture in upper layer wl0 initial values of soil moisture in lower layer wd0 initial values of soil moisture in deep layer eu evaporation of the upper layer; it isn't used in this function el evaporation of the lower layer ed evaporation of the deep layer pe net precipitation; it is able to be negative value in this function r runoff

Returns

torch.Tensor wu,wl,wd -- soil moisture in upper, lower and deep layer

Source code in torchhydro/models/dpl4xaj.py
def calculate_w_storage(um, lm, dm, wu0, wl0, wd0, eu, el, ed, pe, r):
    """
    Update the soil moisture values of the three layers.

    According to the runoff-generation equation 2.60 in the book "SHUIWENYUBAO", dW = dPE - dR

    Parameters
    ----------
    um
        average soil moisture storage capacity of the upper layer (mm)
    lm
        average soil moisture storage capacity of the lower layer (mm)
    dm
        average soil moisture storage capacity of the deep layer (mm)
    wu0
        initial values of soil moisture in upper layer
    wl0
        initial values of soil moisture in lower layer
    wd0
        initial values of soil moisture in deep layer
    eu
        evaporation of the upper layer; it isn't used in this function
    el
        evaporation of the lower layer
    ed
        evaporation of the deep layer
    pe
        net precipitation; it is able to be negative value in this function
    r
        runoff

    Returns
    -------
    torch.Tensor
        wu,wl,wd -- soil moisture in upper, lower and deep layer
    """
    xaj_device = pe.device
    tensor_zeros = torch.full(wu0.size(), 0.0).to(xaj_device)
    # pe>0: the upper soil moisture was added firstly, then lower layer, and the final is deep layer
    # pe<=0: no additional water, just remove evapotranspiration,
    # but note the case: e >= p > 0
    # (1) if wu0 + p > e, then e = eu (2) else, wu must be zero
    wu = torch.where(
        pe > 0.0,
        torch.where(wu0 + pe - r < um, wu0 + pe - r, um),
        torch.where(wu0 + pe > 0.0, wu0 + pe, tensor_zeros),
    )
    # calculate wd before wl because it is easier to cal using where statement
    wd = torch.where(
        pe > 0.0,
        torch.where(
            wu0 + wl0 + pe - r > um + lm, wu0 + wl0 + wd0 + pe - r - um - lm, wd0
        ),
        wd0 - ed,
    )
    # water balance (equation 2.2 in Page 13, also shown in Page 23)
    # if wu0 + p > e, then e = eu; else p must be used in upper layer,
    # so no matter what the case is, el didn't include p, neither ed
    wl = torch.where(pe > 0.0, wu0 + wl0 + wd0 + pe - r - wu - wd, wl0 - el)
    # the water storage should be in reasonable range
    tensor_mins = torch.full(um.size(), 0.0).to(xaj_device)
    wu_ = torch.clamp(wu, min=tensor_mins, max=um)
    wl_ = torch.clamp(wl, min=tensor_mins, max=lm)
    wd_ = torch.clamp(wd, min=tensor_mins, max=dm)
    return wu_, wl_, wd_

linear_reservoir(x, weight, last_y=None)

Linear reservoir's release function

Parameters

x the input to the linear reservoir weight the coefficient of linear reservoir last_y the output of last period

Returns

torch.Tensor one-step forward result

Source code in torchhydro/models/dpl4xaj.py
def linear_reservoir(x, weight, last_y: Optional[Tensor] = None):
    """
    Linear reservoir's release function

    Parameters
    ----------
    x
        the input to the linear reservoir
    weight
        the coefficient of linear reservoir
    last_y
        the output of last period

    Returns
    -------
    torch.Tensor
        one-step forward result
    """
    weight1 = 1 - weight
    if last_y is None:
        last_y = torch.full(weight.size(), 0.001).to(x.device)
    return weight * last_y + weight1 * x

lstm_pbm(dl_model, pb_model, param_func, x, z)

Differential parameter learning

z (normalized input) -> lstm -> param -> + x (not normalized) -> pbm -> q Parameters will be denormalized in pbm model

Parameters

dl_model lstm model pb_model physics-based model param_func function used to limit the range of params; now it is sigmoid or clamp function x not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature] z normalized data used for DL model; a sequence-first 3-dim tensor. [sequence, batch, feature]

Returns

torch.Tensor one time forward result

Source code in torchhydro/models/dpl4xaj.py
def lstm_pbm(dl_model, pb_model, param_func, x, z):
    """
    Differential parameter learning

    z (normalized input) -> lstm -> param -> + x (not normalized) -> pbm -> q
    Parameters will be denormalized in pbm model

    Parameters
    ----------
    dl_model
        lstm model
    pb_model
        physics-based model
    param_func
        function used to limit the range of params; now it is sigmoid or clamp function
    x
        not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature]
    z
        normalized data used for DL model; a sequence-first 3-dim tensor. [sequence, batch, feature]

    Returns
    -------
    torch.Tensor
            one time forward result
    """
    gen = dl_model(z)
    if torch.isnan(gen).any():
        raise ValueError("Error: NaN values detected. Check your data firstly!!!")
    # we set all params' values in [0, 1] and will scale them when forwarding
    if param_func == "sigmoid":
        params_ = F.sigmoid(gen)
    elif param_func == "clamp":
        params_ = torch.clamp(gen, min=0.0, max=1.0)
    else:
        raise NotImplementedError(
            "We don't provide this way to limit parameters' range!! Please choose sigmoid or clamp"
        )
    # just get one-period values, here we use the final period's values
    params = params_[-1, :, :]
    return pb_model(x[:, :, : pb_model.feature_size], params)

xaj_generation(p_and_e, k, b, im, um, lm, dm, c, wu0=None, wl0=None, wd0=None)

Single-step runoff generation in XAJ.

Parameters

p_and_e precipitation and potential evapotranspiration (mm/day) k ratio of potential evapotranspiration to reference crop evaporation b exponent parameter um average soil moisture storage capacity of the upper layer (mm) lm average soil moisture storage capacity of the lower layer (mm) dm average soil moisture storage capacity of the deep layer (mm) im impermeability coefficient c coefficient of deep layer wu0 initial values of soil moisture in upper layer (mm) wl0 initial values of soil moisture in lower layer (mm) wd0 initial values of soil moisture in deep layer (mm)

Returns

tuple[torch.Tensor] (r, rim, e, pe), (wu, wl, wd)

Source code in torchhydro/models/dpl4xaj.py
def xaj_generation(
    p_and_e: Tensor,
    k,
    b,
    im,
    um,
    lm,
    dm,
    c,
    wu0: Tensor = None,
    wl0: Tensor = None,
    wd0: Tensor = None,
) -> tuple:
    """
    Single-step runoff generation in XAJ.

    Parameters
    ----------
    p_and_e
        precipitation and potential evapotranspiration (mm/day)
    k
        ratio of potential evapotranspiration to reference crop evaporation
    b
        exponent parameter
    um
        average soil moisture storage capacity of the upper layer (mm)
    lm
        average soil moisture storage capacity of the lower layer (mm)
    dm
        average soil moisture storage capacity of the deep layer (mm)
    im
        impermeability coefficient
    c
        coefficient of deep layer
    wu0
        initial values of soil moisture in upper layer (mm)
    wl0
        initial values of soil moisture in lower layer (mm)
    wd0
        initial values of soil moisture in deep layer (mm)

    Returns
    -------
    tuple[torch.Tensor]
        (r, rim, e, pe), (wu, wl, wd)
    """
    # make sure physical variables' value ranges are correct
    prcp = torch.clamp(p_and_e[:, 0], min=0.0)
    pet = torch.clamp(p_and_e[:, 1] * k, min=0.0)
    # wm
    wm = um + lm + dm
    if wu0 is None:
        # use detach func to make wu0 no_grad as it is an initial value
        wu0 = 0.6 * (um.detach())
    if wl0 is None:
        wl0 = 0.6 * (lm.detach())
    if wd0 is None:
        wd0 = 0.6 * (dm.detach())
    w0_ = wu0 + wl0 + wd0
    # w0 need locate in correct range so that following calculation could be right
    # To make sure the gradient is also not NaN (see case in xaj_sources),
    # we'd better minus a precision (1e-5), although we've not met this situation (grad is NaN)
    w0 = torch.clamp(w0_, max=wm - 1e-5)

    # Calculate the amount of evaporation from storage
    eu, el, ed = calculate_evap(lm, c, wu0, wl0, prcp, pet)
    e = eu + el + ed

    # Calculate the runoff generated by net precipitation
    prcp_difference = prcp - e
    pe = torch.clamp(prcp_difference, min=0.0)
    r, rim = calculate_prcp_runoff(b, im, wm, w0, pe)
    # Update wu, wl, wd;
    # we use prcp_difference rather than pe, as when pe<0 but prcp>0, prcp should be considered
    wu, wl, wd = calculate_w_storage(
        um, lm, dm, wu0, wl0, wd0, eu, el, ed, prcp_difference, r
    )

    return (r, rim, e, pe), (wu, wl, wd)

xaj_sources(pe, r, sm, ex, ki, kg, s0=None, fr0=None, book='HF')

Divide the runoff to different sources

We use the initial version from the paper of the inventor of the XAJ model -- Prof. Renjun Zhao: "Analysis of parameters of the XinAnJiang model". Its Chinese name is <<新安江模型参数的分析>>, which could be found by searching in "Baidu Xueshu". The module's code can also be found in "Watershed Hydrologic Simulation" (WHS) Page 174. It is nearly same with that in "Hydrologic Forecasting" (HF) Page 148-149 We use the period average runoff as input and the unit period is day so we don't need to difference it as books show

Parameters

pe net precipitation (mm/day) r runoff from xaj_generation (mm/day) sm areal mean free water capacity of the surface layer (mm) ex exponent of the free water capacity curve ki outflow coefficients of the free water storage to interflow relationships kg outflow coefficients of the free water storage to groundwater relationships s0 initial free water capacity (mm) fr0 runoff area of last period

Return

torch.Tensor rs -- surface runoff; ri-- interflow runoff; rg -- groundwater runoff

Source code in torchhydro/models/dpl4xaj.py
def xaj_sources(
    pe,
    r,
    sm,
    ex,
    ki,
    kg,
    s0: Optional[Tensor] = None,
    fr0: Optional[Tensor] = None,
    book="HF",
) -> tuple:
    """
    Divide the runoff to different sources

    We use the initial version from the paper of the inventor of the XAJ model -- Prof. Renjun Zhao:
    "Analysis of parameters of the XinAnJiang model". Its Chinese name is <<新安江模型参数的分析>>,
    which could be found by searching in "Baidu Xueshu".
    The module's code can also be found in "Watershed Hydrologic Simulation" (WHS) Page 174.
    It is nearly same with that in "Hydrologic Forecasting" (HF) Page 148-149
    We use the period average runoff as input and the unit period is day so we don't need to difference it as books show


    Parameters
    ------------
    pe
        net precipitation (mm/day)
    r
        runoff from xaj_generation (mm/day)
    sm
        areal mean free water capacity of the surface layer (mm)
    ex
        exponent of the free water capacity curve
    ki
        outflow coefficients of the free water storage to interflow relationships
    kg
        outflow coefficients of the free water storage to groundwater relationships
    s0
        initial free water capacity (mm)
    fr0
        runoff area of last period

    Return
    ------------
    torch.Tensor
        rs -- surface runoff; ri-- interflow runoff; rg -- groundwater runoff

    """
    xaj_device = pe.device
    # maximum free water storage capacity in a basin
    ms = sm * (1 + ex)
    if fr0 is None:
        fr0 = torch.full(sm.shape[0], 0.1).to(xaj_device)
    if s0 is None:
        s0 = 0.5 * (sm.clone().detach())
    # For free water storage, because s is related to fr and s0 and fr0 are both values of last period,
    # we have to trans the initial value of s from last period to this one.
    # both WHS(流域水文模拟)'s sample code and HF(水文预报) use s = fr0 * s0 / fr.
    # I think they both think free water reservoir as a cubic tank. Its height is s and area of bottom rectangle is fr
    # we will have a cubic tank with varying bottom and height,
    # and fixed boundary (in HF sm is fixed) or none-fixed boundary (in EH smmf is not fixed)
    # but notice r's list like" [1,0] which 1 is the 1st period's runoff and 0 is the 2nd period's runoff
    # after 1st period, the s1 could not be zero, but in the 2nd period, fr=0, then we cannot set s=0, because some water still in the tank
    # fr's formula could be found in Eq. 9 in "Analysis of parameters of the XinAnJiang model",
    # Here our r doesn't include rim, so there is no need to remove rim from r; this is also the method in HF
    # Moreover, to make sure ss is not larger than sm, otherwise au will be nan value.
    # It is worth to note that we have to use a precision here -- 1e-5, otherwise the gradient will be NaN;
    # I guess maybe when calculating gradient -- Δy/Δx, Δ brings some precision problem when we need exponent function.

    # NOTE: when r is 0, fr should be 0, however, s1 may not be zero and it still hold some water,
    # then fr can not be 0, otherwise when fr is used as denominator it lead to error,
    # so we have to deal with this case later, for example, when r=0, we cannot use pe * fr to replace r
    # because fr get the value of last period, and it is not 0

    # cannot use torch.where, because it will cause some error when calculating gradient
    # fr = torch.where(r > 0.0, r / pe, fr0)
    # fr just use fr0, and it can be included in the computation graph, so we don't detach it
    fr = torch.clone(fr0)
    fr_mask = r > 0.0
    fr[fr_mask] = r[fr_mask] / pe[fr_mask]
    if any(torch.isnan(fr)):
        raise ValueError(
            "Error: NaN values detected. Try set clamp function or check your data!!!"
        )
    if any(fr == 0.0):
        raise ArithmeticError(
            "Please check fr's value, fr==0.0 will cause error in the next step!"
        )
    ss = torch.clone(s0)
    s = torch.clone(s0)

    ss[fr_mask] = fr0[fr_mask] * s0[fr_mask] / fr[fr_mask]

    if book == "HF":
        ss = torch.clamp(ss, max=sm - PRECISION)
        au = ms * (1.0 - (1.0 - ss / sm) ** (1.0 / (1.0 + ex)))
        if any(torch.isnan(au)):
            raise ValueError(
                "Error: NaN values detected. Try set clamp function or check your data!!!"
            )

        rs = torch.full_like(r, 0.0, device=xaj_device)
        rs[fr_mask] = torch.where(
            pe[fr_mask] + au[fr_mask] < ms[fr_mask],
            # equation 2-85 in HF
            # it's weird here, but we have to clamp so that the gradient could be not NaN;
            # otherwise, even the forward calculation is correct, the gradient is still NaN;
            # maybe when calculating gradient -- Δy/Δx, Δ brings some precision problem
            # if we need exponent function.
            fr[fr_mask]
            * (
                pe[fr_mask]
                - sm[fr_mask]
                + ss[fr_mask]
                + sm[fr_mask]
                * (
                    (
                        1
                        - torch.clamp(pe[fr_mask] + au[fr_mask], max=ms[fr_mask])
                        / ms[fr_mask]
                    )
                    ** (1 + ex[fr_mask])
                )
            ),
            # equation 2-86 in HF
            fr[fr_mask] * (pe[fr_mask] + ss[fr_mask] - sm[fr_mask]),
        )
        rs = torch.clamp(rs, max=r)
        # ri's mask is not same as rs's, because last period's s may not be 0
        # and in this time, ri and rg could be larger than 0
        # we need firstly calculate the updated s, s's mask is same as fr_mask,
        # when r==0, then s will be equal to last period's
        # equation 2-87 in HF, some free water leave or save, so we update free water storage
        s[fr_mask] = ss[fr_mask] + (r[fr_mask] - rs[fr_mask]) / fr[fr_mask]
        s = torch.clamp(s, max=sm)
    elif book == "EH":
        smmf = ms * (1 - (1 - fr) ** (1 / ex))
        smf = smmf / (1 + ex)
        ss = torch.clamp(ss, max=smf - PRECISION)
        au = smmf * (1 - (1 - ss / smf) ** (1 / (1 + ex)))
        if torch.isnan(au).any():
            raise ArithmeticError(
                "Error: NaN values detected. Try set clip function or check your data!!!"
            )
        rs = torch.full_like(r, 0.0, device=xaj_device)
        rs[fr_mask] = torch.where(
            pe[fr_mask] + au[fr_mask] < smmf[fr_mask],
            (
                pe[fr_mask]
                - smf[fr_mask]
                + ss[fr_mask]
                + smf[fr_mask]
                * (
                    1
                    - torch.clamp(
                        pe[fr_mask] + au[fr_mask],
                        max=smmf[fr_mask],
                    )
                    / smmf[fr_mask]
                )
                ** (ex[fr_mask] + 1)
            )
            * fr[fr_mask],
            (pe[fr_mask] + ss[fr_mask] - smf[fr_mask]) * fr[fr_mask],
        )
        rs = torch.clamp(rs, max=r)
        s[fr_mask] = ss[fr_mask] + (r[fr_mask] - rs[fr_mask]) / fr[fr_mask]
        s[fr_mask] = torch.clamp(s[fr_mask], max=smf[fr_mask])
        s = torch.clamp(s, max=smf)
    else:
        raise ValueError("Please set book as 'HF' or 'EH'!")
    # equation 2-88 in HF, next interflow and ground water will be released from the updated free water storage
    # We use the period average runoff as input and the unit period is day.
    # Hence, we directly use ki and kg rather than ki_{Δt} in books.
    ri = ki * s * fr
    rg = kg * s * fr
    # equation 2-89 in HF; although it looks different with that in WHS, they are actually same
    # Finally, calculate the final free water storage
    s1 = s * (1 - ki - kg)
    return (rs, ri, rg), (s1, fr)

xaj_sources5mm(pe, runoff, sm, ex, ki, kg, s0=None, fr0=None, book='HF')

Divide the runoff to different sources according to books -- 《水文预报》HF 5th edition and 《工程水文学》EH 3rd edition

Parameters

pe net precipitation runoff runoff from xaj_generation sm areal mean free water capacity of the surface layer ex exponent of the free water capacity curve ki outflow coefficients of the free water storage to interflow relationships kg outflow coefficients of the free water storage to groundwater relationships s0 initial free water capacity fr0 initial area of generation time_interval_hours 由于Ki、Kg、Ci、Cg都是以24小时为时段长定义的,需根据时段长转换 book the methods in 《水文预报》HF 5th edition and 《工程水文学》EH 3rd edition are different, hence, both are provided, and the default is the former -- "ShuiWenYuBao"; the other one is "GongChengShuiWenXue"

Returns

tuple[tuple, tuple] rs_s -- surface runoff; rss_s-- interflow runoff; rg_s -- groundwater runoff; (fr_ds[-1], s_ds[-1]): state variables' final value; all variables are numpy array

Source code in torchhydro/models/dpl4xaj.py
def xaj_sources5mm(
    pe,
    runoff,
    sm,
    ex,
    ki,
    kg,
    s0=None,
    fr0=None,
    book="HF",
):
    """
    Divide the runoff to different sources according to books -- 《水文预报》HF 5th edition and 《工程水文学》EH 3rd edition

    Parameters
    ----------
    pe
        net precipitation
    runoff
        runoff from xaj_generation
    sm
        areal mean free water capacity of the surface layer
    ex
        exponent of the free water capacity curve
    ki
        outflow coefficients of the free water storage to interflow relationships
    kg
        outflow coefficients of the free water storage to groundwater relationships
    s0
        initial free water capacity
    fr0
        initial area of generation
    time_interval_hours
        由于Ki、Kg、Ci、Cg都是以24小时为时段长定义的,需根据时段长转换
    book
        the methods in 《水文预报》HF 5th edition and 《工程水文学》EH 3rd edition are different,
        hence, both are provided, and the default is the former -- "ShuiWenYuBao";
        the other one is "GongChengShuiWenXue"

    Returns
    -------
    tuple[tuple, tuple]
        rs_s -- surface runoff; rss_s-- interflow runoff; rg_s -- groundwater runoff;
        (fr_ds[-1], s_ds[-1]): state variables' final value;
        all variables are numpy array
    """
    xaj_device = pe.device
    # 流域最大点自由水蓄水容量深
    smm = sm * (1 + ex)
    if fr0 is None:
        fr0 = torch.full_like(sm, 0.1, device=xaj_device)
    if s0 is None:
        s0 = 0.5 * (sm.clone().detach())
    fr = torch.clone(fr0)
    fr_mask = runoff > 0.0
    fr[fr_mask] = runoff[fr_mask] / pe[fr_mask]
    if torch.all(runoff < 5):
        n = 1
    else:
        r_max = torch.max(runoff).detach().cpu().numpy()
        residue_temp = r_max % 5
        if residue_temp != 0:
            residue_temp = 1
        n = int(r_max / 5) + residue_temp
    rn = runoff / n
    pen = pe / n
    kss_d = (1 - (1 - (ki + kg)) ** (1 / n)) / (1 + kg / ki)
    kg_d = kss_d * kg / ki
    if torch.isnan(kss_d).any() or torch.isnan(kg_d).any():
        raise ValueError("Error: NaN values detected. Check your parameters setting!!!")
    # kss_d = ki
    # kg_d = kg

    rs = torch.full_like(runoff, 0.0, device=xaj_device)
    rss = torch.full_like(runoff, 0.0, device=xaj_device)
    rg = torch.full_like(runoff, 0.0, device=xaj_device)

    s_ds = []
    fr_ds = []
    s_ds.append(s0)
    fr_ds.append(fr0)
    for j in range(n):
        fr0_d = fr_ds[j]
        s0_d = s_ds[j]
        # equation 5-32 in HF, but strange, cause each period, rn/pen is same
        # fr_d = torch.full_like(fr0_d, PRECISION, device=xaj_device)
        # fr_d_mask = fr > PRECISION
        # fr_d[fr_d_mask] = 1 - (1 - fr[fr_d_mask]) ** (1 / n)
        fr_d = fr

        ss_d = torch.clone(s0_d)
        s_d = torch.clone(s0_d)

        ss_d[fr_mask] = fr0_d[fr_mask] * s0_d[fr_mask] / fr_d[fr_mask]

        if book == "HF":
            # ms = smm
            ss_d = torch.clamp(ss_d, max=sm - PRECISION)
            au = smm * (1.0 - (1.0 - ss_d / sm) ** (1.0 / (1.0 + ex)))
            if torch.isnan(au).any():
                raise ValueError(
                    "Error: NaN values detected. Try set clip function or check your data!!!"
                )
            rs_j = torch.full_like(rn, 0.0, device=xaj_device)
            rs_j[fr_mask] = torch.where(
                pen[fr_mask] + au[fr_mask] < smm[fr_mask],
                # equation 5-26 in HF
                fr_d[fr_mask]
                * (
                    pen[fr_mask]
                    - sm[fr_mask]
                    + ss_d[fr_mask]
                    + sm[fr_mask]
                    * (
                        (
                            1
                            - torch.clamp(pen[fr_mask] + au[fr_mask], max=smm[fr_mask])
                            / smm[fr_mask]
                        )
                        ** (1 + ex[fr_mask])
                    )
                ),
                # equation 5-27 in HF
                fr_d[fr_mask] * (pen[fr_mask] + ss_d[fr_mask] - sm[fr_mask]),
            )
            rs_j = torch.clamp(rs_j, max=rn)
            s_d[fr_mask] = ss_d[fr_mask] + (rn[fr_mask] - rs_j[fr_mask]) / fr_d[fr_mask]
            s_d = torch.clamp(s_d, max=sm)

        elif book == "EH":
            smmf = smm * (1 - (1 - fr_d) ** (1 / ex))
            smf = smmf / (1 + ex)
            ss_d = torch.clamp(ss_d, max=smf - PRECISION)
            au = smmf * (1 - (1 - ss_d / smf) ** (1 / (1 + ex)))
            if torch.isnan(au).any():
                raise ValueError(
                    "Error: NaN values detected. Try set clip function or check your data!!!"
                )
            rs_j = torch.full(rn.size(), 0.0).to(xaj_device)
            rs_j[fr_mask] = torch.where(
                pen[fr_mask] + au[fr_mask] < smmf[fr_mask],
                (
                    pen[fr_mask]
                    - smf[fr_mask]
                    + ss_d[fr_mask]
                    + smf[fr_mask]
                    * (
                        1
                        - torch.clamp(
                            pen[fr_mask] + au[fr_mask],
                            max=smmf[fr_mask],
                        )
                        / smmf[fr_mask]
                    )
                    ** (ex[fr_mask] + 1)
                )
                * fr_d[fr_mask],
                (pen[fr_mask] + ss_d[fr_mask] - smf[fr_mask]) * fr_d[fr_mask],
            )
            rs_j = torch.clamp(rs_j, max=rn)
            s_d[fr_mask] = ss_d[fr_mask] + (rn[fr_mask] - rs_j[fr_mask]) / fr_d[fr_mask]
            s_d = torch.clamp(s_d, max=smf)
        else:
            raise NotImplementedError(
                "We don't have this implementation! Please chose 'HF' or 'EH'!!"
            )

        rss_j = s_d * kss_d * fr_d
        rg_j = s_d * kg_d * fr_d
        s1_d = s_d * (1 - kss_d - kg_d)

        rs = rs + rs_j
        rss = rss + rss_j
        rg = rg + rg_j
        # 赋值s_d和fr_d到数组中,以给下一段做初值
        s_ds.append(s1_d)
        fr_ds.append(fr_d)

    return (rs, rss, rg), (s_ds[-1], fr_ds[-1])

dpl4xaj_nn4et

The method is similar with dpl4xaj.py. The difference between dpl4xaj and dpl4xaj_nn4et is: in the former, the parameter of PBM is only one output of a DL model, while in the latter, time series output of a DL model can be as parameter of PBM and some modules could be replaced with neural networks

DplLstmNnModuleXaj (Module)

Source code in torchhydro/models/dpl4xaj_nn4et.py
class DplLstmNnModuleXaj(nn.Module):
    def __init__(
        self,
        n_input_features,
        n_output_features,
        n_hidden_states,
        kernel_size,
        warmup_length,
        param_limit_func="clamp",
        param_test_way="final",
        param_var_index=None,
        source_book="HF",
        source_type="sources",
        nn_hidden_size=None,
        nn_dropout=0.2,
        et_output=3,
        return_et=True,
    ):
        """
        Differential Parameter learning model: LSTM -> Param -> XAJ

        The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

        Parameters
        ----------
        n_input_features
            the number of input features of LSTM
        n_output_features
            the number of output features of LSTM, and it should be equal to the number of learning parameters in XAJ
        n_hidden_states
            the number of hidden features of LSTM
        kernel_size
            the time length of unit hydrograph
        warmup_length
            the length of warmup periods;
            hydrologic models need a warmup period to generate reasonable initial state values
        param_limit_func
            function used to limit the range of params; now it is sigmoid or clamp function
        param_test_way
            how we use parameters from dl model when testing;
            now we have three ways:
            1. "final" -- use the final period's parameter for each period
            2. "mean_time" -- Mean values of all periods' parameters is used
            3. "mean_basin" -- Mean values of all basins' final periods' parameters is used
            but remember these ways are only for non-variable parameters
        param_var_index
            variable parameters' indices in all parameters
        return_et
            if True, return evapotranspiration
        """
        if param_var_index is None:
            param_var_index = [0, 1, 6]
        if nn_hidden_size is None:
            nn_hidden_size = [16, 8]
        super(DplLstmNnModuleXaj, self).__init__()
        self.dl_model = SimpleLSTM(n_input_features, n_output_features, n_hidden_states)
        self.pb_model = Xaj4DplWithNnModule(
            kernel_size,
            warmup_length,
            source_book=source_book,
            source_type=source_type,
            nn_hidden_size=nn_hidden_size,
            nn_dropout=nn_dropout,
            et_output=et_output,
            param_var_index=param_var_index,
            param_test_way=param_test_way,
        )
        self.param_func = param_limit_func
        self.param_test_way = param_test_way
        self.param_var_index = param_var_index
        self.return_et = return_et

    def forward(self, x, z):
        """
        Differential parameter learning

        z (normalized input) -> lstm -> param -> + x (not normalized) -> xaj -> q
        Parameters will be denormalized in xaj model

        Parameters
        ----------
        x
            not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature]
        z
            normalized data used for DL model; a sequence-first 3-dim tensor. [sequence, batch, feature]

        Returns
        -------
        torch.Tensor
            one time forward result
        """
        gen = self.dl_model(z)
        if torch.isnan(gen).any():
            raise ValueError("Error: NaN values detected. Check your data firstly!!!")
        # we set all params' values in [0, 1] and will scale them when forwarding
        if self.param_func == "sigmoid":
            params = F.sigmoid(gen)
        elif self.param_func == "clamp":
            params = torch.clamp(gen, min=0.0, max=1.0)
        else:
            raise NotImplementedError(
                "We don't provide this way to limit parameters' range!! Please choose sigmoid or clamp"
            )
        # just get one-period values, here we use the final period's values,
        # when the MODEL_PARAM_TEST_WAY is not time_varing, we use the last period's values.
        if self.param_test_way != MODEL_PARAM_TEST_WAY["time_varying"]:
            params = params[-1, :, :]
        # Please put p in the first location and pet in the second
        q, e = self.pb_model(x[:, :, : self.pb_model.feature_size], params)
        return torch.cat([q, e], dim=-1) if self.return_et else q

__init__(self, n_input_features, n_output_features, n_hidden_states, kernel_size, warmup_length, param_limit_func='clamp', param_test_way='final', param_var_index=None, source_book='HF', source_type='sources', nn_hidden_size=None, nn_dropout=0.2, et_output=3, return_et=True) special

Differential Parameter learning model: LSTM -> Param -> XAJ

The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

Parameters

n_input_features the number of input features of LSTM n_output_features the number of output features of LSTM, and it should be equal to the number of learning parameters in XAJ n_hidden_states the number of hidden features of LSTM kernel_size the time length of unit hydrograph warmup_length the length of warmup periods; hydrologic models need a warmup period to generate reasonable initial state values param_limit_func function used to limit the range of params; now it is sigmoid or clamp function param_test_way how we use parameters from dl model when testing; now we have three ways: 1. "final" -- use the final period's parameter for each period 2. "mean_time" -- Mean values of all periods' parameters is used 3. "mean_basin" -- Mean values of all basins' final periods' parameters is used but remember these ways are only for non-variable parameters param_var_index variable parameters' indices in all parameters return_et if True, return evapotranspiration

Source code in torchhydro/models/dpl4xaj_nn4et.py
def __init__(
    self,
    n_input_features,
    n_output_features,
    n_hidden_states,
    kernel_size,
    warmup_length,
    param_limit_func="clamp",
    param_test_way="final",
    param_var_index=None,
    source_book="HF",
    source_type="sources",
    nn_hidden_size=None,
    nn_dropout=0.2,
    et_output=3,
    return_et=True,
):
    """
    Differential Parameter learning model: LSTM -> Param -> XAJ

    The principle can be seen here: https://doi.org/10.1038/s41467-021-26107-z

    Parameters
    ----------
    n_input_features
        the number of input features of LSTM
    n_output_features
        the number of output features of LSTM, and it should be equal to the number of learning parameters in XAJ
    n_hidden_states
        the number of hidden features of LSTM
    kernel_size
        the time length of unit hydrograph
    warmup_length
        the length of warmup periods;
        hydrologic models need a warmup period to generate reasonable initial state values
    param_limit_func
        function used to limit the range of params; now it is sigmoid or clamp function
    param_test_way
        how we use parameters from dl model when testing;
        now we have three ways:
        1. "final" -- use the final period's parameter for each period
        2. "mean_time" -- Mean values of all periods' parameters is used
        3. "mean_basin" -- Mean values of all basins' final periods' parameters is used
        but remember these ways are only for non-variable parameters
    param_var_index
        variable parameters' indices in all parameters
    return_et
        if True, return evapotranspiration
    """
    if param_var_index is None:
        param_var_index = [0, 1, 6]
    if nn_hidden_size is None:
        nn_hidden_size = [16, 8]
    super(DplLstmNnModuleXaj, self).__init__()
    self.dl_model = SimpleLSTM(n_input_features, n_output_features, n_hidden_states)
    self.pb_model = Xaj4DplWithNnModule(
        kernel_size,
        warmup_length,
        source_book=source_book,
        source_type=source_type,
        nn_hidden_size=nn_hidden_size,
        nn_dropout=nn_dropout,
        et_output=et_output,
        param_var_index=param_var_index,
        param_test_way=param_test_way,
    )
    self.param_func = param_limit_func
    self.param_test_way = param_test_way
    self.param_var_index = param_var_index
    self.return_et = return_et

forward(self, x, z)

Differential parameter learning

z (normalized input) -> lstm -> param -> + x (not normalized) -> xaj -> q Parameters will be denormalized in xaj model

Parameters

x not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature] z normalized data used for DL model; a sequence-first 3-dim tensor. [sequence, batch, feature]

Returns

torch.Tensor one time forward result

Source code in torchhydro/models/dpl4xaj_nn4et.py
def forward(self, x, z):
    """
    Differential parameter learning

    z (normalized input) -> lstm -> param -> + x (not normalized) -> xaj -> q
    Parameters will be denormalized in xaj model

    Parameters
    ----------
    x
        not normalized data used for physical model; a sequence-first 3-dim tensor. [sequence, batch, feature]
    z
        normalized data used for DL model; a sequence-first 3-dim tensor. [sequence, batch, feature]

    Returns
    -------
    torch.Tensor
        one time forward result
    """
    gen = self.dl_model(z)
    if torch.isnan(gen).any():
        raise ValueError("Error: NaN values detected. Check your data firstly!!!")
    # we set all params' values in [0, 1] and will scale them when forwarding
    if self.param_func == "sigmoid":
        params = F.sigmoid(gen)
    elif self.param_func == "clamp":
        params = torch.clamp(gen, min=0.0, max=1.0)
    else:
        raise NotImplementedError(
            "We don't provide this way to limit parameters' range!! Please choose sigmoid or clamp"
        )
    # just get one-period values, here we use the final period's values,
    # when the MODEL_PARAM_TEST_WAY is not time_varing, we use the last period's values.
    if self.param_test_way != MODEL_PARAM_TEST_WAY["time_varying"]:
        params = params[-1, :, :]
    # Please put p in the first location and pet in the second
    q, e = self.pb_model(x[:, :, : self.pb_model.feature_size], params)
    return torch.cat([q, e], dim=-1) if self.return_et else q

NnModule4Hydro (Module)

A NN module for Hydrological model. Generally, the difference between it and normal NN is: we need constrain its output to some specific value range

Parameters

nn : type description

Source code in torchhydro/models/dpl4xaj_nn4et.py
class NnModule4Hydro(nn.Module):
    """A NN module for Hydrological model.
    Generally, the difference between it and normal NN is:
    we need constrain its output to some specific value range

    Parameters
    ----------
    nn : _type_
        _description_
    """

    def __init__(
        self,
        nx: int,
        ny: int,
        hidden_size: Union[int, tuple, list] = None,
        dr: Union[float, tuple, list] = 0.0,
    ):
        """
        A simple multi-layer NN model with final linear layer

        Parameters
        ----------
        nx
            number of input neurons
        ny
            number of output neurons
        hidden_size
            a list/tuple which contains number of neurons in each hidden layer;
            if int, only one hidden layer except for hidden_size=0
        dr
            dropout rate of layers, default is 0.0 which means no dropout;
            here we set number of dropout layers to (number of nn layers - 1)
        """
        super(NnModule4Hydro, self).__init__()
        self.ann = SimpleAnn(nx, ny, hidden_size, dr)

    def forward(self, x, w0, prcp, pet, k):
        """the forward function of the NN ET module

        Parameters
        ----------
        x : _type_
            _description_
        w0 : _type_
            water storage
        p : _type_
            precipitation
        pet: tensor
            potential evapotranspiration, used to be part of upper limit of ET
        k: tensor
            coefficient of PET in XAJ model, used to be part of upper limit of ET

        Returns
        -------
        _type_
            _description_
        """
        zeros = torch.full_like(w0, 0.0, device=x.device)
        et = torch.full_like(w0, 0.0, device=x.device)
        w_mask = w0 + prcp > PRECISION
        y = self.ann(x)
        z = y.flatten()
        et[w_mask] = torch.clamp(
            z[w_mask],
            min=zeros[w_mask],
            # torch.minimum computes the element-wise minimum: https://pytorch.org/docs/stable/generated/torch.minimum.html
            # k * pet is real pet in XAJ model
            max=torch.minimum(w0[w_mask] + prcp[w_mask], k[w_mask] * pet[w_mask]),
        )
        return et

__init__(self, nx, ny, hidden_size=None, dr=0.0) special

A simple multi-layer NN model with final linear layer

Parameters

nx number of input neurons ny number of output neurons hidden_size a list/tuple which contains number of neurons in each hidden layer; if int, only one hidden layer except for hidden_size=0 dr dropout rate of layers, default is 0.0 which means no dropout; here we set number of dropout layers to (number of nn layers - 1)

Source code in torchhydro/models/dpl4xaj_nn4et.py
def __init__(
    self,
    nx: int,
    ny: int,
    hidden_size: Union[int, tuple, list] = None,
    dr: Union[float, tuple, list] = 0.0,
):
    """
    A simple multi-layer NN model with final linear layer

    Parameters
    ----------
    nx
        number of input neurons
    ny
        number of output neurons
    hidden_size
        a list/tuple which contains number of neurons in each hidden layer;
        if int, only one hidden layer except for hidden_size=0
    dr
        dropout rate of layers, default is 0.0 which means no dropout;
        here we set number of dropout layers to (number of nn layers - 1)
    """
    super(NnModule4Hydro, self).__init__()
    self.ann = SimpleAnn(nx, ny, hidden_size, dr)

forward(self, x, w0, prcp, pet, k)

the forward function of the NN ET module

Parameters

x : type description w0 : type water storage p : type precipitation !!! pet "tensor" potential evapotranspiration, used to be part of upper limit of ET !!! k "tensor" coefficient of PET in XAJ model, used to be part of upper limit of ET

Returns

type description

Source code in torchhydro/models/dpl4xaj_nn4et.py
def forward(self, x, w0, prcp, pet, k):
    """the forward function of the NN ET module

    Parameters
    ----------
    x : _type_
        _description_
    w0 : _type_
        water storage
    p : _type_
        precipitation
    pet: tensor
        potential evapotranspiration, used to be part of upper limit of ET
    k: tensor
        coefficient of PET in XAJ model, used to be part of upper limit of ET

    Returns
    -------
    _type_
        _description_
    """
    zeros = torch.full_like(w0, 0.0, device=x.device)
    et = torch.full_like(w0, 0.0, device=x.device)
    w_mask = w0 + prcp > PRECISION
    y = self.ann(x)
    z = y.flatten()
    et[w_mask] = torch.clamp(
        z[w_mask],
        min=zeros[w_mask],
        # torch.minimum computes the element-wise minimum: https://pytorch.org/docs/stable/generated/torch.minimum.html
        # k * pet is real pet in XAJ model
        max=torch.minimum(w0[w_mask] + prcp[w_mask], k[w_mask] * pet[w_mask]),
    )
    return et

Xaj4DplWithNnModule (Module)

XAJ model for Differential Parameter learning with neural network as submodule

Source code in torchhydro/models/dpl4xaj_nn4et.py
class Xaj4DplWithNnModule(nn.Module):
    """
    XAJ model for Differential Parameter learning with neural network as submodule
    """

    def __init__(
        self,
        kernel_size: int,
        warmup_length: int,
        nn_module=None,
        param_var_index=None,
        source_book="HF",
        source_type="sources",
        et_output=1,
        nn_hidden_size: Union[int, tuple, list] = None,
        nn_dropout=0.2,
        param_test_way=MODEL_PARAM_TEST_WAY["time_varying"],
    ):
        """
        Parameters
        ----------
        kernel_size
            the time length of unit hydrograph
        warmup_length
            the length of warmup periods;
            XAJ needs a warmup period to generate reasonable initial state values
        nn_module
            We initialize the module when we firstly initialize Xaj4DplWithNnModule.
            Then we will iterately call Xaj4DplWithNnModule module for warmup.
            Hence, in warmup period, we don't need to initialize it again
        param_var_index
            the index of parameters which will be time-varying
            NOTE: at the most, we support k, b, and c to be time-varying
        et_output
            we only support one-layer et now, because its water balance is not easy to handle with
        """
        if param_var_index is None:
            param_var_index = [0, 6]
        if nn_hidden_size is None:
            nn_hidden_size = [16, 8]
        super(Xaj4DplWithNnModule, self).__init__()
        self.params_names = MODEL_PARAM_DICT["xaj_mz"]["param_name"]
        param_range = MODEL_PARAM_DICT["xaj_mz"]["param_range"]
        self.k_scale = param_range["K"]
        self.b_scale = param_range["B"]
        self.im_sacle = param_range["IM"]
        self.um_scale = param_range["UM"]
        self.lm_scale = param_range["LM"]
        self.dm_scale = param_range["DM"]
        self.c_scale = param_range["C"]
        self.sm_scale = param_range["SM"]
        self.ex_scale = param_range["EX"]
        self.ki_scale = param_range["KI"]
        self.kg_scale = param_range["KG"]
        self.a_scale = param_range["A"]
        self.theta_scale = param_range["THETA"]
        self.ci_scale = param_range["CI"]
        self.cg_scale = param_range["CG"]
        self.kernel_size = kernel_size
        self.warmup_length = warmup_length
        # there are 2 input variables in XAJ: P and PET
        self.feature_size = 2
        if nn_module is None:
            # 7: k, um, lm, dm, c, prcp, p_and_e[:, 1] + 1/3: w0 or wu0, wl0, wd0
            self.evap_nn_module = NnModule4Hydro(
                7 + et_output, et_output, nn_hidden_size, nn_dropout
            )
        else:
            self.evap_nn_module = nn_module
        self.source_book = source_book
        self.source_type = source_type
        self.et_output = et_output
        self.param_var_index = param_var_index
        self.nn_hidden_size = nn_hidden_size
        self.nn_dropout = nn_dropout
        self.param_test_way = param_test_way

    def xaj_generation_with_new_module(
        self,
        p_and_e: Tensor,
        k,
        b,
        im,
        um,
        lm,
        dm,
        c,
        *args,
        # wu0: Tensor = None,
        # wl0: Tensor = None,
        # wd0: Tensor = None,
    ) -> tuple:
        # make sure physical variables' value ranges are correct
        prcp = torch.clamp(p_and_e[:, 0], min=0.0)
        pet = torch.clamp(p_and_e[:, 1], min=0.0)
        # wm
        wm = um + lm + dm
        if self.et_output != 1:
            raise NotImplementedError("We only support one-layer evaporation now")
        w0_ = args[0]
        if w0_ is None:
            w0_ = 0.6 * (um.detach() + lm.detach() + dm.detach())
        w0 = torch.clamp(w0_, max=wm - PRECISION)
        concat_input = torch.stack([k, um, lm, dm, c, w0, prcp, pet], dim=1)
        e = self.evap_nn_module(concat_input, w0, prcp, pet, k)
        # Calculate the runoff generated by net precipitation
        prcp_difference = prcp - e
        pe = torch.clamp(prcp_difference, min=0.0)
        r, rim = calculate_prcp_runoff(b, im, wm, w0, pe)
        if self.et_output == 1:
            w = calculate_1layer_w_storage(
                um,
                lm,
                dm,
                w0,
                prcp_difference,
                r,
            )
            return (r, rim, e, pe), (w,)
        else:
            raise ValueError("et_output should be 1")

    def forward(self, p_and_e, parameters_ts, return_state=False):
        """
        run XAJ model

        Parameters
        ----------
        p_and_e
            precipitation and potential evapotranspiration
        parameters_ts
            time series parameters of XAJ model;
            some parameters may be time-varying specified by param_var_index
        return_state
            if True, return state values, mainly for warmup periods

        Returns
        -------
        torch.Tensor
            streamflow got by XAJ
        """
        xaj_device = p_and_e.device
        if self.param_test_way == MODEL_PARAM_TEST_WAY["time_varying"]:
            parameters = parameters_ts[-1, :, :]
        else:
            # parameters_ts must be a 2-d tensor: (basin, param)
            parameters = parameters_ts
        # denormalize the parameters to general range
        # TODO: now the specific parameters are hard coded; 0 is k, 1 is b, 6 is c, same as in model_config.py
        if 0 not in self.param_var_index or self.param_var_index is None:
            ks = self.k_scale[0] + parameters[:, 0] * (
                self.k_scale[1] - self.k_scale[0]
            )
        else:
            ks = self.k_scale[0] + parameters_ts[:, :, 0] * (
                self.k_scale[1] - self.k_scale[0]
            )
        if 1 not in self.param_var_index or self.param_var_index is None:
            bs = self.b_scale[0] + parameters[:, 1] * (
                self.b_scale[1] - self.b_scale[0]
            )
        else:
            bs = self.b_scale[0] + parameters_ts[:, :, 1] * (
                self.b_scale[1] - self.b_scale[0]
            )
        im = self.im_sacle[0] + parameters[:, 2] * (self.im_sacle[1] - self.im_sacle[0])
        um = self.um_scale[0] + parameters[:, 3] * (self.um_scale[1] - self.um_scale[0])
        lm = self.lm_scale[0] + parameters[:, 4] * (self.lm_scale[1] - self.lm_scale[0])
        dm = self.dm_scale[0] + parameters[:, 5] * (self.dm_scale[1] - self.dm_scale[0])
        if 6 not in self.param_var_index or self.param_var_index is None:
            cs = self.c_scale[0] + parameters[:, 6] * (
                self.c_scale[1] - self.c_scale[0]
            )
        else:
            cs = self.c_scale[0] + parameters_ts[:, :, 6] * (
                self.c_scale[1] - self.c_scale[0]
            )
        sm = self.sm_scale[0] + parameters[:, 7] * (self.sm_scale[1] - self.sm_scale[0])
        ex = self.ex_scale[0] + parameters[:, 8] * (self.ex_scale[1] - self.ex_scale[0])
        ki_ = self.ki_scale[0] + parameters[:, 9] * (
            self.ki_scale[1] - self.ki_scale[0]
        )
        kg_ = self.kg_scale[0] + parameters[:, 10] * (
            self.kg_scale[1] - self.kg_scale[0]
        )
        # ki+kg should be smaller than 1; if not, we scale them
        ki = torch.where(
            ki_ + kg_ < 1.0,
            ki_,
            (1 - PRECISION) / (ki_ + kg_) * ki_,
        )
        kg = torch.where(
            ki_ + kg_ < 1.0,
            kg_,
            (1 - PRECISION) / (ki_ + kg_) * kg_,
        )
        a = self.a_scale[0] + parameters[:, 11] * (self.a_scale[1] - self.a_scale[0])
        theta = self.theta_scale[0] + parameters[:, 12] * (
            self.theta_scale[1] - self.theta_scale[0]
        )
        ci = self.ci_scale[0] + parameters[:, 13] * (
            self.ci_scale[1] - self.ci_scale[0]
        )
        cg = self.cg_scale[0] + parameters[:, 14] * (
            self.cg_scale[1] - self.cg_scale[0]
        )

        # initialize state values
        warmup_length = self.warmup_length
        if warmup_length > 0:
            # set no_grad for warmup periods
            with torch.no_grad():
                p_and_e_warmup = p_and_e[0:warmup_length, :, :]
                if self.param_test_way == MODEL_PARAM_TEST_WAY["time_varying"]:
                    parameters_ts_warmup = parameters_ts[0:warmup_length, :, :]
                else:
                    parameters_ts_warmup = parameters_ts
                cal_init_xaj4dpl = Xaj4DplWithNnModule(
                    kernel_size=self.kernel_size,
                    # warmup_length must be 0 here
                    warmup_length=0,
                    nn_module=self.evap_nn_module,
                    param_var_index=self.param_var_index,
                    source_book=self.source_book,
                    source_type=self.source_type,
                    et_output=self.et_output,
                    nn_hidden_size=self.nn_hidden_size,
                    nn_dropout=self.nn_dropout,
                    param_test_way=self.param_test_way,
                )
                if cal_init_xaj4dpl.warmup_length > 0:
                    raise RuntimeError("Please set init model's warmup length to 0!!!")
                _, _, *w0, s0, fr0, qi0, qg0 = cal_init_xaj4dpl(
                    p_and_e_warmup, parameters_ts_warmup, return_state=True
                )
        else:
            # use detach func to make wu0 no_grad as it is an initial value
            if self.et_output == 1:
                # () and , must be added, otherwise, w0 will be a tensor, not a tuple
                w0 = (0.5 * (um.detach() + lm.detach() + dm.detach()),)
            else:
                raise ValueError("et_output should be 1 or 3")
            s0 = 0.5 * (sm.detach())
            fr0 = torch.full(ci.size(), 0.1).to(xaj_device)
            qi0 = torch.full(ci.size(), 0.1).to(xaj_device)
            qg0 = torch.full(cg.size(), 0.1).to(xaj_device)

        inputs = p_and_e[warmup_length:, :, :]
        runoff_ims_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
        rss_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
        ris_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
        rgs_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
        es_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
        for i in range(inputs.shape[0]):
            if 0 in self.param_var_index or self.param_var_index is None:
                k = ks[i]
            else:
                k = ks
            if 1 in self.param_var_index or self.param_var_index is None:
                b = bs[i]
            else:
                b = bs
            if 6 in self.param_var_index or self.param_var_index is None:
                c = cs[i]
            else:
                c = cs
            if i == 0:
                (r, rim, e, pe), w = self.xaj_generation_with_new_module(
                    inputs[i, :, :], k, b, im, um, lm, dm, c, *w0
                )
                if self.source_type == "sources":
                    (rs, ri, rg), (s, fr) = xaj_sources(
                        pe, r, sm, ex, ki, kg, s0, fr0, book=self.source_book
                    )
                elif self.source_type == "sources5mm":
                    (rs, ri, rg), (s, fr) = xaj_sources5mm(
                        pe, r, sm, ex, ki, kg, s0, fr0, book=self.source_book
                    )
                else:
                    raise NotImplementedError("No such divide-sources method")
            else:
                (r, rim, e, pe), w = self.xaj_generation_with_new_module(
                    inputs[i, :, :], k, b, im, um, lm, dm, c, *w
                )
                if self.source_type == "sources":
                    (rs, ri, rg), (s, fr) = xaj_sources(
                        pe, r, sm, ex, ki, kg, s, fr, book=self.source_book
                    )
                elif self.source_type == "sources5mm":
                    (rs, ri, rg), (s, fr) = xaj_sources5mm(
                        pe, r, sm, ex, ki, kg, s, fr, book=self.source_book
                    )
                else:
                    raise NotImplementedError("No such divide-sources method")
            # impevious part is pe * im
            runoff_ims_[i, :] = rim
            # so for non-imprvious part, the result should be corrected
            rss_[i, :] = rs * (1 - im)
            ris_[i, :] = ri * (1 - im)
            rgs_[i, :] = rg * (1 - im)
            es_[i, :] = e
        # seq, batch, feature
        runoff_im = torch.unsqueeze(runoff_ims_, dim=2)
        rss = torch.unsqueeze(rss_, dim=2)
        es = torch.unsqueeze(es_, dim=2)

        conv_uh = KernelConv(a, theta, self.kernel_size)
        qs_ = conv_uh(runoff_im + rss)

        qs = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
        for i in range(inputs.shape[0]):
            if i == 0:
                qi = linear_reservoir(ris_[i], ci, qi0)
                qg = linear_reservoir(rgs_[i], cg, qg0)
            else:
                qi = linear_reservoir(ris_[i], ci, qi)
                qg = linear_reservoir(rgs_[i], cg, qg)
            qs[i, :] = qs_[i, :, 0] + qi + qg
        # seq, batch, feature
        q_sim = torch.unsqueeze(qs, dim=2)
        if return_state:
            return q_sim, es, *w, s, fr, qi, qg
        return q_sim, es

__init__(self, kernel_size, warmup_length, nn_module=None, param_var_index=None, source_book='HF', source_type='sources', et_output=1, nn_hidden_size=None, nn_dropout=0.2, param_test_way='var') special

Parameters

kernel_size the time length of unit hydrograph warmup_length the length of warmup periods; XAJ needs a warmup period to generate reasonable initial state values nn_module We initialize the module when we firstly initialize Xaj4DplWithNnModule. Then we will iterately call Xaj4DplWithNnModule module for warmup. Hence, in warmup period, we don't need to initialize it again param_var_index the index of parameters which will be time-varying NOTE: at the most, we support k, b, and c to be time-varying et_output we only support one-layer et now, because its water balance is not easy to handle with

Source code in torchhydro/models/dpl4xaj_nn4et.py
def __init__(
    self,
    kernel_size: int,
    warmup_length: int,
    nn_module=None,
    param_var_index=None,
    source_book="HF",
    source_type="sources",
    et_output=1,
    nn_hidden_size: Union[int, tuple, list] = None,
    nn_dropout=0.2,
    param_test_way=MODEL_PARAM_TEST_WAY["time_varying"],
):
    """
    Parameters
    ----------
    kernel_size
        the time length of unit hydrograph
    warmup_length
        the length of warmup periods;
        XAJ needs a warmup period to generate reasonable initial state values
    nn_module
        We initialize the module when we firstly initialize Xaj4DplWithNnModule.
        Then we will iterately call Xaj4DplWithNnModule module for warmup.
        Hence, in warmup period, we don't need to initialize it again
    param_var_index
        the index of parameters which will be time-varying
        NOTE: at the most, we support k, b, and c to be time-varying
    et_output
        we only support one-layer et now, because its water balance is not easy to handle with
    """
    if param_var_index is None:
        param_var_index = [0, 6]
    if nn_hidden_size is None:
        nn_hidden_size = [16, 8]
    super(Xaj4DplWithNnModule, self).__init__()
    self.params_names = MODEL_PARAM_DICT["xaj_mz"]["param_name"]
    param_range = MODEL_PARAM_DICT["xaj_mz"]["param_range"]
    self.k_scale = param_range["K"]
    self.b_scale = param_range["B"]
    self.im_sacle = param_range["IM"]
    self.um_scale = param_range["UM"]
    self.lm_scale = param_range["LM"]
    self.dm_scale = param_range["DM"]
    self.c_scale = param_range["C"]
    self.sm_scale = param_range["SM"]
    self.ex_scale = param_range["EX"]
    self.ki_scale = param_range["KI"]
    self.kg_scale = param_range["KG"]
    self.a_scale = param_range["A"]
    self.theta_scale = param_range["THETA"]
    self.ci_scale = param_range["CI"]
    self.cg_scale = param_range["CG"]
    self.kernel_size = kernel_size
    self.warmup_length = warmup_length
    # there are 2 input variables in XAJ: P and PET
    self.feature_size = 2
    if nn_module is None:
        # 7: k, um, lm, dm, c, prcp, p_and_e[:, 1] + 1/3: w0 or wu0, wl0, wd0
        self.evap_nn_module = NnModule4Hydro(
            7 + et_output, et_output, nn_hidden_size, nn_dropout
        )
    else:
        self.evap_nn_module = nn_module
    self.source_book = source_book
    self.source_type = source_type
    self.et_output = et_output
    self.param_var_index = param_var_index
    self.nn_hidden_size = nn_hidden_size
    self.nn_dropout = nn_dropout
    self.param_test_way = param_test_way

forward(self, p_and_e, parameters_ts, return_state=False)

run XAJ model

Parameters

p_and_e precipitation and potential evapotranspiration parameters_ts time series parameters of XAJ model; some parameters may be time-varying specified by param_var_index return_state if True, return state values, mainly for warmup periods

Returns

torch.Tensor streamflow got by XAJ

Source code in torchhydro/models/dpl4xaj_nn4et.py
def forward(self, p_and_e, parameters_ts, return_state=False):
    """
    run XAJ model

    Parameters
    ----------
    p_and_e
        precipitation and potential evapotranspiration
    parameters_ts
        time series parameters of XAJ model;
        some parameters may be time-varying specified by param_var_index
    return_state
        if True, return state values, mainly for warmup periods

    Returns
    -------
    torch.Tensor
        streamflow got by XAJ
    """
    xaj_device = p_and_e.device
    if self.param_test_way == MODEL_PARAM_TEST_WAY["time_varying"]:
        parameters = parameters_ts[-1, :, :]
    else:
        # parameters_ts must be a 2-d tensor: (basin, param)
        parameters = parameters_ts
    # denormalize the parameters to general range
    # TODO: now the specific parameters are hard coded; 0 is k, 1 is b, 6 is c, same as in model_config.py
    if 0 not in self.param_var_index or self.param_var_index is None:
        ks = self.k_scale[0] + parameters[:, 0] * (
            self.k_scale[1] - self.k_scale[0]
        )
    else:
        ks = self.k_scale[0] + parameters_ts[:, :, 0] * (
            self.k_scale[1] - self.k_scale[0]
        )
    if 1 not in self.param_var_index or self.param_var_index is None:
        bs = self.b_scale[0] + parameters[:, 1] * (
            self.b_scale[1] - self.b_scale[0]
        )
    else:
        bs = self.b_scale[0] + parameters_ts[:, :, 1] * (
            self.b_scale[1] - self.b_scale[0]
        )
    im = self.im_sacle[0] + parameters[:, 2] * (self.im_sacle[1] - self.im_sacle[0])
    um = self.um_scale[0] + parameters[:, 3] * (self.um_scale[1] - self.um_scale[0])
    lm = self.lm_scale[0] + parameters[:, 4] * (self.lm_scale[1] - self.lm_scale[0])
    dm = self.dm_scale[0] + parameters[:, 5] * (self.dm_scale[1] - self.dm_scale[0])
    if 6 not in self.param_var_index or self.param_var_index is None:
        cs = self.c_scale[0] + parameters[:, 6] * (
            self.c_scale[1] - self.c_scale[0]
        )
    else:
        cs = self.c_scale[0] + parameters_ts[:, :, 6] * (
            self.c_scale[1] - self.c_scale[0]
        )
    sm = self.sm_scale[0] + parameters[:, 7] * (self.sm_scale[1] - self.sm_scale[0])
    ex = self.ex_scale[0] + parameters[:, 8] * (self.ex_scale[1] - self.ex_scale[0])
    ki_ = self.ki_scale[0] + parameters[:, 9] * (
        self.ki_scale[1] - self.ki_scale[0]
    )
    kg_ = self.kg_scale[0] + parameters[:, 10] * (
        self.kg_scale[1] - self.kg_scale[0]
    )
    # ki+kg should be smaller than 1; if not, we scale them
    ki = torch.where(
        ki_ + kg_ < 1.0,
        ki_,
        (1 - PRECISION) / (ki_ + kg_) * ki_,
    )
    kg = torch.where(
        ki_ + kg_ < 1.0,
        kg_,
        (1 - PRECISION) / (ki_ + kg_) * kg_,
    )
    a = self.a_scale[0] + parameters[:, 11] * (self.a_scale[1] - self.a_scale[0])
    theta = self.theta_scale[0] + parameters[:, 12] * (
        self.theta_scale[1] - self.theta_scale[0]
    )
    ci = self.ci_scale[0] + parameters[:, 13] * (
        self.ci_scale[1] - self.ci_scale[0]
    )
    cg = self.cg_scale[0] + parameters[:, 14] * (
        self.cg_scale[1] - self.cg_scale[0]
    )

    # initialize state values
    warmup_length = self.warmup_length
    if warmup_length > 0:
        # set no_grad for warmup periods
        with torch.no_grad():
            p_and_e_warmup = p_and_e[0:warmup_length, :, :]
            if self.param_test_way == MODEL_PARAM_TEST_WAY["time_varying"]:
                parameters_ts_warmup = parameters_ts[0:warmup_length, :, :]
            else:
                parameters_ts_warmup = parameters_ts
            cal_init_xaj4dpl = Xaj4DplWithNnModule(
                kernel_size=self.kernel_size,
                # warmup_length must be 0 here
                warmup_length=0,
                nn_module=self.evap_nn_module,
                param_var_index=self.param_var_index,
                source_book=self.source_book,
                source_type=self.source_type,
                et_output=self.et_output,
                nn_hidden_size=self.nn_hidden_size,
                nn_dropout=self.nn_dropout,
                param_test_way=self.param_test_way,
            )
            if cal_init_xaj4dpl.warmup_length > 0:
                raise RuntimeError("Please set init model's warmup length to 0!!!")
            _, _, *w0, s0, fr0, qi0, qg0 = cal_init_xaj4dpl(
                p_and_e_warmup, parameters_ts_warmup, return_state=True
            )
    else:
        # use detach func to make wu0 no_grad as it is an initial value
        if self.et_output == 1:
            # () and , must be added, otherwise, w0 will be a tensor, not a tuple
            w0 = (0.5 * (um.detach() + lm.detach() + dm.detach()),)
        else:
            raise ValueError("et_output should be 1 or 3")
        s0 = 0.5 * (sm.detach())
        fr0 = torch.full(ci.size(), 0.1).to(xaj_device)
        qi0 = torch.full(ci.size(), 0.1).to(xaj_device)
        qg0 = torch.full(cg.size(), 0.1).to(xaj_device)

    inputs = p_and_e[warmup_length:, :, :]
    runoff_ims_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
    rss_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
    ris_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
    rgs_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
    es_ = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
    for i in range(inputs.shape[0]):
        if 0 in self.param_var_index or self.param_var_index is None:
            k = ks[i]
        else:
            k = ks
        if 1 in self.param_var_index or self.param_var_index is None:
            b = bs[i]
        else:
            b = bs
        if 6 in self.param_var_index or self.param_var_index is None:
            c = cs[i]
        else:
            c = cs
        if i == 0:
            (r, rim, e, pe), w = self.xaj_generation_with_new_module(
                inputs[i, :, :], k, b, im, um, lm, dm, c, *w0
            )
            if self.source_type == "sources":
                (rs, ri, rg), (s, fr) = xaj_sources(
                    pe, r, sm, ex, ki, kg, s0, fr0, book=self.source_book
                )
            elif self.source_type == "sources5mm":
                (rs, ri, rg), (s, fr) = xaj_sources5mm(
                    pe, r, sm, ex, ki, kg, s0, fr0, book=self.source_book
                )
            else:
                raise NotImplementedError("No such divide-sources method")
        else:
            (r, rim, e, pe), w = self.xaj_generation_with_new_module(
                inputs[i, :, :], k, b, im, um, lm, dm, c, *w
            )
            if self.source_type == "sources":
                (rs, ri, rg), (s, fr) = xaj_sources(
                    pe, r, sm, ex, ki, kg, s, fr, book=self.source_book
                )
            elif self.source_type == "sources5mm":
                (rs, ri, rg), (s, fr) = xaj_sources5mm(
                    pe, r, sm, ex, ki, kg, s, fr, book=self.source_book
                )
            else:
                raise NotImplementedError("No such divide-sources method")
        # impevious part is pe * im
        runoff_ims_[i, :] = rim
        # so for non-imprvious part, the result should be corrected
        rss_[i, :] = rs * (1 - im)
        ris_[i, :] = ri * (1 - im)
        rgs_[i, :] = rg * (1 - im)
        es_[i, :] = e
    # seq, batch, feature
    runoff_im = torch.unsqueeze(runoff_ims_, dim=2)
    rss = torch.unsqueeze(rss_, dim=2)
    es = torch.unsqueeze(es_, dim=2)

    conv_uh = KernelConv(a, theta, self.kernel_size)
    qs_ = conv_uh(runoff_im + rss)

    qs = torch.full(inputs.shape[:2], 0.0).to(xaj_device)
    for i in range(inputs.shape[0]):
        if i == 0:
            qi = linear_reservoir(ris_[i], ci, qi0)
            qg = linear_reservoir(rgs_[i], cg, qg0)
        else:
            qi = linear_reservoir(ris_[i], ci, qi)
            qg = linear_reservoir(rgs_[i], cg, qg)
        qs[i, :] = qs_[i, :, 0] + qi + qg
    # seq, batch, feature
    q_sim = torch.unsqueeze(qs, dim=2)
    if return_state:
        return q_sim, es, *w, s, fr, qi, qg
    return q_sim, es

calculate_1layer_w_storage(um, lm, dm, w0, pe, r)

Update the soil moisture value.

According to the runoff-generation equation 2.60 in the book "SHUIWENYUBAO", dW = dPE - dR

Parameters

um average soil moisture storage capacity of the upper layer (mm) lm average soil moisture storage capacity of the lower layer (mm) dm average soil moisture storage capacity of the deep layer (mm) w0 initial values of soil moisture pe net precipitation; it is able to be negative value in this function r runoff

Returns

torch.Tensor w -- soil moisture

Source code in torchhydro/models/dpl4xaj_nn4et.py
def calculate_1layer_w_storage(um, lm, dm, w0, pe, r):
    """
    Update the soil moisture value.

    According to the runoff-generation equation 2.60 in the book "SHUIWENYUBAO", dW = dPE - dR

    Parameters
    ----------
    um
        average soil moisture storage capacity of the upper layer (mm)
    lm
        average soil moisture storage capacity of the lower layer (mm)
    dm
        average soil moisture storage capacity of the deep layer (mm)
    w0
        initial values of soil moisture
    pe
        net precipitation; it is able to be negative value in this function
    r
        runoff

    Returns
    -------
    torch.Tensor
        w -- soil moisture
    """
    xaj_device = pe.device
    tensor_zeros = torch.full_like(w0, 0.0, device=xaj_device)
    # water balance (equation 2.2 in Page 13, also shown in Page 23)
    w = w0 + pe - r
    return torch.clamp(w, min=tensor_zeros, max=(um + lm + dm) - PRECISION)

dropout

Author: Wenyu Ouyang Date: 2023-07-11 17:39:09 LastEditTime: 2024-10-09 15:27:49 LastEditors: Wenyu Ouyang Description: Functions for dropout. Code is from Kuai Fang's repo: hydroDL FilePath: orchhydro orchhydro\models\dropout.py Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.

DropMask (InplaceFunction)

Source code in torchhydro/models/dropout.py
class DropMask(torch.autograd.function.InplaceFunction):
    @classmethod
    def forward(cls, ctx, input, mask, train=False, inplace=False):
        """_summary_

        Parameters
        ----------
        ctx : autograd.Function
            ctx is a context object that can be used to store information for backward computation
        input : _type_
            _description_
        mask : _type_
            _description_
        train : bool, optional
            if the model is in training mode, by default False
        inplace : bool, optional
            inplace operation, by default False

        Returns
        -------
        _type_
            _description_
        """
        ctx.master_train = train
        ctx.inplace = inplace
        ctx.mask = mask

        if not ctx.master_train:
            # if not in training mode, just return the input
            return input
        if ctx.inplace:
            # mark_dirty() is used to mark the input as dirty, meaning inplace operation is performed
            # make it dirty so that the gradient is calculated correctly during backward
            ctx.mark_dirty(input)
            output = input
        else:
            # clone the input tensor so that avoid changing the input tensor
            output = input.clone()
        # inplace multiplication with the mask
        output.mul_(ctx.mask)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        """
        backward method for DropMask
        staticmethod means that the method belongs to the class itself and not to the object of the class

        Parameters
        ----------
        ctx : _type_
            store information for backward computation
        grad_output : _type_
            gradient of the downstream layer

        Returns
        -------
        _type_
            _description_
        """
        if ctx.master_train:
            # if in training mode, return the gradient multiplied by the mask
            return grad_output * ctx.mask, None, None, None
        else:
            # if not in training mode, return the gradient directly
            return grad_output, None, None, None

backward(ctx, grad_output) staticmethod

backward method for DropMask staticmethod means that the method belongs to the class itself and not to the object of the class

Parameters

ctx : type store information for backward computation grad_output : type gradient of the downstream layer

Returns

type description

Source code in torchhydro/models/dropout.py
@staticmethod
def backward(ctx, grad_output):
    """
    backward method for DropMask
    staticmethod means that the method belongs to the class itself and not to the object of the class

    Parameters
    ----------
    ctx : _type_
        store information for backward computation
    grad_output : _type_
        gradient of the downstream layer

    Returns
    -------
    _type_
        _description_
    """
    if ctx.master_train:
        # if in training mode, return the gradient multiplied by the mask
        return grad_output * ctx.mask, None, None, None
    else:
        # if not in training mode, return the gradient directly
        return grad_output, None, None, None

forward(ctx, input, mask, train=False, inplace=False) classmethod

summary

Parameters

ctx : autograd.Function ctx is a context object that can be used to store information for backward computation input : type description mask : type description train : bool, optional if the model is in training mode, by default False inplace : bool, optional inplace operation, by default False

Returns

type description

Source code in torchhydro/models/dropout.py
@classmethod
def forward(cls, ctx, input, mask, train=False, inplace=False):
    """_summary_

    Parameters
    ----------
    ctx : autograd.Function
        ctx is a context object that can be used to store information for backward computation
    input : _type_
        _description_
    mask : _type_
        _description_
    train : bool, optional
        if the model is in training mode, by default False
    inplace : bool, optional
        inplace operation, by default False

    Returns
    -------
    _type_
        _description_
    """
    ctx.master_train = train
    ctx.inplace = inplace
    ctx.mask = mask

    if not ctx.master_train:
        # if not in training mode, just return the input
        return input
    if ctx.inplace:
        # mark_dirty() is used to mark the input as dirty, meaning inplace operation is performed
        # make it dirty so that the gradient is calculated correctly during backward
        ctx.mark_dirty(input)
        output = input
    else:
        # clone the input tensor so that avoid changing the input tensor
        output = input.clone()
    # inplace multiplication with the mask
    output.mul_(ctx.mask)

    return output

create_mask(x, dr)

Dropout method in Gal & Ghahramami: A Theoretically Grounded Application of Dropout in RNNs. http://papers.nips.cc/paper/6241-a-theoretically-grounded-application-of-dropout-in-recurrent-neural-networks.pdf

Parameters

!!! x "torch.Tensor" input tensor !!! dr "float" dropout rate

Returns

torch.Tensor mask tensor

Source code in torchhydro/models/dropout.py
def create_mask(x, dr):
    """
    Dropout method in Gal & Ghahramami: A Theoretically Grounded Application of Dropout in RNNs.
    http://papers.nips.cc/paper/6241-a-theoretically-grounded-application-of-dropout-in-recurrent-neural-networks.pdf

    Parameters
    ----------
    x: torch.Tensor
        input tensor
    dr: float
        dropout rate

    Returns
    -------
    torch.Tensor
        mask tensor
    """
    # x.new() creates a new tensor with the same data type as x
    # bernoulli_(1-dr) creates a tensor with the same shape as x, filled with 0 or 1, where 1 has a probability of 1-dr
    # div_(1-dr) divides the tensor by 1-dr, so that the expected value of the tensor is the same as x, for example, if dr=0.5, then the expected value of the tensor is 2*x
    # detach_() can be used to detach the tensor from the computation graph so that the gradient is not calculated
    # if dr=1, then the tensor is all zeros, the results are all NaNs if using the code with bernoulli_(1 - dr).div_(1 - dr).detach_()
    # so we need to add a special case for dr=1
    if dr == 1:
        # add a warning message
        print("Warning: dropout rate is 1, directly set 0.")
        return x.new().resize_as_(x).zero_().detach_()
    return x.new().resize_as_(x).bernoulli_(1 - dr).div_(1 - dr).detach_()

gnn

GNNBaseModel (Module, ABC)

Source code in torchhydro/models/gnn.py
class GNNBaseModel(Module, ABC):
    def __init__(
        self,
        in_channels: int,
        hidden_channels: int,
        num_hidden: int,
        param_sharing: bool,
        layerfun: Callable[[], Module],
        edge_orientation: Optional[str],
        edge_weights: Optional[torch.Tensor],
        output_size: int = 1,
        root_gauge_idx: Optional[int] = None,
    ) -> None:
        super().__init__()
        # 修改:支持多时段输出
        self.output_size = output_size
        self.root_gauge_idx = root_gauge_idx

        self.encoder = Linear(
            in_channels, hidden_channels, weight_initializer="kaiming_uniform"
        )
        if param_sharing:
            self.layers = ModuleList(num_hidden * [layerfun()])
        else:
            self.layers = ModuleList([layerfun() for _ in range(num_hidden)])
        # 传统的decoder(用于所有节点输出)
        self.decoder = Linear(
            hidden_channels, output_size, weight_initializer="kaiming_uniform"
        )
        # 聚合层:将所有节点的信息聚合到根节点
        if root_gauge_idx is not None:
            # 这个层将在forward中动态创建,需要知道确切的节点数
            self.aggregation_layer: Optional[Linear] = None

        self.edge_weights = edge_weights
        self.edge_orientation = edge_orientation
        # 设置自环填充值,优先使用预设的edge_weights,如果没有则在forward中动态设置
        if self.edge_weights is not None:
            self.loop_fill_value: Union[float, str] = (
                1.0 if (self.edge_weights == 0).all() else "mean"
            )
        else:
            # 如果没有预设权重,使用默认值,在forward中可能会根据输入的edge_weight调整
            self.loop_fill_value: Union[float, str] = "mean"

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_weight: torch.Tensor,
        batch_vector: Optional[torch.Tensor] = None,
        evo_tracking: bool = False,
        **kwargs,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
        """
        通用GNN前向传播,兼容两种输入模式:
        1. batch_vector模式(变节点数batch,PyG风格,适合大规模异构图/多流域拼接)
        2. 传统batch维度模式([batch, num_nodes, ...],适合定长节点数)

        参数:
            x: 节点特征张量,shape见上
            edge_index: 边索引,PyG格式
            edge_weight: 边权重,若为None则自动补1
            batch_vector: 节点到batch的映射(如有)
            evo_tracking: 是否记录每层输出
        返回:
            预测结果,或(预测, 演化序列)
        """
        # 保险:所有输入 tensor 强制同步到 encoder.device,防止 device 不一致
        device = self.encoder.weight.device
        x = x.to(device)
        edge_index = edge_index.to(device)
        edge_weight = edge_weight.to(device)
        if batch_vector is not None:
            batch_vector = batch_vector.to(device)
        if batch_vector is not None:
            # 支持 x 为 [batch, num_nodes, window_size, num_features] 或 [total_nodes, window_size, num_features]
            if x.dim() == 4:
                x = x.view(-1, x.size(2), x.size(3))
            elif x.dim() != 3:
                raise ValueError(f"Unsupported x shape for GNN batch_vector mode: {x.shape}")
            x = x.view(x.size(0), -1)
            x_0 = self.encoder(x)
            evolution = [x_0.detach()] if evo_tracking else None
            x = x_0
            for layer in self.layers:
                x = self.apply_layer(layer, x, x_0, edge_index, edge_weight)
                if evo_tracking:
                    evolution.append(x.detach())
            x = self.decoder(x)
            if self.root_gauge_idx is not None:
                if evo_tracking:
                    evolution.append(x.detach())
                batch_size = batch_vector.max().item() + 1
                # PyTorch原生实现batch mean聚合
                out_sum = torch.zeros(batch_size, x.size(-1), device=x.device)
                out_sum = out_sum.index_add(0, batch_vector, x)
                count = torch.zeros(batch_size, device=x.device)
                count = count.index_add(0, batch_vector, torch.ones_like(batch_vector, dtype=x.dtype))
                count = count.clamp_min(1).unsqueeze(-1)
                x = out_sum / count
                # 保证输出 shape 为 [batch, time, feature],即 [batch, output_size, 1](如果 output_size=时间步,特征数=1)
                if x.dim() == 2:
                    x = x.unsqueeze(-1)
                if evo_tracking:
                    return x, evolution
            else:
                # 保证输出 shape 为 [node, time, feature],即 [N, output_size, 1]
                if x.dim() == 2:
                    x = x.unsqueeze(-1)
            return (x, evolution) if evo_tracking else x
        else:
            # 标准 batch 模式,要求 edge_weight 必须输入,和 edge_index 一致
            batch_size, num_nodes, window_size, num_features = x.shape
            x = x.view(batch_size * num_nodes, window_size * num_features)
            # edge_index: [2, num_edges] 或 [batch, 2, num_edges]
            if edge_index.dim() == 3:
                node_offsets = torch.arange(batch_size, device=edge_index.device) * num_nodes
                node_offsets = node_offsets.view(-1, 1, 1)
                edge_index_offset = edge_index + node_offsets
                edge_index = edge_index_offset.transpose(0, 1).contiguous().view(2, -1)
            else:
                if batch_size > 1:
                    edge_indices = []
                    for b in range(batch_size):
                        offset = b * num_nodes
                        edge_indices.append(edge_index + offset)
                    edge_index = torch.cat(edge_indices, dim=1)
            # 添加自环(如有需要,可在数据集预处理)
            # edge_index, edge_weight = add_self_loops(edge_index, edge_weight, num_nodes=batch_size * num_nodes)
            x_0 = self.encoder(x)
            evolution: Optional[List[torch.Tensor]] = [x_0.detach()] if evo_tracking else None
            x = x_0
            for layer in self.layers:
                x = self.apply_layer(layer, x, x_0, edge_index, edge_weight)
                if evo_tracking:
                    evolution.append(x.detach())
            x = self.decoder(x)
            if self.root_gauge_idx is not None:
                if evo_tracking:
                    evolution.append(x.detach())
                x = x.view(batch_size, num_nodes, self.output_size)
                if self.aggregation_layer is None:
                    input_dim = num_nodes * self.output_size
                    self.aggregation_layer = Linear(
                        input_dim, self.output_size, weight_initializer="kaiming_uniform"
                    ).to(x.device)
                x_flat = x.view(batch_size, -1)
                x = self.aggregation_layer(x_flat)
                if evo_tracking:
                    return x, evolution
            else:
                x = x.view(batch_size, num_nodes, self.output_size)
            return (x, evolution) if evo_tracking else x

    @abstractmethod
    def apply_layer(
        self,
        layer: Module,
        x: torch.Tensor,
        x_0: torch.Tensor,
        edge_index: torch.Tensor,
        edge_weights: torch.Tensor,
    ) -> torch.Tensor:
        pass

forward(self, x, edge_index, edge_weight, batch_vector=None, evo_tracking=False, **kwargs)

通用GNN前向传播,兼容两种输入模式: 1. batch_vector模式(变节点数batch,PyG风格,适合大规模异构图/多流域拼接) 2. 传统batch维度模式([batch, num_nodes, ...],适合定长节点数)

!!! 参数 x: 节点特征张量,shape见上 edge_index: 边索引,PyG格式 edge_weight: 边权重,若为None则自动补1 batch_vector: 节点到batch的映射(如有) evo_tracking: 是否记录每层输出 !!! 返回 预测结果,或(预测, 演化序列)

Source code in torchhydro/models/gnn.py
def forward(
    self,
    x: torch.Tensor,
    edge_index: torch.Tensor,
    edge_weight: torch.Tensor,
    batch_vector: Optional[torch.Tensor] = None,
    evo_tracking: bool = False,
    **kwargs,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
    """
    通用GNN前向传播,兼容两种输入模式:
    1. batch_vector模式(变节点数batch,PyG风格,适合大规模异构图/多流域拼接)
    2. 传统batch维度模式([batch, num_nodes, ...],适合定长节点数)

    参数:
        x: 节点特征张量,shape见上
        edge_index: 边索引,PyG格式
        edge_weight: 边权重,若为None则自动补1
        batch_vector: 节点到batch的映射(如有)
        evo_tracking: 是否记录每层输出
    返回:
        预测结果,或(预测, 演化序列)
    """
    # 保险:所有输入 tensor 强制同步到 encoder.device,防止 device 不一致
    device = self.encoder.weight.device
    x = x.to(device)
    edge_index = edge_index.to(device)
    edge_weight = edge_weight.to(device)
    if batch_vector is not None:
        batch_vector = batch_vector.to(device)
    if batch_vector is not None:
        # 支持 x 为 [batch, num_nodes, window_size, num_features] 或 [total_nodes, window_size, num_features]
        if x.dim() == 4:
            x = x.view(-1, x.size(2), x.size(3))
        elif x.dim() != 3:
            raise ValueError(f"Unsupported x shape for GNN batch_vector mode: {x.shape}")
        x = x.view(x.size(0), -1)
        x_0 = self.encoder(x)
        evolution = [x_0.detach()] if evo_tracking else None
        x = x_0
        for layer in self.layers:
            x = self.apply_layer(layer, x, x_0, edge_index, edge_weight)
            if evo_tracking:
                evolution.append(x.detach())
        x = self.decoder(x)
        if self.root_gauge_idx is not None:
            if evo_tracking:
                evolution.append(x.detach())
            batch_size = batch_vector.max().item() + 1
            # PyTorch原生实现batch mean聚合
            out_sum = torch.zeros(batch_size, x.size(-1), device=x.device)
            out_sum = out_sum.index_add(0, batch_vector, x)
            count = torch.zeros(batch_size, device=x.device)
            count = count.index_add(0, batch_vector, torch.ones_like(batch_vector, dtype=x.dtype))
            count = count.clamp_min(1).unsqueeze(-1)
            x = out_sum / count
            # 保证输出 shape 为 [batch, time, feature],即 [batch, output_size, 1](如果 output_size=时间步,特征数=1)
            if x.dim() == 2:
                x = x.unsqueeze(-1)
            if evo_tracking:
                return x, evolution
        else:
            # 保证输出 shape 为 [node, time, feature],即 [N, output_size, 1]
            if x.dim() == 2:
                x = x.unsqueeze(-1)
        return (x, evolution) if evo_tracking else x
    else:
        # 标准 batch 模式,要求 edge_weight 必须输入,和 edge_index 一致
        batch_size, num_nodes, window_size, num_features = x.shape
        x = x.view(batch_size * num_nodes, window_size * num_features)
        # edge_index: [2, num_edges] 或 [batch, 2, num_edges]
        if edge_index.dim() == 3:
            node_offsets = torch.arange(batch_size, device=edge_index.device) * num_nodes
            node_offsets = node_offsets.view(-1, 1, 1)
            edge_index_offset = edge_index + node_offsets
            edge_index = edge_index_offset.transpose(0, 1).contiguous().view(2, -1)
        else:
            if batch_size > 1:
                edge_indices = []
                for b in range(batch_size):
                    offset = b * num_nodes
                    edge_indices.append(edge_index + offset)
                edge_index = torch.cat(edge_indices, dim=1)
        # 添加自环(如有需要,可在数据集预处理)
        # edge_index, edge_weight = add_self_loops(edge_index, edge_weight, num_nodes=batch_size * num_nodes)
        x_0 = self.encoder(x)
        evolution: Optional[List[torch.Tensor]] = [x_0.detach()] if evo_tracking else None
        x = x_0
        for layer in self.layers:
            x = self.apply_layer(layer, x, x_0, edge_index, edge_weight)
            if evo_tracking:
                evolution.append(x.detach())
        x = self.decoder(x)
        if self.root_gauge_idx is not None:
            if evo_tracking:
                evolution.append(x.detach())
            x = x.view(batch_size, num_nodes, self.output_size)
            if self.aggregation_layer is None:
                input_dim = num_nodes * self.output_size
                self.aggregation_layer = Linear(
                    input_dim, self.output_size, weight_initializer="kaiming_uniform"
                ).to(x.device)
            x_flat = x.view(batch_size, -1)
            x = self.aggregation_layer(x_flat)
            if evo_tracking:
                return x, evolution
        else:
            x = x.view(batch_size, num_nodes, self.output_size)
        return (x, evolution) if evo_tracking else x

kernel_conv

KernelConv (Module)

Source code in torchhydro/models/kernel_conv.py
class KernelConv(nn.Module):
    def __init__(self, a, theta, kernel_size):
        """
        The convolution kernel for the convolution operation in routing module

        We use two-parameter gamma distribution to determine the unit hydrograph,
        which comes from [mizuRoute](http://www.geosci-model-dev.net/9/2223/2016/)

        Parameters
        ----------
        a
            shape parameter
        theta
            timescale parameter
        kernel_size
            the size of conv kernel
        """
        super(KernelConv, self).__init__()
        self.a = a
        self.theta = theta
        routa = self.a.repeat(kernel_size, 1).unsqueeze(-1)
        routb = self.theta.repeat(kernel_size, 1).unsqueeze(-1)
        self.uh_gamma = uh_gamma(routa, routb, len_uh=kernel_size)

    def forward(self, x):
        """
        1d-convolution calculation

        Parameters
        ----------
        x
            x is a sequence-first variable, so the dim of x is [seq, batch, feature]

        Returns
        -------
        torch.Tensor
            convolution
        """
        # dim: permute from [len_uh, batch, feature] to [batch, feature, len_uh]
        uh = self.uh_gamma.permute(1, 2, 0)
        # the dim of conv kernel in F.conv1d is out_channels, in_channels (feature)/groups, width (seq)
        # the dim of inputs in F.conv1d are batch, in_channels (feature) and width (seq),
        # each element in a batch should has its own conv kernel,
        # hence set groups = batch_size and permute input's batch-dim to channel-dim to make "groups" work
        inputs = x.permute(2, 1, 0)
        batch_size = x.shape[1]
        # conv1d in NN is different from the general convolution: it is lack of a flip
        outputs = F.conv1d(
            inputs, torch.flip(uh, [2]), groups=batch_size, padding=uh.shape[-1] - 1
        )
        # permute from [feature, batch, seq] to [seq, batch, feature]
        return outputs[:, :, : -(uh.shape[-1] - 1)].permute(2, 1, 0)

__init__(self, a, theta, kernel_size) special

The convolution kernel for the convolution operation in routing module

We use two-parameter gamma distribution to determine the unit hydrograph, which comes from mizuRoute

Parameters

a shape parameter theta timescale parameter kernel_size the size of conv kernel

Source code in torchhydro/models/kernel_conv.py
def __init__(self, a, theta, kernel_size):
    """
    The convolution kernel for the convolution operation in routing module

    We use two-parameter gamma distribution to determine the unit hydrograph,
    which comes from [mizuRoute](http://www.geosci-model-dev.net/9/2223/2016/)

    Parameters
    ----------
    a
        shape parameter
    theta
        timescale parameter
    kernel_size
        the size of conv kernel
    """
    super(KernelConv, self).__init__()
    self.a = a
    self.theta = theta
    routa = self.a.repeat(kernel_size, 1).unsqueeze(-1)
    routb = self.theta.repeat(kernel_size, 1).unsqueeze(-1)
    self.uh_gamma = uh_gamma(routa, routb, len_uh=kernel_size)

forward(self, x)

1d-convolution calculation

Parameters

x x is a sequence-first variable, so the dim of x is [seq, batch, feature]

Returns

torch.Tensor convolution

Source code in torchhydro/models/kernel_conv.py
def forward(self, x):
    """
    1d-convolution calculation

    Parameters
    ----------
    x
        x is a sequence-first variable, so the dim of x is [seq, batch, feature]

    Returns
    -------
    torch.Tensor
        convolution
    """
    # dim: permute from [len_uh, batch, feature] to [batch, feature, len_uh]
    uh = self.uh_gamma.permute(1, 2, 0)
    # the dim of conv kernel in F.conv1d is out_channels, in_channels (feature)/groups, width (seq)
    # the dim of inputs in F.conv1d are batch, in_channels (feature) and width (seq),
    # each element in a batch should has its own conv kernel,
    # hence set groups = batch_size and permute input's batch-dim to channel-dim to make "groups" work
    inputs = x.permute(2, 1, 0)
    batch_size = x.shape[1]
    # conv1d in NN is different from the general convolution: it is lack of a flip
    outputs = F.conv1d(
        inputs, torch.flip(uh, [2]), groups=batch_size, padding=uh.shape[-1] - 1
    )
    # permute from [feature, batch, seq] to [seq, batch, feature]
    return outputs[:, :, : -(uh.shape[-1] - 1)].permute(2, 1, 0)

uh_conv(x, uh_made)

Function for 1d-convolution calculation

Parameters

x x is a sequence-first variable, so the dim of x is [seq, batch, feature] uh_made unit hydrograph from uh_gamma or other unit-hydrograph method

Returns

torch.Tensor convolution, [seq, batch, feature]; the length of seq is same as x's

Source code in torchhydro/models/kernel_conv.py
def uh_conv(x, uh_made) -> torch.Tensor:
    """
    Function for 1d-convolution calculation

    Parameters
    ----------
    x
        x is a sequence-first variable, so the dim of x is [seq, batch, feature]
    uh_made
        unit hydrograph from uh_gamma or other unit-hydrograph method

    Returns
    -------
    torch.Tensor
        convolution, [seq, batch, feature]; the length of seq is same as x's
    """
    uh = uh_made.permute(1, 2, 0)
    # the dim of conv kernel in F.conv1d is out_channels, in_channels (feature)/groups, width (seq)
    # the dim of inputs in F.conv1d are batch, in_channels (feature) and width (seq),
    # each element in a batch should has its own conv kernel,
    # hence set groups = batch_size and permute input's batch-dim to channel-dim to make "groups" work
    inputs = x.permute(2, 1, 0)
    batch_size = x.shape[1]
    # conv1d in NN is different from the general convolution: it is lack of a flip
    outputs = F.conv1d(
        inputs, torch.flip(uh, [2]), groups=batch_size, padding=uh.shape[-1] - 1
    )
    # cut to same shape with x and permute from [feature, batch, seq] to [seq, batch, feature]
    return outputs[:, :, : x.shape[0]].permute(2, 1, 0)

uh_gamma(a, theta, len_uh=10)

A simple two-parameter Gamma distribution as a unit-hydrograph to route instantaneous runoff from a hydrologic model

The method comes from mizuRoute -- http://www.geosci-model-dev.net/9/2223/2016/

Parameters

a shape parameter theta timescale parameter len_uh the time length of the unit hydrograph

Returns

torch.Tensor the unit hydrograph, dim: [seq, batch, feature]

Source code in torchhydro/models/kernel_conv.py
def uh_gamma(a, theta, len_uh=10):
    """
    A simple two-parameter Gamma distribution as a unit-hydrograph to route instantaneous runoff from a hydrologic model

    The method comes from mizuRoute -- http://www.geosci-model-dev.net/9/2223/2016/

    Parameters
    ----------
    a
        shape parameter
    theta
        timescale parameter
    len_uh
        the time length of the unit hydrograph

    Returns
    -------
    torch.Tensor
        the unit hydrograph, dim: [seq, batch, feature]

    """
    # dims of a: time_seq (same all time steps), batch, feature=1
    m = a.shape
    assert len_uh <= m[0]
    # aa > 0, here we set minimum 0.1 (min of a is 0, set when calling this func); First dimension of a is repeat
    aa = F.relu(a[0:len_uh, :, :]) + 0.1
    # theta > 0, here set minimum 0.5
    theta = F.relu(theta[0:len_uh, :, :]) + 0.5
    # len_f, batch, feature
    t = (
        torch.arange(0.5, len_uh * 1.0)
        .view([len_uh, 1, 1])
        .repeat([1, m[1], m[2]])
        .to(aa.device)
    )
    denominator = (aa.lgamma().exp()) * (theta**aa)
    # [len_f, m[1], m[2]]
    w = 1 / denominator * (t ** (aa - 1)) * (torch.exp(-t / theta))
    w = w / w.sum(0)  # scale to 1 for each UH
    return w

model_dict_function

Author: Wenyu Ouyang Date: 2021-12-31 11:08:29 LastEditTime: 2025-07-13 18:17:48 LastEditors: Wenyu Ouyang Description: Dicts including models (which are seq-first), losses, and optims FilePath: orchhydro orchhydro\models\model_dict_function.py Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.

model_utils

Author: Wenyu Ouyang Date: 2021-08-09 10:19:13 LastEditTime: 2025-04-15 12:59:35 LastEditors: Wenyu Ouyang Description: Some util functions for modeling FilePath: /torchhydro/torchhydro/models/model_utils.py Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.

get_the_device(device_num)

Get device for torch according to its name

Parameters

device_num : Union[list, int] number of the device -- -1 means "cpu" or 0, 1, ... means "cuda:x" or "mps:x"

Source code in torchhydro/models/model_utils.py
def get_the_device(device_num: Union[list, int]):
    """
    Get device for torch according to its name

    Parameters
    ----------
    device_num : Union[list, int]
        number of the device -- -1 means "cpu" or 0, 1, ... means "cuda:x" or "mps:x"
    """
    if device_num in [[-1], -1, ["-1"]]:
        return torch.device("cpu")
    if torch.cuda.is_available():
        return (
            torch.device(f"cuda:{str(device_num)}")
            if type(device_num) is not list
            else torch.device(f"cuda:{str(device_num[0])}")
        )
    # Check for MPS (MacOS)
    mps_available = False
    with contextlib.suppress(AttributeError):
        mps_available = torch.backends.mps.is_available()
    if mps_available:
        if device_num != 0:
            warnings.warn(
                f"MPS only supports device 0. Using 'mps:0' instead of {device_num}."
            )
        return torch.device("mps:0")
    if device_num not in [[-1], -1, ["-1"]]:
        warnings.warn("You don't have GPU, so have to choose cpu for models")
    return torch.device("cpu")

mtslstm

MTSLSTM (Module)

Multi-Temporal-Scale LSTM (MTS-LSTM).

This model processes multi-frequency time-series data (hour/day/week) by aggregating high-frequency (hourly) inputs into lower-frequency branches. It supports per-feature down-aggregation, optional state transfer between frequency branches, and loading pretrained weights for the daily branch.

Example usage: # Unified hourly input (recommended) model = MTSLSTM( hidden_sizes=[64, 64, 64], output_size=1, feature_buckets=[2, 2, 1, 0, ...], frequency_factors=[7, 24], # week->day ×7, day->hour ×24 seq_lengths=[T_week, T_day, T_hour], slice_transfer=True, ) y = model(x_hour) # x_hour: (T_hour, B, D)

1
2
# Legacy: explicitly provide each frequency branch
y = model(x_week, x_day, x_hour)
Source code in torchhydro/models/mtslstm.py
class MTSLSTM(nn.Module):
    """Multi-Temporal-Scale LSTM (MTS-LSTM).

    This model processes multi-frequency time-series data (hour/day/week) by
    aggregating high-frequency (hourly) inputs into lower-frequency branches.
    It supports per-feature down-aggregation, optional state transfer between
    frequency branches, and loading pretrained weights for the daily branch.

    Example usage:
        # Unified hourly input (recommended)
        model = MTSLSTM(
            hidden_sizes=[64, 64, 64],
            output_size=1,
            feature_buckets=[2, 2, 1, 0, ...],
            frequency_factors=[7, 24],   # week->day ×7, day->hour ×24
            seq_lengths=[T_week, T_day, T_hour],
            slice_transfer=True,
        )
        y = model(x_hour)  # x_hour: (T_hour, B, D)

        # Legacy: explicitly provide each frequency branch
        y = model(x_week, x_day, x_hour)
    """

    def __init__(
        self,
        input_sizes: Union[int, List[int], None] = None,
        hidden_sizes: Union[int, List[int]] = 128,
        output_size: int = 1,
        shared_mtslstm: bool = False,
        transfer: Union[
            None, str, Dict[str, Optional[Literal["identity", "linear"]]]
        ] = "linear",
        dropout: float = 0.0,
        return_all: bool = False,
        add_freq_one_hot_if_shared: bool = True,
        auto_build_lowfreq: bool = False,
        build_factor: int = 7,
        agg_reduce: Literal["mean", "sum"] = "mean",
        per_feature_aggs: Optional[List[Literal["mean", "sum"]]] = None,
        truncate_incomplete: bool = True,
        slice_transfer: bool = True,
        slice_use_ceil: bool = True,
        seq_lengths: Optional[List[int]] = None,
        frequency_factors: Optional[List[int]] = None,
        feature_buckets: Optional[List[int]] = None,
        per_feature_aggs_map: Optional[List[Literal["mean", "sum"]]] = None,
        down_aggregate_all_to_each_branch: bool = True,
        pretrained_day_path: Optional[str] = None,
        pretrained_lstm_prefix: Optional[str] = None,
        pretrained_head_prefix: Optional[str] = None,
        pretrained_flag: bool = False,
        linear1_size: Optional[int] = None,
        linear2_size: Optional[int] = None
    ):
        """Initializes an MTSLSTM model.

        Args:
            input_sizes: Input feature dimension(s). Can be:
                * int: shared across all frequency branches
                * list: per-frequency input sizes
                * None: inferred from `feature_buckets`
            hidden_sizes: Hidden dimension(s) for each LSTM branch.
            output_size: Output dimension per timestep.
            shared_mtslstm: If True, all frequency branches share one LSTM.
            transfer: Hidden state transfer mode between frequencies.
                * None: no transfer
                * "identity": copy states directly (same dim required)
                * "linear": learn linear projection between dims
            dropout: Dropout probability applied before heads.
            return_all: If True, return all branch outputs (dict f0,f1,...).
                If False, return only the highest-frequency output.
            add_freq_one_hot_if_shared: If True and `shared_mtslstm=True`,
                append frequency one-hot encoding to inputs.
            auto_build_lowfreq: Legacy 2-frequency path (high->low).
            build_factor: Aggregation factor for auto low-frequency.
            agg_reduce: Aggregation method for downsampling ("mean" or "sum").
            per_feature_aggs: Optional list of per-feature aggregation methods.
            truncate_incomplete: Whether to drop remainder timesteps when
                aggregating (vs. zero-padding).
            slice_transfer: If True, transfer LSTM states at slice boundaries
                computed by seq_lengths × frequency_factors.
            slice_use_ceil: If True, use ceil for slice length calculation.
            seq_lengths: Per-frequency sequence lengths [low, ..., high].
            frequency_factors: Multipliers between adjacent frequencies.
                Example: [7,24] means week->day ×7, day->hour ×24.
            feature_buckets: Per-feature frequency assignment (len = D).
                0 = lowest (week), nf-1 = highest (hour).
            per_feature_aggs_map: Per-feature aggregation method ("mean"/"sum").
            down_aggregate_all_to_each_branch: If True, branch f includes all
                features with bucket >= f (down-aggregate); else only == f.
            pretrained_day_path: Optional path to pretrained checkpoint. If set,
                loads weights for the daily (f1) LSTM and head.

        Raises:
            AssertionError: If configuration is inconsistent.
        """
        super().__init__()

        # Store configuration parameters
        self.pretrained_day_path = pretrained_day_path
        self.pretrained_flag = pretrained_flag
        self.linear1_size = linear1_size
        self.linear2_size = linear2_size
        self.output_size = output_size
        self.shared = shared_mtslstm
        self.return_all_default = return_all
        self.feature_buckets = list(feature_buckets) if feature_buckets is not None else None
        self.per_feature_aggs_map = list(per_feature_aggs_map) if per_feature_aggs_map is not None else None
        self.down_agg_all = down_aggregate_all_to_each_branch
        self.auto_build_lowfreq = auto_build_lowfreq

        # Aggregation and slicing parameters
        assert build_factor >= 2, "build_factor must be >=2"
        self.build_factor = int(build_factor)
        self.agg_reduce = agg_reduce
        self.per_feature_aggs = per_feature_aggs
        self.truncate_incomplete = truncate_incomplete
        self.slice_transfer = slice_transfer
        self.slice_use_ceil = slice_use_ceil
        self.seq_lengths = list(seq_lengths) if seq_lengths is not None else None
        self._warned_slice_fallback = False

        # Setup frequency configuration
        self._setup_frequency_config(feature_buckets, input_sizes, auto_build_lowfreq)

        # Validate seq_lengths
        if self.seq_lengths is not None:
            assert len(self.seq_lengths) == self.nf, "seq_lengths length must match nf"

        # Setup input sizes for each branch
        self.base_input_sizes = self._setup_input_sizes(
            self.feature_buckets, input_sizes, self.down_agg_all
        )

        # Setup hidden layer sizes
        if isinstance(hidden_sizes, int):
            self.hidden_sizes = [hidden_sizes] * self.nf
        else:
            assert len(hidden_sizes) == self.nf, "hidden_sizes length mismatch"
            self.hidden_sizes = list(hidden_sizes)

        # Setup frequency one-hot encoding
        self.add_freq1hot = add_freq_one_hot_if_shared and self.shared

        # Setup transfer configuration
        self._setup_transfer_config(transfer)

        # Setup frequency factors and slice timesteps
        self._setup_frequency_factors(frequency_factors, auto_build_lowfreq, self.build_factor)

        # Calculate effective input sizes (including one-hot if needed)
        eff_input_sizes = self.base_input_sizes[:]
        if self.add_freq1hot:
            eff_input_sizes = [d + self.nf for d in eff_input_sizes]

        if self.shared and len(set(eff_input_sizes)) != 1:
            raise ValueError("shared_mtslstm=True requires equal input sizes.")

        # Create model layers
        self._create_model_layers(eff_input_sizes)

        # Create transfer layers
        self._create_transfer_layers()

        # Setup dropout
        self.dropout = nn.Dropout(p=dropout)

        # Set unified hourly aggregation flag
        self.use_hourly_unified = self.feature_buckets is not None

        # Load pretrained weights if specified
        self._load_pretrained_weights(pretrained_lstm_prefix, pretrained_head_prefix)

    def _setup_frequency_config(self, feature_buckets, input_sizes, auto_build_lowfreq):
        """Setup frequency configuration and compute number of frequencies."""
        if feature_buckets is not None:
            assert len(feature_buckets) > 0, "feature_buckets cannot be empty"
            self.nf = max(feature_buckets) + 1
        else:
            if isinstance(input_sizes, int) or input_sizes is None:
                self.nf = 2 if auto_build_lowfreq else 2
            else:
                assert len(input_sizes) >= 2, "At least 2 frequencies required"
                self.nf = len(input_sizes)

    def _setup_input_sizes(self, feature_buckets, input_sizes, down_aggregate_all_to_each_branch):
        """Setup input sizes for each frequency branch."""
        if feature_buckets is not None:
            D = len(feature_buckets)
            if down_aggregate_all_to_each_branch:
                base_input_sizes = [
                    sum(1 for k in range(D) if feature_buckets[k] >= f)
                    for f in range(self.nf)
                ]
            else:
                base_input_sizes = [
                    sum(1 for k in range(D) if feature_buckets[k] == f)
                    for f in range(self.nf)
                ]
        else:
            if isinstance(input_sizes, int):
                base_input_sizes = [input_sizes] * self.nf
            else:
                base_input_sizes = list(input_sizes)

        assert len(base_input_sizes) == self.nf, "input_sizes mismatch with nf"
        return base_input_sizes

    def _setup_transfer_config(self, transfer):
        """Setup transfer configuration for hidden and cell states."""
        if transfer is None or isinstance(transfer, str):
            transfer = {"h": transfer, "c": transfer}
        self.transfer_mode: Dict[str, Optional[str]] = {
            "h": transfer.get("h", None),
            "c": transfer.get("c", None),
        }
        for k in ("h", "c"):
            assert self.transfer_mode[k] in (
                None, "identity", "linear"
            ), "transfer must be None/'identity'/'linear'"

    def _setup_frequency_factors(self, frequency_factors, auto_build_lowfreq, build_factor):
        """Setup frequency factors and slice timesteps."""
        if frequency_factors is not None:
            assert (
                len(frequency_factors) == self.nf - 1
            ), "frequency_factors length must be nf-1"
            self.frequency_factors = list(map(int, frequency_factors))
        elif self.nf == 2 and auto_build_lowfreq:
            self.frequency_factors = [int(build_factor)]
        else:
            self.frequency_factors = None

        # Pre-compute slice positions if seq_lengths and frequency_factors are provided
        self.slice_timesteps: Optional[List[int]] = None
        if self.seq_lengths is not None and self.frequency_factors is not None:
            self.slice_timesteps = []
            for i in range(self.nf - 1):
                fac = int(self.frequency_factors[i])
                next_len = int(self.seq_lengths[i + 1])
                st = int(next_len / fac)  # floor
                self.slice_timesteps.append(max(0, st))

    def _create_model_layers(self, eff_input_sizes):
        """Create LSTM and linear layers based on configuration."""
        if self.pretrained_flag:
            # Use pretrained model specific layers
            self.linear1 = nn.ModuleList([
                nn.Linear(eff_input_sizes[i], self.linear1_size) for i in range(self.nf)
            ])
            self.linear2 = nn.ModuleList([
                nn.Linear(self.linear1_size, self.linear2_size) for i in range(self.nf)
            ])
            self.lstms = nn.ModuleList()
            if self.shared:
                self.lstms.append(nn.LSTM(self.hidden_sizes[0], self.hidden_sizes[0]))
            else:
                for i in range(self.nf):
                    self.lstms.append(nn.LSTM(self.linear2_size, self.hidden_sizes[i]))
        else:
            # Use default layers when no pretrained model is loaded
            self.input_linears = nn.ModuleList([
                nn.Linear(eff_input_sizes[i], self.hidden_sizes[i]) for i in range(self.nf)
            ])
            self.lstms = nn.ModuleList()
            if self.shared:
                self.lstms.append(nn.LSTM(self.hidden_sizes[0], self.hidden_sizes[0]))
            else:
                for i in range(self.nf):
                    self.lstms.append(nn.LSTM(self.hidden_sizes[i], self.hidden_sizes[i]))

        # Create head layers
        self.heads = nn.ModuleList()
        if self.shared:
            self.heads.append(nn.Linear(self.hidden_sizes[0], self.output_size))
        else:
            for i in range(self.nf):
                self.heads.append(nn.Linear(self.hidden_sizes[i], self.output_size))

    def _create_transfer_layers(self):
        """Create transfer projection layers between frequency branches."""
        self.transfer_h = nn.ModuleList()
        self.transfer_c = nn.ModuleList()
        for i in range(self.nf - 1):
            hs_i = self.hidden_sizes[i]
            hs_j = self.hidden_sizes[i + 1]

            if self.transfer_mode["h"] == "linear":
                self.transfer_h.append(nn.Linear(hs_i, hs_j))
            elif self.transfer_mode["h"] == "identity":
                assert hs_i == hs_j, "identity requires same hidden size"
                self.transfer_h.append(nn.Identity())
            else:
                self.transfer_h.append(None)

            if self.transfer_mode["c"] == "linear":
                self.transfer_c.append(nn.Linear(hs_i, hs_j))
            elif self.transfer_mode["c"] == "identity":
                assert hs_i == hs_j, "identity requires same hidden size"
                self.transfer_c.append(nn.Identity())
            else:
                self.transfer_c.append(None)

    def _load_pretrained_weights(self, pretrained_lstm_prefix, pretrained_head_prefix):
        """Load pretrained weights for the daily branch if specified."""
        if self.pretrained_day_path is None:
            return

        if not os.path.isfile(self.pretrained_day_path):
            warnings.warn(
                f"[MTSLSTM] Pretrained file not found: {self.pretrained_day_path}"
            )
            return

        if self.shared:
            warnings.warn(
                "[MTSLSTM] shared_mtslstm=True: skip daily-only pretrained load."
            )
            return

        if self.nf < 2:
            warnings.warn("[MTSLSTM] nf<2: no daily branch, skip pretrained load.")
            return

        try:
            state = torch.load(self.pretrained_day_path, map_location="cpu")
            if isinstance(state, dict):
                if "state_dict" in state:
                    state = state["state_dict"]
                elif "model" in state:
                    state = state["model"]

            self._load_daily_branch_weights(state, pretrained_lstm_prefix, pretrained_head_prefix)

        except Exception as e:
            warnings.warn(f"[MTSLSTM] Failed to load daily pretrained: {e}")

    def _load_daily_branch_weights(self, state, pretrained_lstm_prefix, pretrained_head_prefix):
        """Load weights for the daily LSTM and head from pretrained state."""
        day_lstm = self.lstms[1]
        day_head = self.heads[1]
        lstm_state = day_lstm.state_dict()
        head_state = day_head.state_dict()

        matched = skipped = shape_mismatch = 0

        def try_load(prefix: str, target_state: Dict[str, torch.Tensor]) -> None:
            nonlocal matched, skipped, shape_mismatch
            for k_pre, v in state.items():
                if not k_pre.startswith(prefix):
                    continue
                k = k_pre[len(prefix):]
                if k in target_state:
                    if target_state[k].shape == v.shape:
                        target_state[k].copy_(v)
                        matched += 1
                    else:
                        shape_mismatch += 1
                else:
                    skipped += 1

        try_load(pretrained_lstm_prefix, lstm_state)
        try_load(pretrained_head_prefix, head_state)

        # Load pretrained linearIn into linear2[1]
        self._load_linear_weights(state)

        # Fallback: try raw state_dict without prefix
        if matched == 0:
            for k, v in state.items():
                if k in lstm_state and lstm_state[k].shape == v.shape:
                    lstm_state[k].copy_(v)
                    matched += 1
                elif k in head_state and head_state[k].shape == v.shape:
                    head_state[k].copy_(v)
                    matched += 1
                else:
                    shape_mismatch += 1

        self._print_loading_debug_info(state, day_lstm, day_head)

        day_lstm.load_state_dict(lstm_state)
        day_head.load_state_dict(head_state)
        print(
            f"[MTSLSTM] Daily pretrained loaded: matched={matched}, "
            f"shape_mismatch={shape_mismatch}, skipped={skipped}"
        )

    def _load_linear_weights(self, state):
        """Load pretrained linear layer weights."""
        if "linearIn.weight" in state and "linearIn.bias" in state:
            linear_weight = state["linearIn.weight"]
            linear_bias = state["linearIn.bias"]

            target_linear = self.linear2[1]
            if (
                target_linear.weight.shape == linear_weight.shape
                and target_linear.bias.shape == linear_bias.shape
            ):
                target_linear.weight.data.copy_(linear_weight)
                target_linear.bias.data.copy_(linear_bias)
                print("[MTSLSTM] Pretrained linearIn loaded into linear2[1]")
            else:
                warnings.warn(
                    f"[MTSLSTM] linearIn shape mismatch: pretrained {linear_weight.shape}, "
                    f"current {target_linear.weight.shape}"
                )
        else:
            warnings.warn("[MTSLSTM] linearIn keys not found in pretrained state")

    def _print_loading_debug_info(self, state, day_lstm, day_head):
        """Print debug information about pretrained weight loading."""
        print("=== Pretrained keys ===")
        for k in list(state.keys())[:10]:
            print(k)

        print("\n=== Current LSTM keys ===")
        for k in list(day_lstm.state_dict().keys())[:10]:
            print(k)

        print("\n=== Current Head keys ===")
        for k in list(day_head.state_dict().keys())[:10]:
            print(k)

        print("Head.bias pretrained:", state["linearOut.bias"].shape)
        print("Head.bias current:", day_head.state_dict()["bias"].shape)

    def _append_one_hot(self, x: torch.Tensor, freq_idx: int) -> torch.Tensor:
        """Appends a one-hot frequency indicator to the input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (T, B, D).
            freq_idx (int): Frequency index to mark as 1 in the one-hot vector.

        Returns:
            torch.Tensor: Tensor of shape (T, B, D + nf), where `nf` is the
            number of frequency branches. The appended one-hot encodes the
            branch identity.
        """
        T, B, _ = x.shape
        oh = x.new_zeros((T, B, self.nf))
        oh[:, :, freq_idx] = 1
        return torch.cat([x, oh], dim=-1)

    def _run_lstm(
        self,
        x: torch.Tensor,
        lstm: nn.LSTM,
        head: nn.Linear,
        h0: Optional[torch.Tensor],
        c0: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Runs an LSTM followed by a linear head with optional initial states.

        Args:
            x (torch.Tensor): Input tensor of shape (T, B, D).
            lstm (nn.LSTM): LSTM module for this branch.
            head (nn.Linear): Linear output layer.
            h0 (torch.Tensor): Optional initial hidden state (1, B, H).
            c0 (torch.Tensor): Optional initial cell state (1, B, H).

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
                - y (torch.Tensor): Output sequence (T, B, O).
                - h_n (torch.Tensor): Final hidden state (1, B, H).
                - c_n (torch.Tensor): Final cell state (1, B, H).
        """
        out, (h_n, c_n) = (
            lstm(x, (h0, c0)) if (h0 is not None and c0 is not None) else lstm(x)
        )
        y = head(self.dropout(out))  # Project to output size
        return y, h_n, c_n

    def _aggregate_lowfreq(
        self,
        x_high: torch.Tensor,
        factor: int,
        agg_reduce: Literal["mean", "sum"],
        per_feature_aggs: Optional[List[Literal["mean", "sum"]]],
        truncate_incomplete: bool,
    ) -> torch.Tensor:
        """Aggregates high-frequency input into lower-frequency sequences.

        Args:
            x_high (torch.Tensor): Input tensor (T, B, D) at high frequency.
            factor (int): Aggregation factor (e.g., 24 for daily from hourly).
            agg_reduce (str): Default aggregation method, "mean" or "sum".
            per_feature_aggs (List[str] | None):
                Optional per-feature aggregation strategies ("mean"/"sum").
            truncate_incomplete (bool): If True, drop incomplete groups;
                if False, pad to make groups complete.

        Returns:
            torch.Tensor: Aggregated tensor of shape (T_low, B, D),
            where T_low = floor(T / factor) if truncate_incomplete,
            else ceil(T / factor).

        Raises:
            ValueError: If `agg_reduce` is not "mean" or "sum".
            AssertionError: If `per_feature_aggs` length mismatches feature dim.
        """
        T, B, D = x_high.shape
        if factor <= 1:
            return x_high
        if truncate_incomplete:
            T_trim = (T // factor) * factor
            xh = x_high[:T_trim]
            groups = xh.view(T_trim // factor, factor, B, D)
        else:
            pad = (factor - (T % factor)) % factor
            if pad > 0:
                pad_tensor = x_high.new_zeros((pad, B, D))
                xh = torch.cat([x_high, pad_tensor], dim=0)
            else:
                xh = x_high
            groups = xh.view(xh.shape[0] // factor, factor, B, D)

        if per_feature_aggs is None:
            if agg_reduce == "mean":
                return groups.mean(dim=1)
            elif agg_reduce == "sum":
                return groups.sum(dim=1)
            else:
                raise ValueError("agg_reduce must be 'mean' or 'sum'")

        assert len(per_feature_aggs) == D, "per_feature_aggs length must match D"
        mean_agg = groups.mean(dim=1)
        sum_agg = groups.sum(dim=1)
        mask_sum = x_high.new_tensor(
            [1.0 if a == "sum" else 0.0 for a in per_feature_aggs]
        ).view(1, 1, D)
        mask_mean = 1.0 - mask_sum
        return mean_agg * mask_mean + sum_agg * mask_sum

    def _multi_factor_to_high(self, f: int) -> int:
        """Computes cumulative factor from branch f to the highest frequency.

        Args:
            f (int): Branch index (0 = lowest frequency).

        Returns:
            int: Product of frequency factors from branch f to highest branch.

        Raises:
            RuntimeError: If `frequency_factors` is not defined.
        """
        if self.frequency_factors is None:
            raise RuntimeError("frequency_factors must be provided.")
        fac = 1
        for k in range(f, self.nf - 1):
            fac *= int(self.frequency_factors[k])
        return fac

    def _build_from_hourly(self, x_hour: torch.Tensor) -> List[torch.Tensor]:
        """Builds multi-frequency inputs from raw hourly features.

        Each branch selects features based on bucket assignment and aggregates
        them down to its frequency using frequency factors.

        Args:
            x_hour (torch.Tensor): Hourly input tensor (T_h, B, D).
                - T_h: number of hourly timesteps
                - B: batch size
                - D: feature dimension (must match `feature_buckets`)

        Returns:
            List[torch.Tensor]: List of tensors, one per branch,
            with shapes (T_f, B, D_f).

        Raises:
            AssertionError: If feature dimension mismatches `feature_buckets`.
        """
        assert self.feature_buckets is not None, "feature_buckets must be provided"
        T_h, B, D = x_hour.shape
        assert D == len(self.feature_buckets), "Mismatch between features and buckets"

        xs = []
        for f in range(self.nf):
            if self.down_agg_all:
                cols = [i for i in range(D) if self.feature_buckets[i] >= f]
            else:
                cols = [i for i in range(D) if self.feature_buckets[i] == f]

            if len(cols) == 0:
                # Insert placeholder if branch has no features
                x_sub = x_hour.new_zeros((T_h, B, 1))
                per_aggs = ["mean"]
            else:
                x_sub = x_hour[:, :, cols]
                per_aggs = None
                if self.per_feature_aggs_map is not None:
                    per_aggs = [self.per_feature_aggs_map[i] for i in cols]

            factor = self._multi_factor_to_high(f) if (self.nf >= 2) else 1
            x_f = self._aggregate_lowfreq(
                x_high=x_sub,
                factor=factor,
                agg_reduce=self.agg_reduce,
                per_feature_aggs=per_aggs,
                truncate_incomplete=self.truncate_incomplete,
            )
            xs.append(x_f)
        return xs

    def _get_slice_len_low(self, i: int, T_low: int, T_high: int) -> int:
        """Computes the slice length for a low-frequency branch.

        This method determines how many timesteps from the low-frequency input
        should be aligned with the high-frequency branch during state transfer.

        Priority:
            1. If `slice_timesteps` is precomputed, return the clamped value.
            2. Otherwise, fall back to heuristic estimation using ceil/floor.

        Args:
            i (int): Index of the low-frequency branch.
            T_low (int): Sequence length of the low-frequency input.
            T_high (int): Sequence length of the high-frequency input.

        Returns:
            int: Number of timesteps to slice from the low-frequency input,
            clamped between [0, T_low].

        Raises:
            RuntimeWarning: If `seq_lengths` and `frequency_factors` are not
            provided, a warning is issued since the fallback may not perfectly
            match NeuralHydrology's fixed slicing definition.
        """
        if self.slice_timesteps is not None:
            return max(0, min(self.slice_timesteps[i], T_low))
        if (not self._warned_slice_fallback) and self.slice_transfer:
            warnings.warn(
                "[MTSLSTM] 未提供 seq_lengths/frequency_factors,切片位置采用启发式估计(ceil/floor),"
                "与 NeuralHydrology 的固定切片定义不完全一致。建议提供这两个超参以完全对齐。",
                RuntimeWarning,
            )
            self._warned_slice_fallback = True
        factor = (
            self.build_factor
            if (
                self.nf == 2
                and self.feature_buckets is None
                and self.auto_build_lowfreq
            )
            else max(int(round(T_high / max(1, T_low))), 1)
        )
        if self.slice_use_ceil:
            slice_len_low = int((T_high + factor - 1) // factor)
        else:
            slice_len_low = int(T_high // factor)
        return max(0, min(slice_len_low, T_low))

    def _prepare_inputs(self, xs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """Prepare and validate input tensors for multi-frequency processing.

        Args:
            xs: Input tensors from forward method

        Returns:
            Tuple of validated input tensors for each frequency branch
        """
        # 新路径:单一小时输入 + feature_buckets
        if len(xs) == 1 and self.use_hourly_unified:
            return tuple(self._build_from_hourly(xs[0]))

        # 旧路径:单一高频输入 + 自动构造低频(仅两频)
        elif (
            len(xs) == 1
            and self.auto_build_lowfreq
            and self.nf == 2
            and self.feature_buckets is None
        ):
            x_high = xs[0]
            assert x_high.dim() == 3, "输入必须是 (time,batch,features)"
            x_low = self._aggregate_lowfreq(
                x_high=x_high,
                factor=self.build_factor,
                agg_reduce=self.agg_reduce,
                per_feature_aggs=self.per_feature_aggs,
                truncate_incomplete=self.truncate_incomplete,
            )
            return (x_low, x_high)

        return xs

    def _validate_inputs(self, xs: Tuple[torch.Tensor, ...]) -> None:
        """Validate input tensor dimensions and shapes.

        Args:
            xs: Input tensors to validate

        Raises:
            AssertionError: If inputs don't match expected configuration
        """
        assert len(xs) == self.nf, f"收到 {len(xs)} 个频率,但模型期望 {self.nf}"
        for i, x in enumerate(xs):
            assert x.dim() == 3, f"第 {i} 个输入必须是 (time,batch,features)"
            exp_d = self.base_input_sizes[i]
            assert (
                x.shape[-1] == exp_d
            ), f"第 {i} 个输入特征维 {x.shape[-1]} 与期望 {exp_d} 不一致"

    def _preprocess_branch_input(self, x: torch.Tensor, branch_idx: int) -> torch.Tensor:
        """Preprocess input for a specific frequency branch.

        Args:
            x: Input tensor for the branch
            branch_idx: Index of the frequency branch

        Returns:
            Preprocessed tensor ready for LSTM processing
        """
        # Add frequency one-hot encoding first if needed
        if self.add_freq1hot:
            x = self._append_one_hot(x, branch_idx)

        if self.pretrained_flag:
            x = self.linear1[branch_idx](x)
            x = F.relu(x)
            x = self.linear2[branch_idx](x)
            x = F.relu(x)
        else:
            # For non-pretrained mode, apply input linear layer
            x = self.input_linears[branch_idx](x)
            x = F.relu(x)

        return x

    def _get_branch_modules(self, branch_idx: int) -> Tuple[nn.LSTM, nn.Linear]:
        """Get LSTM and head modules for a specific branch.

        Args:
            branch_idx: Index of the frequency branch

        Returns:
            Tuple of (lstm_module, head_module) for the branch
        """
        lstm = self.lstms[0] if self.shared else self.lstms[branch_idx]
        head = self.heads[0] if self.shared else self.heads[branch_idx]
        return lstm, head

    def _initialize_transfer_states(self, device: torch.device, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Initialize hidden and cell states for transfer between branches.

        Args:
            device: Device to create tensors on
            batch_size: Batch size for state tensors

        Returns:
            Tuple of (h_transfer, c_transfer) initial states
        """
        H0 = self.hidden_sizes[0]
        h_transfer = torch.zeros(1, batch_size, H0, device=device)
        c_transfer = torch.zeros(1, batch_size, H0, device=device)
        return h_transfer, c_transfer

    def _update_transfer_states(
        self, 
        branch_idx: int, 
        h_state: torch.Tensor, 
        c_state: torch.Tensor,
        next_batch_size: int,
        device: torch.device
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Update transfer states for the next branch.

        Args:
            branch_idx: Current branch index
            h_state: Current hidden state
            c_state: Current cell state  
            next_batch_size: Batch size for next branch
            device: Device for tensor creation

        Returns:
            Updated (h_transfer, c_transfer) for next branch
        """
        if branch_idx >= self.nf - 1:
            return h_state, c_state

        Hn = self.hidden_sizes[branch_idx + 1]
        h_transfer = torch.zeros(1, next_batch_size, Hn, device=device)
        c_transfer = torch.zeros(1, next_batch_size, Hn, device=device)

        if self.transfer_h[branch_idx] is not None:
            h_transfer = self.transfer_h[branch_idx](h_state[0]).unsqueeze(0)
        if self.transfer_c[branch_idx] is not None:
            c_transfer = self.transfer_c[branch_idx](c_state[0]).unsqueeze(0)

        return h_transfer, c_transfer

    def _process_branch_with_slice_transfer(
        self,
        x_i: torch.Tensor,
        branch_idx: int,
        xs: Tuple[torch.Tensor, ...],
        lstm: nn.LSTM,
        head: nn.Linear,
        h_transfer: torch.Tensor,
        c_transfer: torch.Tensor,
        device: torch.device
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Process a branch with slice transfer enabled.

        Args:
            x_i: Input tensor for current branch
            branch_idx: Index of current branch
            xs: All input tensors
            lstm: LSTM module for current branch
            head: Head module for current branch
            h_transfer: Transfer hidden state
            c_transfer: Transfer cell state
            device: Device for tensor operations

        Returns:
            Tuple of (output, final_h_state, final_c_state)
        """
        T_low = x_i.shape[0]
        T_high = xs[branch_idx + 1].shape[0]
        slice_len_low = self._get_slice_len_low(branch_idx, T_low=T_low, T_high=T_high)

        if slice_len_low == 0:
            return self._run_lstm(x_i, lstm, head, h_transfer, c_transfer)
        else:
            # Process first part
            x_part1 = (
                x_i[:-slice_len_low] if slice_len_low < x_i.shape[0] else x_i[:0]
            )
            if x_part1.shape[0] > 0:
                y1, h1, c1 = self._run_lstm(x_part1, lstm, head, h_transfer, c_transfer)
            else:
                y1 = x_i.new_zeros((0, x_i.shape[1], self.output_size))
                h1, c1 = h_transfer, c_transfer

            # Process second part
            x_part2 = x_i[-slice_len_low:] if slice_len_low > 0 else x_i[:0]
            if x_part2.shape[0] > 0:
                y2, h_final, c_final = self._run_lstm(x_part2, lstm, head, h1, c1)
                y_all = torch.cat([y1, y2], dim=0)
            else:
                y_all = y1
                h_final, c_final = h1, c1

            return y_all, h_final, c_final

    def _process_single_branch(
        self,
        branch_idx: int,
        xs: Tuple[torch.Tensor, ...],
        h_transfer: torch.Tensor,
        c_transfer: torch.Tensor,
        device: torch.device
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Process a single frequency branch.

        Args:
            branch_idx: Index of the branch to process
            xs: All input tensors
            h_transfer: Transfer hidden state
            c_transfer: Transfer cell state
            device: Device for tensor operations

        Returns:
            Tuple of (branch_output, final_h_state, final_c_state)
        """
        # Preprocess input
        x_i = self._preprocess_branch_input(xs[branch_idx], branch_idx)

        # Get branch modules
        lstm_i, head_i = self._get_branch_modules(branch_idx)

        # Process with or without slice transfer
        if (branch_idx < self.nf - 1) and self.slice_transfer:
            return self._process_branch_with_slice_transfer(
                x_i, branch_idx, xs, lstm_i, head_i, h_transfer, c_transfer, device
            )
        else:
            return self._run_lstm(x_i, lstm_i, head_i, h_transfer, c_transfer)

    def forward(
        self, *xs: torch.Tensor, return_all: Optional[bool] = None, **kwargs: Any
    ) -> Union[Dict[str, torch.Tensor], torch.Tensor]:
        """Forward pass of the Multi-Time-Scale LSTM (MTSLSTM).

        This method supports three types of input pipelines:

        1. New unified hourly input (recommended):
            - Pass a single hourly tensor (T, B, D).
            - Requires `feature_buckets` set in the constructor.
            - Internally calls `_build_from_hourly()` to build multi-scale inputs.

        2. Legacy path A (two-frequency auto build):
            - Pass a single high-frequency tensor (daily).
            - Requires `auto_build_lowfreq=True` and `nf=2`.
            - Automatically constructs the low-frequency branch.

        3. Legacy path B (manual multi-frequency input):
            - Pass a tuple of tensors: (x_f0, x_f1, ..., x_f{nf-1}).

        Args:
            *xs (torch.Tensor): Input tensors. Can be:
                - One hourly tensor of shape (T, B, D).
                - One daily tensor (legacy auto-build).
                - A tuple of nf tensors, each shaped (T_f, B, D_f).
            return_all (Optional[bool], default=None): Whether to return outputs
                from all frequency branches. If None, uses the class default.
            **kwargs: Additional unused keyword arguments.

        Returns:
            Dict[str, torch.Tensor] | torch.Tensor:
                - If `return_all=True`: Dictionary mapping branch names to outputs,
                  e.g. {"f0": y_low, "f1": y_mid, "f2": y_high}.
                - If `return_all=False`: Only returns the highest-frequency output
                  tensor of shape (T_high, B, output_size).

        Raises:
            AssertionError: If the number of provided inputs does not match `nf`,
                or if input feature dimensions do not match the expected
                configuration.
        """
        if return_all is None:
            return_all = self.return_all_default

        # Prepare and validate inputs
        xs = self._prepare_inputs(xs)
        self._validate_inputs(xs)

        # Initialize processing state
        device = xs[0].device
        batch_size = xs[0].shape[1]
        h_transfer, c_transfer = self._initialize_transfer_states(device, batch_size)

        outputs: Dict[str, torch.Tensor] = {}

        # Process each frequency branch
        for i in range(self.nf):
            y_i, h_i, c_i = self._process_single_branch(i, xs, h_transfer, c_transfer, device)
            outputs[f"f{i}"] = y_i

            # Update transfer states for next branch
            if i < self.nf - 1:
                next_batch_size = xs[i + 1].shape[1]
                h_transfer, c_transfer = self._update_transfer_states(
                    i, h_i, c_i, next_batch_size, device
                )

        return outputs if return_all else outputs[f"f{self.nf - 1}"]

__init__(self, input_sizes=None, hidden_sizes=128, output_size=1, shared_mtslstm=False, transfer='linear', dropout=0.0, return_all=False, add_freq_one_hot_if_shared=True, auto_build_lowfreq=False, build_factor=7, agg_reduce='mean', per_feature_aggs=None, truncate_incomplete=True, slice_transfer=True, slice_use_ceil=True, seq_lengths=None, frequency_factors=None, feature_buckets=None, per_feature_aggs_map=None, down_aggregate_all_to_each_branch=True, pretrained_day_path=None, pretrained_lstm_prefix=None, pretrained_head_prefix=None, pretrained_flag=False, linear1_size=None, linear2_size=None) special

Initializes an MTSLSTM model.

Parameters:

Name Type Description Default
input_sizes Union[int, List[int]]

Input feature dimension(s). Can be: * int: shared across all frequency branches * list: per-frequency input sizes * None: inferred from feature_buckets

None
hidden_sizes Union[int, List[int]]

Hidden dimension(s) for each LSTM branch.

128
output_size int

Output dimension per timestep.

1
shared_mtslstm bool

If True, all frequency branches share one LSTM.

False
transfer Union[NoneType, str, Dict[str, Optional[Literal['identity', 'linear']]]]

Hidden state transfer mode between frequencies. * None: no transfer * "identity": copy states directly (same dim required) * "linear": learn linear projection between dims

'linear'
dropout float

Dropout probability applied before heads.

0.0
return_all bool

If True, return all branch outputs (dict f0,f1,...). If False, return only the highest-frequency output.

False
add_freq_one_hot_if_shared bool

If True and shared_mtslstm=True, append frequency one-hot encoding to inputs.

True
auto_build_lowfreq bool

Legacy 2-frequency path (high->low).

False
build_factor int

Aggregation factor for auto low-frequency.

7
agg_reduce Literal['mean', 'sum']

Aggregation method for downsampling ("mean" or "sum").

'mean'
per_feature_aggs Optional[List[Literal['mean', 'sum']]]

Optional list of per-feature aggregation methods.

None
truncate_incomplete bool

Whether to drop remainder timesteps when aggregating (vs. zero-padding).

True
slice_transfer bool

If True, transfer LSTM states at slice boundaries computed by seq_lengths × frequency_factors.

True
slice_use_ceil bool

If True, use ceil for slice length calculation.

True
seq_lengths Optional[List[int]]

Per-frequency sequence lengths [low, ..., high].

None
frequency_factors Optional[List[int]]

Multipliers between adjacent frequencies. Example: [7,24] means week->day ×7, day->hour ×24.

None
feature_buckets Optional[List[int]]

Per-feature frequency assignment (len = D). 0 = lowest (week), nf-1 = highest (hour).

None
per_feature_aggs_map Optional[List[Literal['mean', 'sum']]]

Per-feature aggregation method ("mean"/"sum").

None
down_aggregate_all_to_each_branch bool

If True, branch f includes all features with bucket >= f (down-aggregate); else only == f.

True
pretrained_day_path Optional[str]

Optional path to pretrained checkpoint. If set, loads weights for the daily (f1) LSTM and head.

None

Exceptions:

Type Description
AssertionError

If configuration is inconsistent.

Source code in torchhydro/models/mtslstm.py
def __init__(
    self,
    input_sizes: Union[int, List[int], None] = None,
    hidden_sizes: Union[int, List[int]] = 128,
    output_size: int = 1,
    shared_mtslstm: bool = False,
    transfer: Union[
        None, str, Dict[str, Optional[Literal["identity", "linear"]]]
    ] = "linear",
    dropout: float = 0.0,
    return_all: bool = False,
    add_freq_one_hot_if_shared: bool = True,
    auto_build_lowfreq: bool = False,
    build_factor: int = 7,
    agg_reduce: Literal["mean", "sum"] = "mean",
    per_feature_aggs: Optional[List[Literal["mean", "sum"]]] = None,
    truncate_incomplete: bool = True,
    slice_transfer: bool = True,
    slice_use_ceil: bool = True,
    seq_lengths: Optional[List[int]] = None,
    frequency_factors: Optional[List[int]] = None,
    feature_buckets: Optional[List[int]] = None,
    per_feature_aggs_map: Optional[List[Literal["mean", "sum"]]] = None,
    down_aggregate_all_to_each_branch: bool = True,
    pretrained_day_path: Optional[str] = None,
    pretrained_lstm_prefix: Optional[str] = None,
    pretrained_head_prefix: Optional[str] = None,
    pretrained_flag: bool = False,
    linear1_size: Optional[int] = None,
    linear2_size: Optional[int] = None
):
    """Initializes an MTSLSTM model.

    Args:
        input_sizes: Input feature dimension(s). Can be:
            * int: shared across all frequency branches
            * list: per-frequency input sizes
            * None: inferred from `feature_buckets`
        hidden_sizes: Hidden dimension(s) for each LSTM branch.
        output_size: Output dimension per timestep.
        shared_mtslstm: If True, all frequency branches share one LSTM.
        transfer: Hidden state transfer mode between frequencies.
            * None: no transfer
            * "identity": copy states directly (same dim required)
            * "linear": learn linear projection between dims
        dropout: Dropout probability applied before heads.
        return_all: If True, return all branch outputs (dict f0,f1,...).
            If False, return only the highest-frequency output.
        add_freq_one_hot_if_shared: If True and `shared_mtslstm=True`,
            append frequency one-hot encoding to inputs.
        auto_build_lowfreq: Legacy 2-frequency path (high->low).
        build_factor: Aggregation factor for auto low-frequency.
        agg_reduce: Aggregation method for downsampling ("mean" or "sum").
        per_feature_aggs: Optional list of per-feature aggregation methods.
        truncate_incomplete: Whether to drop remainder timesteps when
            aggregating (vs. zero-padding).
        slice_transfer: If True, transfer LSTM states at slice boundaries
            computed by seq_lengths × frequency_factors.
        slice_use_ceil: If True, use ceil for slice length calculation.
        seq_lengths: Per-frequency sequence lengths [low, ..., high].
        frequency_factors: Multipliers between adjacent frequencies.
            Example: [7,24] means week->day ×7, day->hour ×24.
        feature_buckets: Per-feature frequency assignment (len = D).
            0 = lowest (week), nf-1 = highest (hour).
        per_feature_aggs_map: Per-feature aggregation method ("mean"/"sum").
        down_aggregate_all_to_each_branch: If True, branch f includes all
            features with bucket >= f (down-aggregate); else only == f.
        pretrained_day_path: Optional path to pretrained checkpoint. If set,
            loads weights for the daily (f1) LSTM and head.

    Raises:
        AssertionError: If configuration is inconsistent.
    """
    super().__init__()

    # Store configuration parameters
    self.pretrained_day_path = pretrained_day_path
    self.pretrained_flag = pretrained_flag
    self.linear1_size = linear1_size
    self.linear2_size = linear2_size
    self.output_size = output_size
    self.shared = shared_mtslstm
    self.return_all_default = return_all
    self.feature_buckets = list(feature_buckets) if feature_buckets is not None else None
    self.per_feature_aggs_map = list(per_feature_aggs_map) if per_feature_aggs_map is not None else None
    self.down_agg_all = down_aggregate_all_to_each_branch
    self.auto_build_lowfreq = auto_build_lowfreq

    # Aggregation and slicing parameters
    assert build_factor >= 2, "build_factor must be >=2"
    self.build_factor = int(build_factor)
    self.agg_reduce = agg_reduce
    self.per_feature_aggs = per_feature_aggs
    self.truncate_incomplete = truncate_incomplete
    self.slice_transfer = slice_transfer
    self.slice_use_ceil = slice_use_ceil
    self.seq_lengths = list(seq_lengths) if seq_lengths is not None else None
    self._warned_slice_fallback = False

    # Setup frequency configuration
    self._setup_frequency_config(feature_buckets, input_sizes, auto_build_lowfreq)

    # Validate seq_lengths
    if self.seq_lengths is not None:
        assert len(self.seq_lengths) == self.nf, "seq_lengths length must match nf"

    # Setup input sizes for each branch
    self.base_input_sizes = self._setup_input_sizes(
        self.feature_buckets, input_sizes, self.down_agg_all
    )

    # Setup hidden layer sizes
    if isinstance(hidden_sizes, int):
        self.hidden_sizes = [hidden_sizes] * self.nf
    else:
        assert len(hidden_sizes) == self.nf, "hidden_sizes length mismatch"
        self.hidden_sizes = list(hidden_sizes)

    # Setup frequency one-hot encoding
    self.add_freq1hot = add_freq_one_hot_if_shared and self.shared

    # Setup transfer configuration
    self._setup_transfer_config(transfer)

    # Setup frequency factors and slice timesteps
    self._setup_frequency_factors(frequency_factors, auto_build_lowfreq, self.build_factor)

    # Calculate effective input sizes (including one-hot if needed)
    eff_input_sizes = self.base_input_sizes[:]
    if self.add_freq1hot:
        eff_input_sizes = [d + self.nf for d in eff_input_sizes]

    if self.shared and len(set(eff_input_sizes)) != 1:
        raise ValueError("shared_mtslstm=True requires equal input sizes.")

    # Create model layers
    self._create_model_layers(eff_input_sizes)

    # Create transfer layers
    self._create_transfer_layers()

    # Setup dropout
    self.dropout = nn.Dropout(p=dropout)

    # Set unified hourly aggregation flag
    self.use_hourly_unified = self.feature_buckets is not None

    # Load pretrained weights if specified
    self._load_pretrained_weights(pretrained_lstm_prefix, pretrained_head_prefix)

forward(self, *xs, *, return_all=None, **kwargs)

Forward pass of the Multi-Time-Scale LSTM (MTSLSTM).

This method supports three types of input pipelines:

  1. New unified hourly input (recommended):

    • Pass a single hourly tensor (T, B, D).
    • Requires feature_buckets set in the constructor.
    • Internally calls _build_from_hourly() to build multi-scale inputs.
  2. Legacy path A (two-frequency auto build):

    • Pass a single high-frequency tensor (daily).
    • Requires auto_build_lowfreq=True and nf=2.
    • Automatically constructs the low-frequency branch.
  3. Legacy path B (manual multi-frequency input):

    • Pass a tuple of tensors: (x_f0, x_f1, ..., x_f{nf-1}).

Parameters:

Name Type Description Default
*xs torch.Tensor

Input tensors. Can be: - One hourly tensor of shape (T, B, D). - One daily tensor (legacy auto-build). - A tuple of nf tensors, each shaped (T_f, B, D_f).

()
return_all Optional[bool], default=None

Whether to return outputs from all frequency branches. If None, uses the class default.

None
**kwargs Any

Additional unused keyword arguments.

{}

Returns:

Type Description
Dict[str, torch.Tensor] | torch.Tensor
  • If return_all=True: Dictionary mapping branch names to outputs, e.g. {"f0": y_low, "f1": y_mid, "f2": y_high}.
    • If return_all=False: Only returns the highest-frequency output tensor of shape (T_high, B, output_size).

Exceptions:

Type Description
AssertionError

If the number of provided inputs does not match nf, or if input feature dimensions do not match the expected configuration.

Source code in torchhydro/models/mtslstm.py
def forward(
    self, *xs: torch.Tensor, return_all: Optional[bool] = None, **kwargs: Any
) -> Union[Dict[str, torch.Tensor], torch.Tensor]:
    """Forward pass of the Multi-Time-Scale LSTM (MTSLSTM).

    This method supports three types of input pipelines:

    1. New unified hourly input (recommended):
        - Pass a single hourly tensor (T, B, D).
        - Requires `feature_buckets` set in the constructor.
        - Internally calls `_build_from_hourly()` to build multi-scale inputs.

    2. Legacy path A (two-frequency auto build):
        - Pass a single high-frequency tensor (daily).
        - Requires `auto_build_lowfreq=True` and `nf=2`.
        - Automatically constructs the low-frequency branch.

    3. Legacy path B (manual multi-frequency input):
        - Pass a tuple of tensors: (x_f0, x_f1, ..., x_f{nf-1}).

    Args:
        *xs (torch.Tensor): Input tensors. Can be:
            - One hourly tensor of shape (T, B, D).
            - One daily tensor (legacy auto-build).
            - A tuple of nf tensors, each shaped (T_f, B, D_f).
        return_all (Optional[bool], default=None): Whether to return outputs
            from all frequency branches. If None, uses the class default.
        **kwargs: Additional unused keyword arguments.

    Returns:
        Dict[str, torch.Tensor] | torch.Tensor:
            - If `return_all=True`: Dictionary mapping branch names to outputs,
              e.g. {"f0": y_low, "f1": y_mid, "f2": y_high}.
            - If `return_all=False`: Only returns the highest-frequency output
              tensor of shape (T_high, B, output_size).

    Raises:
        AssertionError: If the number of provided inputs does not match `nf`,
            or if input feature dimensions do not match the expected
            configuration.
    """
    if return_all is None:
        return_all = self.return_all_default

    # Prepare and validate inputs
    xs = self._prepare_inputs(xs)
    self._validate_inputs(xs)

    # Initialize processing state
    device = xs[0].device
    batch_size = xs[0].shape[1]
    h_transfer, c_transfer = self._initialize_transfer_states(device, batch_size)

    outputs: Dict[str, torch.Tensor] = {}

    # Process each frequency branch
    for i in range(self.nf):
        y_i, h_i, c_i = self._process_single_branch(i, xs, h_transfer, c_transfer, device)
        outputs[f"f{i}"] = y_i

        # Update transfer states for next branch
        if i < self.nf - 1:
            next_batch_size = xs[i + 1].shape[1]
            h_transfer, c_transfer = self._update_transfer_states(
                i, h_i, c_i, next_batch_size, device
            )

    return outputs if return_all else outputs[f"f{self.nf - 1}"]

seq2seq

Author: Wenyu Ouyang Date: 2024-04-17 12:32:26 LastEditTime: 2024-11-05 19:10:02 LastEditors: Wenyu Ouyang Description: FilePath: orchhydro orchhydro\models\seq2seq.py Copyright (c) 2021-2024 Wenyu Ouyang. All rights reserved.

AdditiveAttention (Module)

Source code in torchhydro/models/seq2seq.py
class AdditiveAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(AdditiveAttention, self).__init__()
        self.W_q = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W_k = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.v = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, encoder_outputs, hidden):
        seq_len = encoder_outputs.shape[1]
        hidden_transformed = self.W_q(hidden).repeat(seq_len, 1, 1).transpose(0, 1)
        encoder_outputs_transformed = self.W_k(encoder_outputs)
        combined = torch.tanh(hidden_transformed + encoder_outputs_transformed)
        scores = self.v(combined).squeeze(2)
        return F.softmax(scores, dim=1)

forward(self, encoder_outputs, hidden)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/seq2seq.py
def forward(self, encoder_outputs, hidden):
    seq_len = encoder_outputs.shape[1]
    hidden_transformed = self.W_q(hidden).repeat(seq_len, 1, 1).transpose(0, 1)
    encoder_outputs_transformed = self.W_k(encoder_outputs)
    combined = torch.tanh(hidden_transformed + encoder_outputs_transformed)
    scores = self.v(combined).squeeze(2)
    return F.softmax(scores, dim=1)

Attention (Module)

Source code in torchhydro/models/seq2seq.py
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_dim * 2, 1, bias=False)

    def forward(self, encoder_outputs, hidden):
        seq_len = encoder_outputs.shape[1]
        hidden = hidden.repeat(seq_len, 1, 1).transpose(0, 1)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        return F.softmax(energy.squeeze(2), dim=1)

forward(self, encoder_outputs, hidden)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/seq2seq.py
def forward(self, encoder_outputs, hidden):
    seq_len = encoder_outputs.shape[1]
    hidden = hidden.repeat(seq_len, 1, 1).transpose(0, 1)
    energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
    return F.softmax(energy.squeeze(2), dim=1)

DataEnhancedModel (GeneralSeq2Seq)

Source code in torchhydro/models/seq2seq.py
class DataEnhancedModel(GeneralSeq2Seq):
    def __init__(self, hidden_length, **kwargs):
        super(DataEnhancedModel, self).__init__(**kwargs)
        self.lstm = nn.LSTM(1, hidden_length, num_layers=1, batch_first=True)
        self.fc = nn.Linear(hidden_length, 6)

    def forward(self, *src):
        src1, src2, token = src
        processed_src1 = torch.unsqueeze(src1[:, :, 0], dim=2)
        out_src1, _ = self.lstm(processed_src1)
        out_src1 = self.fc(out_src1)
        combined_input = torch.cat((out_src1, src1[:, :, 1:]), dim=2)
        return super(DataEnhancedModel, self).forward(combined_input, src2, token)

forward(self, *src)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/seq2seq.py
def forward(self, *src):
    src1, src2, token = src
    processed_src1 = torch.unsqueeze(src1[:, :, 0], dim=2)
    out_src1, _ = self.lstm(processed_src1)
    out_src1 = self.fc(out_src1)
    combined_input = torch.cat((out_src1, src1[:, :, 1:]), dim=2)
    return super(DataEnhancedModel, self).forward(combined_input, src2, token)

DataFusionModel (DataEnhancedModel)

Source code in torchhydro/models/seq2seq.py
class DataFusionModel(DataEnhancedModel):
    def __init__(self, input_dim, **kwargs):
        super(DataFusionModel, self).__init__(**kwargs)
        self.input_dim = input_dim

        self.fusion_layer = nn.Conv1d(
            in_channels=input_dim, out_channels=1, kernel_size=1
        )

    def forward(self, *src):
        src1, src2, token = src
        if self.input_dim == 3:
            processed_src1 = self.fusion_layer(
                src1[:, :, 0:3].permute(0, 2, 1)
            ).permute(0, 2, 1)
            combined_input = torch.cat((processed_src1, src1[:, :, 3:]), dim=2)
        else:
            processed_src1 = self.fusion_layer(
                src1[:, :, 0:2].permute(0, 2, 1)
            ).permute(0, 2, 1)
            combined_input = torch.cat((processed_src1, src1[:, :, 2:]), dim=2)

        return super(DataFusionModel, self).forward(combined_input, src2, token)

forward(self, *src)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/seq2seq.py
def forward(self, *src):
    src1, src2, token = src
    if self.input_dim == 3:
        processed_src1 = self.fusion_layer(
            src1[:, :, 0:3].permute(0, 2, 1)
        ).permute(0, 2, 1)
        combined_input = torch.cat((processed_src1, src1[:, :, 3:]), dim=2)
    else:
        processed_src1 = self.fusion_layer(
            src1[:, :, 0:2].permute(0, 2, 1)
        ).permute(0, 2, 1)
        combined_input = torch.cat((processed_src1, src1[:, :, 2:]), dim=2)

    return super(DataFusionModel, self).forward(combined_input, src2, token)

Decoder (Module)

Source code in torchhydro/models/seq2seq.py
class Decoder(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, num_layers=1, dropout=0.3):
        super(Decoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.pre_fc = nn.Linear(input_dim, hidden_dim)
        self.pre_relu = nn.ReLU()
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers)
        self.dropout = nn.Dropout(dropout)
        self.fc_out = nn.Linear(hidden_dim, output_dim)

    def forward(self, input, hidden, cell):
        x0 = self.pre_fc(input)
        x1 = self.pre_relu(x0)
        output_, (hidden_, cell_) = self.lstm(x1, (hidden, cell))
        output_dr = self.dropout(output_)
        output = self.fc_out(output_dr)
        return output, hidden_, cell_

forward(self, input, hidden, cell)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/seq2seq.py
def forward(self, input, hidden, cell):
    x0 = self.pre_fc(input)
    x1 = self.pre_relu(x0)
    output_, (hidden_, cell_) = self.lstm(x1, (hidden, cell))
    output_dr = self.dropout(output_)
    output = self.fc_out(output_dr)
    return output, hidden_, cell_

DotProductAttention (Module)

Source code in torchhydro/models/seq2seq.py
class DotProductAttention(nn.Module):
    def __init__(self, dropout=0.1):
        super(DotProductAttention, self).__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, encoder_outputs, hidden):
        hidden_dim = encoder_outputs.shape[2]
        hidden_expanded = hidden.unsqueeze(1)
        scores = torch.bmm(
            hidden_expanded, encoder_outputs.transpose(1, 2)
        ) / math.sqrt(hidden_dim)
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        return attention_weights.squeeze(1)

forward(self, encoder_outputs, hidden)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/seq2seq.py
def forward(self, encoder_outputs, hidden):
    hidden_dim = encoder_outputs.shape[2]
    hidden_expanded = hidden.unsqueeze(1)
    scores = torch.bmm(
        hidden_expanded, encoder_outputs.transpose(1, 2)
    ) / math.sqrt(hidden_dim)
    attention_weights = F.softmax(scores, dim=-1)
    attention_weights = self.dropout(attention_weights)
    return attention_weights.squeeze(1)

Encoder (Module)

Source code in torchhydro/models/seq2seq.py
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1, dropout=0.3):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.pre_fc = nn.Linear(input_dim, hidden_dim)
        self.pre_relu = nn.ReLU()
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # a nonlinear layer to transform the input
        x0 = self.pre_fc(x)
        x1 = self.pre_relu(x0)
        # the LSTM layer
        outputs_, (hidden, cell) = self.lstm(x1)
        # a dropout layer
        dr_outputs = self.dropout(outputs_)
        # final linear layer
        outputs = self.fc(dr_outputs)
        return outputs, hidden, cell

forward(self, x)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/seq2seq.py
def forward(self, x):
    # a nonlinear layer to transform the input
    x0 = self.pre_fc(x)
    x1 = self.pre_relu(x0)
    # the LSTM layer
    outputs_, (hidden, cell) = self.lstm(x1)
    # a dropout layer
    dr_outputs = self.dropout(outputs_)
    # final linear layer
    outputs = self.fc(dr_outputs)
    return outputs, hidden, cell

GeneralSeq2Seq (Module)

Source code in torchhydro/models/seq2seq.py
class GeneralSeq2Seq(nn.Module):
    def __init__(
        self,
        en_input_size,
        de_input_size,
        output_size,
        hidden_size,
        forecast_length,
        hindcast_output_window=0,
        teacher_forcing_ratio=0.5,
    ):
        """General Seq2Seq model

        Parameters
        ----------
        en_input_size : _type_
            the size of the input of the encoder
        de_input_size : _type_
            the size of the input of the decoder
        output_size : _type_
            the size of the output, same for encoder and decoder
        hidden_size : _type_
            the size of the hidden state of LSTMs
        forecast_length : _type_
            the length of the forecast, i.e., the periods of decoder outputs
        hindcast_output_window : int, optional
            the encoder's final several outputs in the final output;
            default is 0 which means no encoder output is included in the final output;
        teacher_forcing_ratio : float, optional
            the probability of using teacher forcing
        """
        super(GeneralSeq2Seq, self).__init__()
        self.trg_len = forecast_length
        self.hindcast_output_window = hindcast_output_window
        self.teacher_forcing_ratio = teacher_forcing_ratio
        self.output_size = output_size
        self.encoder = Encoder(
            input_dim=en_input_size, hidden_dim=hidden_size, output_dim=output_size
        )
        self.decoder = Decoder(
            input_dim=de_input_size, hidden_dim=hidden_size, output_dim=output_size
        )
        self.transfer = StateTransferNetwork(hidden_dim=hidden_size)

    def _teacher_forcing_preparation(self, trgs):
        # teacher forcing preparation
        valid_mask = ~torch.isnan(trgs)
        random_vals = torch.rand_like(valid_mask, dtype=torch.float)
        return (random_vals < self.teacher_forcing_ratio) * valid_mask

    def forward(self, *src):
        if len(src) == 3:
            encoder_input, decoder_input, trgs = src
        else:
            encoder_input, decoder_input = src
            device = decoder_input.device
            trgs = torch.full(
                (
                    self.hindcast_output_window + self.trg_len,  # seq
                    decoder_input.shape[1],  # batch_size
                    self.output_size,  # features
                ),
                float("nan"),
            ).to(device)
        trgs_q = trgs[:, :, :1]
        trgs_s = trgs[:, :, 1:]
        trgs = torch.cat((trgs_s, trgs_q), dim=2)  # sq
        encoder_outputs, hidden_, cell_ = self.encoder(encoder_input)  # sq
        hidden, cell = self.transfer(hidden_, cell_)
        outputs = []
        prev_output = encoder_outputs[-1, :, :].unsqueeze(0)  # sq
        _, batch_size, _ = decoder_input.size()

        outputs = torch.zeros(self.trg_len, batch_size, self.output_size).to(
            decoder_input.device
        )
        use_teacher_forcing = self._teacher_forcing_preparation(trgs)
        for t in range(self.trg_len):
            pc = decoder_input[t : t + 1, :, :]  # sq
            obs = trgs[self.hindcast_output_window + t, :, :].unsqueeze(0)  # sq
            safe_obs = torch.where(torch.isnan(obs), torch.zeros_like(obs), obs)
            prev_output = torch.where(  # sq
                use_teacher_forcing[t : t + 1, :, :],
                safe_obs,
                prev_output,
            )
            current_input = torch.cat((pc, prev_output), dim=2)  # pcsq
            output, hidden, cell = self.decoder(current_input, hidden, cell)
            outputs[t, :, :] = output.squeeze(0)  # sq
        if self.hindcast_output_window > 0:
            prec_outputs = encoder_outputs[-self.hindcast_output_window :, :, :]
            outputs = torch.cat((prec_outputs, outputs), dim=0)
        outputs_s = outputs[:, :, :1]
        outputs_q = outputs[:, :, 1:]
        outputs = torch.cat((outputs_q, outputs_s), dim=2)  # qs
        return outputs

__init__(self, en_input_size, de_input_size, output_size, hidden_size, forecast_length, hindcast_output_window=0, teacher_forcing_ratio=0.5) special

General Seq2Seq model

Parameters

en_input_size : type the size of the input of the encoder de_input_size : type the size of the input of the decoder output_size : type the size of the output, same for encoder and decoder hidden_size : type the size of the hidden state of LSTMs forecast_length : type the length of the forecast, i.e., the periods of decoder outputs hindcast_output_window : int, optional the encoder's final several outputs in the final output; default is 0 which means no encoder output is included in the final output; teacher_forcing_ratio : float, optional the probability of using teacher forcing

Source code in torchhydro/models/seq2seq.py
def __init__(
    self,
    en_input_size,
    de_input_size,
    output_size,
    hidden_size,
    forecast_length,
    hindcast_output_window=0,
    teacher_forcing_ratio=0.5,
):
    """General Seq2Seq model

    Parameters
    ----------
    en_input_size : _type_
        the size of the input of the encoder
    de_input_size : _type_
        the size of the input of the decoder
    output_size : _type_
        the size of the output, same for encoder and decoder
    hidden_size : _type_
        the size of the hidden state of LSTMs
    forecast_length : _type_
        the length of the forecast, i.e., the periods of decoder outputs
    hindcast_output_window : int, optional
        the encoder's final several outputs in the final output;
        default is 0 which means no encoder output is included in the final output;
    teacher_forcing_ratio : float, optional
        the probability of using teacher forcing
    """
    super(GeneralSeq2Seq, self).__init__()
    self.trg_len = forecast_length
    self.hindcast_output_window = hindcast_output_window
    self.teacher_forcing_ratio = teacher_forcing_ratio
    self.output_size = output_size
    self.encoder = Encoder(
        input_dim=en_input_size, hidden_dim=hidden_size, output_dim=output_size
    )
    self.decoder = Decoder(
        input_dim=de_input_size, hidden_dim=hidden_size, output_dim=output_size
    )
    self.transfer = StateTransferNetwork(hidden_dim=hidden_size)

forward(self, *src)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/seq2seq.py
def forward(self, *src):
    if len(src) == 3:
        encoder_input, decoder_input, trgs = src
    else:
        encoder_input, decoder_input = src
        device = decoder_input.device
        trgs = torch.full(
            (
                self.hindcast_output_window + self.trg_len,  # seq
                decoder_input.shape[1],  # batch_size
                self.output_size,  # features
            ),
            float("nan"),
        ).to(device)
    trgs_q = trgs[:, :, :1]
    trgs_s = trgs[:, :, 1:]
    trgs = torch.cat((trgs_s, trgs_q), dim=2)  # sq
    encoder_outputs, hidden_, cell_ = self.encoder(encoder_input)  # sq
    hidden, cell = self.transfer(hidden_, cell_)
    outputs = []
    prev_output = encoder_outputs[-1, :, :].unsqueeze(0)  # sq
    _, batch_size, _ = decoder_input.size()

    outputs = torch.zeros(self.trg_len, batch_size, self.output_size).to(
        decoder_input.device
    )
    use_teacher_forcing = self._teacher_forcing_preparation(trgs)
    for t in range(self.trg_len):
        pc = decoder_input[t : t + 1, :, :]  # sq
        obs = trgs[self.hindcast_output_window + t, :, :].unsqueeze(0)  # sq
        safe_obs = torch.where(torch.isnan(obs), torch.zeros_like(obs), obs)
        prev_output = torch.where(  # sq
            use_teacher_forcing[t : t + 1, :, :],
            safe_obs,
            prev_output,
        )
        current_input = torch.cat((pc, prev_output), dim=2)  # pcsq
        output, hidden, cell = self.decoder(current_input, hidden, cell)
        outputs[t, :, :] = output.squeeze(0)  # sq
    if self.hindcast_output_window > 0:
        prec_outputs = encoder_outputs[-self.hindcast_output_window :, :, :]
        outputs = torch.cat((prec_outputs, outputs), dim=0)
    outputs_s = outputs[:, :, :1]
    outputs_q = outputs[:, :, 1:]
    outputs = torch.cat((outputs_q, outputs_s), dim=2)  # qs
    return outputs

StateTransferNetwork (Module)

Source code in torchhydro/models/seq2seq.py
class StateTransferNetwork(nn.Module):
    def __init__(self, hidden_dim):
        super(StateTransferNetwork, self).__init__()
        self.fc_hidden = nn.Linear(hidden_dim, hidden_dim)
        self.fc_cell = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, hidden, cell):
        transfer_hidden = torch.tanh(self.fc_hidden(hidden))
        transfer_cell = self.fc_cell(cell)
        return transfer_hidden, transfer_cell

forward(self, hidden, cell)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/seq2seq.py
def forward(self, hidden, cell):
    transfer_hidden = torch.tanh(self.fc_hidden(hidden))
    transfer_cell = self.fc_cell(cell)
    return transfer_hidden, transfer_cell

Transformer (Module)

Source code in torchhydro/models/seq2seq.py
class Transformer(nn.Module):
    def __init__(
        self,
        n_encoder_inputs,
        n_decoder_inputs,
        n_decoder_output,
        channels=256,
        num_embeddings=512,
        nhead=8,
        num_layers=8,
        dropout=0.1,
        hindcast_output_window=0,
    ):
        """TODO: hindcast_output_window seems not used

        Parameters
        ----------
        n_encoder_inputs : _type_
            _description_
        n_decoder_inputs : _type_
            _description_
        n_decoder_output : _type_
            _description_
        channels : int, optional
            _description_, by default 256
        num_embeddings : int, optional
            _description_, by default 512
        nhead : int, optional
            _description_, by default 8
        num_layers : int, optional
            _description_, by default 8
        dropout : float, optional
            _description_, by default 0.1
        hindcast_output_window : int, optional
            _description_, by default 0
        """
        super().__init__()

        self.input_pos_embedding = torch.nn.Embedding(num_embeddings, channels)
        self.target_pos_embedding = torch.nn.Embedding(num_embeddings, channels)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=channels,
            nhead=nhead,
            dropout=dropout,
            dim_feedforward=4 * channels,
        )
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=channels,
            nhead=nhead,
            dropout=dropout,
            dim_feedforward=4 * channels,
        )

        self.encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers)
        self.decoder = torch.nn.TransformerDecoder(decoder_layer, num_layers)

        self.input_projection = nn.Linear(n_encoder_inputs, channels)
        self.output_projection = nn.Linear(n_decoder_inputs, channels)

        self.linear = nn.Linear(channels, n_decoder_output)

        self.do = nn.Dropout(p=dropout)

    def encode_src(self, src):
        src_start = self.input_projection(src)

        in_sequence_len, batch_size = src_start.size(0), src_start.size(1)
        pos_encoder = (
            torch.arange(0, in_sequence_len, device=src.device)
            .unsqueeze(0)
            .repeat(batch_size, 1)
        )
        pos_encoder = self.input_pos_embedding(pos_encoder).permute(1, 0, 2)

        src = src_start + pos_encoder
        src = self.encoder(src) + src_start
        src = self.do(src)
        return src

    def decode_trg(self, trg, memory):
        trg_start = self.output_projection(trg)

        out_sequence_len, batch_size = trg_start.size(0), trg_start.size(1)
        pos_decoder = (
            torch.arange(0, out_sequence_len, device=trg.device)
            .unsqueeze(0)
            .repeat(batch_size, 1)
        )
        pos_decoder = self.target_pos_embedding(pos_decoder).permute(1, 0, 2)

        trg = pos_decoder + trg_start
        trg_mask = gen_trg_mask(out_sequence_len, trg.device)
        out = self.decoder(tgt=trg, memory=memory, tgt_mask=trg_mask) + trg_start
        out = self.do(out)
        out = self.linear(out)
        return out

    def forward(self, *x):
        src, trg = x
        src = self.encode_src(src)
        return self.decode_trg(trg=trg, memory=src)

__init__(self, n_encoder_inputs, n_decoder_inputs, n_decoder_output, channels=256, num_embeddings=512, nhead=8, num_layers=8, dropout=0.1, hindcast_output_window=0) special

TODO: hindcast_output_window seems not used

Parameters

n_encoder_inputs : type description n_decoder_inputs : type description n_decoder_output : type description channels : int, optional description, by default 256 num_embeddings : int, optional description, by default 512 nhead : int, optional description, by default 8 num_layers : int, optional description, by default 8 dropout : float, optional description, by default 0.1 hindcast_output_window : int, optional description, by default 0

Source code in torchhydro/models/seq2seq.py
def __init__(
    self,
    n_encoder_inputs,
    n_decoder_inputs,
    n_decoder_output,
    channels=256,
    num_embeddings=512,
    nhead=8,
    num_layers=8,
    dropout=0.1,
    hindcast_output_window=0,
):
    """TODO: hindcast_output_window seems not used

    Parameters
    ----------
    n_encoder_inputs : _type_
        _description_
    n_decoder_inputs : _type_
        _description_
    n_decoder_output : _type_
        _description_
    channels : int, optional
        _description_, by default 256
    num_embeddings : int, optional
        _description_, by default 512
    nhead : int, optional
        _description_, by default 8
    num_layers : int, optional
        _description_, by default 8
    dropout : float, optional
        _description_, by default 0.1
    hindcast_output_window : int, optional
        _description_, by default 0
    """
    super().__init__()

    self.input_pos_embedding = torch.nn.Embedding(num_embeddings, channels)
    self.target_pos_embedding = torch.nn.Embedding(num_embeddings, channels)

    encoder_layer = nn.TransformerEncoderLayer(
        d_model=channels,
        nhead=nhead,
        dropout=dropout,
        dim_feedforward=4 * channels,
    )
    decoder_layer = nn.TransformerDecoderLayer(
        d_model=channels,
        nhead=nhead,
        dropout=dropout,
        dim_feedforward=4 * channels,
    )

    self.encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers)
    self.decoder = torch.nn.TransformerDecoder(decoder_layer, num_layers)

    self.input_projection = nn.Linear(n_encoder_inputs, channels)
    self.output_projection = nn.Linear(n_decoder_inputs, channels)

    self.linear = nn.Linear(channels, n_decoder_output)

    self.do = nn.Dropout(p=dropout)

forward(self, *x)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/seq2seq.py
def forward(self, *x):
    src, trg = x
    src = self.encode_src(src)
    return self.decode_trg(trg=trg, memory=src)

seqforecast

FeatureEmbedding (Module)

Source code in torchhydro/models/seqforecast.py
class FeatureEmbedding(nn.Module):
    def __init__(
        self, input_dim, embedding_dim, hidden_size=0, dropout=0.0, activation="relu"
    ):
        super(FeatureEmbedding, self).__init__()
        self.embedding = Mlp(
            input_dim,
            embedding_dim,
            hidden_size=hidden_size,
            dr=dropout,
            activation=activation,
        )

    def forward(self, static_features):
        return self.embedding(static_features)

forward(self, static_features)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/seqforecast.py
def forward(self, static_features):
    return self.embedding(static_features)

ForecastLSTM (Module)

Source code in torchhydro/models/seqforecast.py
class ForecastLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout=0):
        super(ForecastLSTM, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, dropout=dropout)

    def forward(self, x, h, c):
        output, _ = self.lstm(x, (h, c))
        return output

forward(self, x, h, c)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/seqforecast.py
def forward(self, x, h, c):
    output, _ = self.lstm(x, (h, c))
    return output

HiddenStateTransferNet (Module)

Source code in torchhydro/models/seqforecast.py
class HiddenStateTransferNet(nn.Module):
    def __init__(
        self, hindcast_hidden_dim, forecast_hidden_dim, dropout=0.0, activation="relu"
    ):
        super(HiddenStateTransferNet, self).__init__()
        self.linear_transfer = nn.Linear(hindcast_hidden_dim, forecast_hidden_dim)
        self.nonlinear_transfer = Mlp(
            hindcast_hidden_dim,
            forecast_hidden_dim,
            hidden_size=0,
            dr=dropout,
            activation=activation,
        )

    def forward(self, hidden, cell):
        transfer_hidden = self.nonlinear_transfer(hidden)
        transfer_cell = self.linear_transfer(cell)
        return transfer_hidden, transfer_cell

forward(self, hidden, cell)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/seqforecast.py
def forward(self, hidden, cell):
    transfer_hidden = self.nonlinear_transfer(hidden)
    transfer_cell = self.linear_transfer(cell)
    return transfer_hidden, transfer_cell

HindcastLSTM (Module)

Source code in torchhydro/models/seqforecast.py
class HindcastLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout=0):
        super(HindcastLSTM, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, dropout=dropout)

    def forward(self, x):
        output, (h, c) = self.lstm(x)
        return output, h, c

forward(self, x)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/seqforecast.py
def forward(self, x):
    output, (h, c) = self.lstm(x)
    return output, h, c

ModelOutputHead (Module)

Source code in torchhydro/models/seqforecast.py
class ModelOutputHead(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super(ModelOutputHead, self).__init__()
        self.head = nn.Sequential(nn.Linear(hidden_dim, output_dim))

    def forward(self, x):
        return self.head(x)

forward(self, x)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/seqforecast.py
def forward(self, x):
    return self.head(x)

SequentialForecastLSTM (Module)

Source code in torchhydro/models/seqforecast.py
class SequentialForecastLSTM(nn.Module):
    def __init__(
        self,
        static_input_dim,
        dynamic_input_dim,
        static_embedding_dim,
        sta_embed_hidden_dim,
        dynamic_embedding_dim,
        dyn_embed_hidden_dim,
        hindcast_hidden_dim,
        forecast_hidden_dim,
        output_dim,
        hindcast_output_window,
        embedding_dropout,
        handoff_dropout,
        lstm_dropout,
        activation="relu",
    ):
        """_summary_

        Parameters
        ----------
        static_input_dim : int
            _description_
        dynamic_input_dim : int
            _description_
        static_embedding_dim : int
            output size of static embedding
        sta_embed_hidden_dim: int
            hidden size of static embedding
        dynamic_embedding_dim : int
            output size of dynamic embedding
        dyn_embed_hidden_dim: int
            hidden size of dynamic embedding
        hidden_dim : _type_
            _description_
        output_dim : _type_
            _description_
        hindcast_output_window : int
            length of hindcast output to calculate loss
        """
        super(SequentialForecastLSTM, self).__init__()
        self.output_dim = output_dim
        self.hindcast_output_window = hindcast_output_window
        self.dynamic_embedding_dim = dynamic_embedding_dim
        if static_embedding_dim > 0:
            self.static_embedding = FeatureEmbedding(
                static_input_dim,
                static_embedding_dim,
                sta_embed_hidden_dim,
                embedding_dropout,
                activation,
            )
        if dynamic_embedding_dim > 0:
            self.dynamic_embedding = FeatureEmbedding(
                dynamic_input_dim,
                dynamic_embedding_dim,
                dyn_embed_hidden_dim,
                embedding_dropout,
                activation,
            )
        self.static_embedding_dim = static_embedding_dim
        self.dynamic_embedding_dim = dynamic_embedding_dim
        hindcast_input_dim = (
            dynamic_embedding_dim if dynamic_embedding_dim != 0 else dynamic_input_dim
        ) + (static_embedding_dim if static_embedding_dim != 0 else static_input_dim)
        forecast_input_dim = (
            dynamic_embedding_dim if dynamic_embedding_dim != 0 else dynamic_input_dim
        ) + (static_embedding_dim if static_embedding_dim != 0 else static_input_dim)
        self.hindcast_lstm = HindcastLSTM(
            hindcast_input_dim, hindcast_hidden_dim, lstm_dropout
        )
        self.forecast_lstm = ForecastLSTM(
            forecast_input_dim, forecast_hidden_dim, lstm_dropout
        )
        self.hiddenstatetransfer = HiddenStateTransferNet(
            hindcast_hidden_dim,
            forecast_hidden_dim,
            dropout=handoff_dropout,
            activation=activation,
        )
        self.hindcast_output_head = ModelOutputHead(hindcast_hidden_dim, output_dim)
        self.forecast_output_head = ModelOutputHead(forecast_hidden_dim, output_dim)

    def _perform_embedding(self, static_features, dynamic_features):
        if self.dynamic_embedding_dim > 0:
            dynamic_embedded = self.dynamic_embedding(dynamic_features)
        else:
            dynamic_embedded = dynamic_features
        if self.static_embedding_dim > 0:
            static_embedded = self.static_embedding(static_features)
        else:
            static_embedded = static_features
        static_embedded = static_embedded.unsqueeze(1).expand(
            -1, dynamic_embedded.size(1), -1
        )
        return torch.cat([dynamic_embedded, static_embedded], dim=-1)

    def forward(self, *src):
        (
            hindcast_features,
            forecast_features,
            static_features,
        ) = src

        # Hindcast LSTM
        hindcast_input = self._perform_embedding(static_features, hindcast_features)
        hincast_output, h, c = self.hindcast_lstm(hindcast_input)

        if self.hindcast_output_window > 0:
            hincast_output = self.hindcast_output_head(
                hincast_output[:, -self.hindcast_output_window :, :]
            )

        h, c = self.hiddenstatetransfer(h, c)

        # Forecast LSTM
        forecast_input = self._perform_embedding(static_features, forecast_features)
        forecast_output = self.forecast_lstm(forecast_input, h, c)
        forecast_output = self.forecast_output_head(forecast_output)
        if self.hindcast_output_window > 0:
            forecast_output = torch.cat([hincast_output, forecast_output], dim=1)
        return forecast_output

__init__(self, static_input_dim, dynamic_input_dim, static_embedding_dim, sta_embed_hidden_dim, dynamic_embedding_dim, dyn_embed_hidden_dim, hindcast_hidden_dim, forecast_hidden_dim, output_dim, hindcast_output_window, embedding_dropout, handoff_dropout, lstm_dropout, activation='relu') special

summary

Parameters

static_input_dim : int description dynamic_input_dim : int description static_embedding_dim : int output size of static embedding !!! sta_embed_hidden_dim "int" hidden size of static embedding dynamic_embedding_dim : int output size of dynamic embedding !!! dyn_embed_hidden_dim "int" hidden size of dynamic embedding hidden_dim : type description output_dim : type description hindcast_output_window : int length of hindcast output to calculate loss

Source code in torchhydro/models/seqforecast.py
def __init__(
    self,
    static_input_dim,
    dynamic_input_dim,
    static_embedding_dim,
    sta_embed_hidden_dim,
    dynamic_embedding_dim,
    dyn_embed_hidden_dim,
    hindcast_hidden_dim,
    forecast_hidden_dim,
    output_dim,
    hindcast_output_window,
    embedding_dropout,
    handoff_dropout,
    lstm_dropout,
    activation="relu",
):
    """_summary_

    Parameters
    ----------
    static_input_dim : int
        _description_
    dynamic_input_dim : int
        _description_
    static_embedding_dim : int
        output size of static embedding
    sta_embed_hidden_dim: int
        hidden size of static embedding
    dynamic_embedding_dim : int
        output size of dynamic embedding
    dyn_embed_hidden_dim: int
        hidden size of dynamic embedding
    hidden_dim : _type_
        _description_
    output_dim : _type_
        _description_
    hindcast_output_window : int
        length of hindcast output to calculate loss
    """
    super(SequentialForecastLSTM, self).__init__()
    self.output_dim = output_dim
    self.hindcast_output_window = hindcast_output_window
    self.dynamic_embedding_dim = dynamic_embedding_dim
    if static_embedding_dim > 0:
        self.static_embedding = FeatureEmbedding(
            static_input_dim,
            static_embedding_dim,
            sta_embed_hidden_dim,
            embedding_dropout,
            activation,
        )
    if dynamic_embedding_dim > 0:
        self.dynamic_embedding = FeatureEmbedding(
            dynamic_input_dim,
            dynamic_embedding_dim,
            dyn_embed_hidden_dim,
            embedding_dropout,
            activation,
        )
    self.static_embedding_dim = static_embedding_dim
    self.dynamic_embedding_dim = dynamic_embedding_dim
    hindcast_input_dim = (
        dynamic_embedding_dim if dynamic_embedding_dim != 0 else dynamic_input_dim
    ) + (static_embedding_dim if static_embedding_dim != 0 else static_input_dim)
    forecast_input_dim = (
        dynamic_embedding_dim if dynamic_embedding_dim != 0 else dynamic_input_dim
    ) + (static_embedding_dim if static_embedding_dim != 0 else static_input_dim)
    self.hindcast_lstm = HindcastLSTM(
        hindcast_input_dim, hindcast_hidden_dim, lstm_dropout
    )
    self.forecast_lstm = ForecastLSTM(
        forecast_input_dim, forecast_hidden_dim, lstm_dropout
    )
    self.hiddenstatetransfer = HiddenStateTransferNet(
        hindcast_hidden_dim,
        forecast_hidden_dim,
        dropout=handoff_dropout,
        activation=activation,
    )
    self.hindcast_output_head = ModelOutputHead(hindcast_hidden_dim, output_dim)
    self.forecast_output_head = ModelOutputHead(forecast_hidden_dim, output_dim)

forward(self, *src)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/seqforecast.py
def forward(self, *src):
    (
        hindcast_features,
        forecast_features,
        static_features,
    ) = src

    # Hindcast LSTM
    hindcast_input = self._perform_embedding(static_features, hindcast_features)
    hincast_output, h, c = self.hindcast_lstm(hindcast_input)

    if self.hindcast_output_window > 0:
        hincast_output = self.hindcast_output_head(
            hincast_output[:, -self.hindcast_output_window :, :]
        )

    h, c = self.hiddenstatetransfer(h, c)

    # Forecast LSTM
    forecast_input = self._perform_embedding(static_features, forecast_features)
    forecast_output = self.forecast_lstm(forecast_input, h, c)
    forecast_output = self.forecast_output_head(forecast_output)
    if self.hindcast_output_window > 0:
        forecast_output = torch.cat([hincast_output, forecast_output], dim=1)
    return forecast_output

simple_lstm

Author: Wenyu Ouyang Date: 2023-09-19 09:36:25 LastEditTime: 2025-11-08 15:55:14 LastEditors: Wenyu Ouyang Description: Some self-made LSTMs FilePath: orchhydro orchhydro\models\simple_lstm.py Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.

HFLSTM (Module)

Source code in torchhydro/models/simple_lstm.py
class HFLSTM(nn.Module):
    def __init__(
        self,
        input_size: int,
        output_size: int,
        hidden_size: int,
        dr: float = 0.0,
        teacher_forcing_ratio: float = 0,
        hindcast_with_output: bool = True,
    ):
        """

        Parameters
        ----------
        input_size : int
            without streamflow
        output_size : int
            streamflow
        hidden_size : int
        dr : float, optional
            dropout, by default 0.0
        teacher_forcing_ratio : float, optional
            by default 0
        hindcast_with_output : bool, optional
            whether to use the output of the model as input for the next time step, by default True
        """
        super(HFLSTM, self).__init__()
        self.linearIn = nn.Linear(input_size, hidden_size)
        self.lstm = nn.LSTM(
            hidden_size,
            hidden_size,
        )
        self.dropout = nn.Dropout(p=dr)
        self.linearOut = nn.Linear(hidden_size, output_size)
        self.teacher_forcing_ratio = teacher_forcing_ratio
        self.hindcast_with_output = hindcast_with_output
        self.hidden_size = hidden_size
        self.output_size = output_size

    def _teacher_forcing_preparation(self, xq_hor: torch.Tensor) -> torch.Tensor:
        # teacher forcing preparation
        valid_mask = ~torch.isnan(xq_hor)
        random_vals = torch.rand_like(valid_mask, dtype=torch.float)
        return (random_vals < self.teacher_forcing_ratio) * valid_mask

    def _rho_forward(
        self, x_rho: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        x0_rho = F.relu(self.linearIn(x_rho))
        out_lstm_rho, (hn_rho, cn_rho) = self.lstm(x0_rho)
        out_lstm_rho_dr = self.dropout(out_lstm_rho)
        out_lstm_rho_lnout = self.linearOut(out_lstm_rho_dr)
        prev_output = out_lstm_rho_lnout[-1:, :, :]
        return out_lstm_rho_lnout, hn_rho, cn_rho, prev_output

    def forward(self, *x: Tuple[torch.Tensor, ...]) -> torch.Tensor:
        xfc_rho, xfc_hor, xq_rho, xq_hor = x

        x_rho = torch.cat((xfc_rho, xq_rho), dim=-1)
        hor_len, batch_size, _ = xfc_hor.size()

        # hindcast-forecast, we do not have forecast-hindcast situation
        # do rho forward first, prev_output is the last output of rho (seq_length = 1, batch_size, feature = output_size)
        if self.hindcast_with_output:
            _, h_n, c_n, prev_output = self._rho_forward(x_rho)
            seq_len = hor_len
        else:
            # TODO: need more test
            seq_len = xfc_rho.shape[0] + hor_len
            xfc_hor = torch.cat((xfc_rho, xfc_hor), dim=0)
            xq_hor = torch.cat((xq_rho, xq_hor), dim=0)
            h_n = torch.randn(1, batch_size, self.hidden_size).to(xfc_rho.device) * 0.1
            c_n = torch.randn(1, batch_size, self.hidden_size).to(xfc_rho.device) * 0.1
            prev_output = (
                torch.randn(1, batch_size, self.output_size).to(xfc_rho.device) * 0.1
            )

        use_teacher_forcing = self._teacher_forcing_preparation(xq_hor)

        # do hor forward
        outputs = torch.zeros(seq_len, batch_size, self.output_size).to(xfc_rho.device)
        # TODO: too slow here when seq_len is large, need to optimize
        for t in range(seq_len):
            real_streamflow_input = xq_hor[t : t + 1, :, :]
            prev_output = torch.where(
                use_teacher_forcing[t : t + 1, :, :],
                real_streamflow_input,
                prev_output,
            )
            input_concat = torch.cat((xfc_hor[t : t + 1, :, :], prev_output), dim=-1)

            # Pass through the initial linear layer
            x0 = F.relu(self.linearIn(input_concat))

            # LSTM step
            out_lstm, (h_n, c_n) = self.lstm(x0, (h_n, c_n))

            # Generate the current output
            prev_output = self.linearOut(out_lstm)
            outputs[t, :, :] = prev_output.squeeze(0)
        # Return the outputs
        return outputs[-hor_len:, :, :]

__init__(self, input_size, output_size, hidden_size, dr=0.0, teacher_forcing_ratio=0, hindcast_with_output=True) special

Parameters

input_size : int without streamflow output_size : int streamflow hidden_size : int dr : float, optional dropout, by default 0.0 teacher_forcing_ratio : float, optional by default 0 hindcast_with_output : bool, optional whether to use the output of the model as input for the next time step, by default True

Source code in torchhydro/models/simple_lstm.py
def __init__(
    self,
    input_size: int,
    output_size: int,
    hidden_size: int,
    dr: float = 0.0,
    teacher_forcing_ratio: float = 0,
    hindcast_with_output: bool = True,
):
    """

    Parameters
    ----------
    input_size : int
        without streamflow
    output_size : int
        streamflow
    hidden_size : int
    dr : float, optional
        dropout, by default 0.0
    teacher_forcing_ratio : float, optional
        by default 0
    hindcast_with_output : bool, optional
        whether to use the output of the model as input for the next time step, by default True
    """
    super(HFLSTM, self).__init__()
    self.linearIn = nn.Linear(input_size, hidden_size)
    self.lstm = nn.LSTM(
        hidden_size,
        hidden_size,
    )
    self.dropout = nn.Dropout(p=dr)
    self.linearOut = nn.Linear(hidden_size, output_size)
    self.teacher_forcing_ratio = teacher_forcing_ratio
    self.hindcast_with_output = hindcast_with_output
    self.hidden_size = hidden_size
    self.output_size = output_size

forward(self, *x)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/simple_lstm.py
def forward(self, *x: Tuple[torch.Tensor, ...]) -> torch.Tensor:
    xfc_rho, xfc_hor, xq_rho, xq_hor = x

    x_rho = torch.cat((xfc_rho, xq_rho), dim=-1)
    hor_len, batch_size, _ = xfc_hor.size()

    # hindcast-forecast, we do not have forecast-hindcast situation
    # do rho forward first, prev_output is the last output of rho (seq_length = 1, batch_size, feature = output_size)
    if self.hindcast_with_output:
        _, h_n, c_n, prev_output = self._rho_forward(x_rho)
        seq_len = hor_len
    else:
        # TODO: need more test
        seq_len = xfc_rho.shape[0] + hor_len
        xfc_hor = torch.cat((xfc_rho, xfc_hor), dim=0)
        xq_hor = torch.cat((xq_rho, xq_hor), dim=0)
        h_n = torch.randn(1, batch_size, self.hidden_size).to(xfc_rho.device) * 0.1
        c_n = torch.randn(1, batch_size, self.hidden_size).to(xfc_rho.device) * 0.1
        prev_output = (
            torch.randn(1, batch_size, self.output_size).to(xfc_rho.device) * 0.1
        )

    use_teacher_forcing = self._teacher_forcing_preparation(xq_hor)

    # do hor forward
    outputs = torch.zeros(seq_len, batch_size, self.output_size).to(xfc_rho.device)
    # TODO: too slow here when seq_len is large, need to optimize
    for t in range(seq_len):
        real_streamflow_input = xq_hor[t : t + 1, :, :]
        prev_output = torch.where(
            use_teacher_forcing[t : t + 1, :, :],
            real_streamflow_input,
            prev_output,
        )
        input_concat = torch.cat((xfc_hor[t : t + 1, :, :], prev_output), dim=-1)

        # Pass through the initial linear layer
        x0 = F.relu(self.linearIn(input_concat))

        # LSTM step
        out_lstm, (h_n, c_n) = self.lstm(x0, (h_n, c_n))

        # Generate the current output
        prev_output = self.linearOut(out_lstm)
        outputs[t, :, :] = prev_output.squeeze(0)
    # Return the outputs
    return outputs[-hor_len:, :, :]

LinearMultiLayerLSTMModel (MultiLayerLSTM)

This model is nonlinear layer + MultiLayerLSTM.

Source code in torchhydro/models/simple_lstm.py
class LinearMultiLayerLSTMModel(MultiLayerLSTM):
    """
    This model is nonlinear layer + MultiLayerLSTM.
    """

    def __init__(self, linear_size: int, **kwargs: Any):
        """

        Parameters
        ----------
        linear_size
            the number of input features for the first input linear layer
        """
        super(LinearMultiLayerLSTMModel, self).__init__(**kwargs)
        self.former_linear = nn.Linear(linear_size, kwargs["input_size"])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x0 = F.relu(self.former_linear(x))
        return super(LinearMultiLayerLSTMModel, self).forward(x0)

__init__(self, linear_size, **kwargs) special

Parameters

linear_size the number of input features for the first input linear layer

Source code in torchhydro/models/simple_lstm.py
def __init__(self, linear_size: int, **kwargs: Any):
    """

    Parameters
    ----------
    linear_size
        the number of input features for the first input linear layer
    """
    super(LinearMultiLayerLSTMModel, self).__init__(**kwargs)
    self.former_linear = nn.Linear(linear_size, kwargs["input_size"])

forward(self, x)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/simple_lstm.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    x0 = F.relu(self.former_linear(x))
    return super(LinearMultiLayerLSTMModel, self).forward(x0)

LinearSimpleLSTMModel (SimpleLSTM)

This model is nonlinear layer + SimpleLSTM.

Source code in torchhydro/models/simple_lstm.py
class LinearSimpleLSTMModel(SimpleLSTM):
    """
    This model is nonlinear layer + SimpleLSTM.
    """

    def __init__(self, linear_size: int, **kwargs: Any):
        """

        Parameters
        ----------
        linear_size
            the number of input features for the first input linear layer
        """
        super(LinearSimpleLSTMModel, self).__init__(**kwargs)
        self.former_linear = nn.Linear(linear_size, kwargs["input_size"])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass for the LinearSimpleLSTMModel.

        Args:
            x: Input tensor which will be passed through a linear layer first.

        Returns:
            The output of the underlying SimpleLSTM model.
        """
        x0 = F.relu(self.former_linear(x))
        return super(LinearSimpleLSTMModel, self).forward(x0)

__init__(self, linear_size, **kwargs) special

Parameters

linear_size the number of input features for the first input linear layer

Source code in torchhydro/models/simple_lstm.py
def __init__(self, linear_size: int, **kwargs: Any):
    """

    Parameters
    ----------
    linear_size
        the number of input features for the first input linear layer
    """
    super(LinearSimpleLSTMModel, self).__init__(**kwargs)
    self.former_linear = nn.Linear(linear_size, kwargs["input_size"])

forward(self, x)

Forward pass for the LinearSimpleLSTMModel.

Parameters:

Name Type Description Default
x Tensor

Input tensor which will be passed through a linear layer first.

required

Returns:

Type Description
Tensor

The output of the underlying SimpleLSTM model.

Source code in torchhydro/models/simple_lstm.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass for the LinearSimpleLSTMModel.

    Args:
        x: Input tensor which will be passed through a linear layer first.

    Returns:
        The output of the underlying SimpleLSTM model.
    """
    x0 = F.relu(self.former_linear(x))
    return super(LinearSimpleLSTMModel, self).forward(x0)

MinLSTM (Module)

Only "parallel mode" is supported for conciseness. use log space. Written by Yang Wang

https://arxiv.org/pdf/2410.01201 https://github.com/axion66/minLSTM-implementation

input shape: [batch, seq_len, in_chn] output shape: [batch,seq_len, out_chn]

Source code in torchhydro/models/simple_lstm.py
class MinLSTM(nn.Module):
    """
    Only "parallel mode" is supported for conciseness.
    use log space.
    Written by Yang Wang

    https://arxiv.org/pdf/2410.01201
    https://github.com/axion66/minLSTM-implementation

    input shape: [batch, seq_len, in_chn]
    output shape: [batch,seq_len, out_chn]
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.linear = nn.Linear(
            input_size, hidden_size * 3, bias=False, device=device, dtype=dtype
        )

    def forward(
        self, x_t: torch.Tensor, h_prev: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        if h_prev is None:
            h_prev = torch.zeros(
                x_t.size(0), self.hidden_size, device=x_t.device, dtype=x_t.dtype
            )
        seq_len = x_t.shape[1]
        f, i, h = torch.chunk(self.linear(x_t), chunks=3, dim=-1)
        diff = F.softplus(-f) - F.softplus(-i)
        log_f = -F.softplus(diff)
        log_i = -F.softplus(-diff)
        log_h_0 = self.log_g(h_prev)
        log_tilde_h = self.log_g(h)
        log_coeff = log_f.unsqueeze(1)
        log_val = torch.cat([log_h_0.unsqueeze(1), (log_i + log_tilde_h)], dim=1)
        h_t = self.parallel_scan_log(log_coeff, log_val)
        return h_t[:, -seq_len:]

    def parallel_scan_log(
        self, log_coeffs: torch.Tensor, log_values: torch.Tensor
    ) -> torch.Tensor:
        a_star = F.pad(torch.cumsum(log_coeffs, dim=1), (0, 0, 1, 0)).squeeze(1)
        log_h0_plus_b_star = torch.logcumsumexp(log_values - a_star, dim=1).squeeze(1)
        log_h = a_star + log_h0_plus_b_star
        return torch.exp(log_h)  # will return [batch, seq + 1, chn]

    def g(self, x: torch.Tensor) -> torch.Tensor:
        return torch.where(x >= 0, x + 0.5, torch.sigmoid(x))

    def log_g(self, x: torch.Tensor) -> torch.Tensor:
        return torch.where(x >= 0, (F.relu(x) + 0.5).log(), -F.softplus(-x))

forward(self, x_t, h_prev=None)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/simple_lstm.py
def forward(
    self, x_t: torch.Tensor, h_prev: Optional[torch.Tensor] = None
) -> torch.Tensor:
    if h_prev is None:
        h_prev = torch.zeros(
            x_t.size(0), self.hidden_size, device=x_t.device, dtype=x_t.dtype
        )
    seq_len = x_t.shape[1]
    f, i, h = torch.chunk(self.linear(x_t), chunks=3, dim=-1)
    diff = F.softplus(-f) - F.softplus(-i)
    log_f = -F.softplus(diff)
    log_i = -F.softplus(-diff)
    log_h_0 = self.log_g(h_prev)
    log_tilde_h = self.log_g(h)
    log_coeff = log_f.unsqueeze(1)
    log_val = torch.cat([log_h_0.unsqueeze(1), (log_i + log_tilde_h)], dim=1)
    h_t = self.parallel_scan_log(log_coeff, log_val)
    return h_t[:, -seq_len:]

MultiLayerLSTM (Module)

Source code in torchhydro/models/simple_lstm.py
class MultiLayerLSTM(nn.Module):
    def __init__(
        self,
        input_size: int,
        output_size: int,
        hidden_size: int,
        num_layers: int = 1,
        dr: float = 0.0,
    ):
        super(MultiLayerLSTM, self).__init__()
        self.linearIn = nn.Linear(input_size, hidden_size)
        self.lstm = nn.LSTM(
            hidden_size,
            hidden_size,
            num_layers,
            dropout=dr,
        )
        self.linearOut = nn.Linear(hidden_size, output_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x0 = F.relu(self.linearIn(x))
        out_lstm, (hn, cn) = self.lstm(x0)
        return self.linearOut(out_lstm)

forward(self, x)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/simple_lstm.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    x0 = F.relu(self.linearIn(x))
    out_lstm, (hn, cn) = self.lstm(x0)
    return self.linearOut(out_lstm)

SimpleLSTM (Module)

Source code in torchhydro/models/simple_lstm.py
class SimpleLSTM(nn.Module):
    def __init__(
        self, input_size: int, output_size: int, hidden_size: int, dr: float = 0.0
    ):
        super(SimpleLSTM, self).__init__()
        self.linearIn = nn.Linear(input_size, hidden_size)
        self.lstm = nn.LSTM(
            hidden_size,
            hidden_size,
        )
        self.dropout = nn.Dropout(p=dr)
        self.linearOut = nn.Linear(hidden_size, output_size)

    def forward(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
        """Forward pass of the SimpleLSTM model.

        Args:
            x: Input tensor.
            **kwargs: Optional keyword arguments:
                - seq_lengths: Sequence lengths for PackedSequence.
                - mask: Mask tensor for manual masking.
                - use_manual_mask: Prioritize manual masking over PackedSequence.

        Returns:
            The output tensor from the model.
        """
        x0 = F.relu(self.linearIn(x))

        # Extract parameters from kwargs
        seq_lengths = kwargs.get("seq_lengths", None)
        mask = kwargs.get("mask", None)
        use_manual_mask = kwargs.get("use_manual_mask", False)

        # Determine processing method based on available parameters
        if use_manual_mask and mask is not None:
            # Use manual masking
            out_lstm, (hn, cn) = self.lstm(x0)

            # Apply mask to LSTM output
            # Ensure mask has the correct shape for broadcasting
            if mask.dim() == 2:  # [batch_size, seq_len]
                # Convert to [seq_len, batch_size, 1] for seq_first format
                mask = mask.transpose(0, 1).unsqueeze(-1)
            elif mask.dim() == 3 and mask.size(-1) != 1:
                # If mask is [seq_len, batch_size, features], take only the first feature
                mask = mask[:, :, :1]

            # Apply mask: set masked positions to zero
            out_lstm = out_lstm * mask

        elif seq_lengths is not None and not use_manual_mask:
            # Use PackedSequence (original behavior)
            packed_x = pack_padded_sequence(
                x0, seq_lengths, batch_first=False, enforce_sorted=False
            )
            packed_out, (hn, cn) = self.lstm(packed_x)
            out_lstm, _ = pad_packed_sequence(packed_out, batch_first=False)

        else:
            # Standard processing without masking
            out_lstm, (hn, cn) = self.lstm(x0)

            # Apply mask if provided (even without use_manual_mask flag)
            if mask is not None:
                if mask.dim() == 2:  # [batch_size, seq_len]
                    mask = mask.transpose(0, 1).unsqueeze(-1)
                elif mask.dim() == 3 and mask.size(-1) != 1:
                    mask = mask[:, :, :1]
                out_lstm = out_lstm * mask

        out_lstm_dr = self.dropout(out_lstm)
        return self.linearOut(out_lstm_dr)

forward(self, x, **kwargs)

Forward pass of the SimpleLSTM model.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
**kwargs Any

Optional keyword arguments: - seq_lengths: Sequence lengths for PackedSequence. - mask: Mask tensor for manual masking. - use_manual_mask: Prioritize manual masking over PackedSequence.

{}

Returns:

Type Description
Tensor

The output tensor from the model.

Source code in torchhydro/models/simple_lstm.py
def forward(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
    """Forward pass of the SimpleLSTM model.

    Args:
        x: Input tensor.
        **kwargs: Optional keyword arguments:
            - seq_lengths: Sequence lengths for PackedSequence.
            - mask: Mask tensor for manual masking.
            - use_manual_mask: Prioritize manual masking over PackedSequence.

    Returns:
        The output tensor from the model.
    """
    x0 = F.relu(self.linearIn(x))

    # Extract parameters from kwargs
    seq_lengths = kwargs.get("seq_lengths", None)
    mask = kwargs.get("mask", None)
    use_manual_mask = kwargs.get("use_manual_mask", False)

    # Determine processing method based on available parameters
    if use_manual_mask and mask is not None:
        # Use manual masking
        out_lstm, (hn, cn) = self.lstm(x0)

        # Apply mask to LSTM output
        # Ensure mask has the correct shape for broadcasting
        if mask.dim() == 2:  # [batch_size, seq_len]
            # Convert to [seq_len, batch_size, 1] for seq_first format
            mask = mask.transpose(0, 1).unsqueeze(-1)
        elif mask.dim() == 3 and mask.size(-1) != 1:
            # If mask is [seq_len, batch_size, features], take only the first feature
            mask = mask[:, :, :1]

        # Apply mask: set masked positions to zero
        out_lstm = out_lstm * mask

    elif seq_lengths is not None and not use_manual_mask:
        # Use PackedSequence (original behavior)
        packed_x = pack_padded_sequence(
            x0, seq_lengths, batch_first=False, enforce_sorted=False
        )
        packed_out, (hn, cn) = self.lstm(packed_x)
        out_lstm, _ = pad_packed_sequence(packed_out, batch_first=False)

    else:
        # Standard processing without masking
        out_lstm, (hn, cn) = self.lstm(x0)

        # Apply mask if provided (even without use_manual_mask flag)
        if mask is not None:
            if mask.dim() == 2:  # [batch_size, seq_len]
                mask = mask.transpose(0, 1).unsqueeze(-1)
            elif mask.dim() == 3 and mask.size(-1) != 1:
                mask = mask[:, :, :1]
            out_lstm = out_lstm * mask

    out_lstm_dr = self.dropout(out_lstm)
    return self.linearOut(out_lstm_dr)

SimpleLSTMForecast (SimpleLSTM)

Source code in torchhydro/models/simple_lstm.py
class SimpleLSTMForecast(SimpleLSTM):
    def __init__(
        self,
        input_size: int,
        output_size: int,
        hidden_size: int,
        forecast_length: int,
        dr: float = 0.0,
    ):
        super(SimpleLSTMForecast, self).__init__(
            input_size, output_size, hidden_size, dr
        )
        self.forecast_length = forecast_length

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass for the SimpleLSTMForecast model.

        This method calls the parent's forward method and returns only the
        final part of the output sequence corresponding to the forecast length.

        Args:
            x: Input tensor.

        Returns:
            A tensor containing the forecast part of the output sequence.
        """
        # 调用父类的forward方法获取完整的输出
        full_output = super(SimpleLSTMForecast, self).forward(x)

        return full_output[-self.forecast_length :, :, :]

forward(self, x)

Forward pass for the SimpleLSTMForecast model.

This method calls the parent's forward method and returns only the final part of the output sequence corresponding to the forecast length.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

A tensor containing the forecast part of the output sequence.

Source code in torchhydro/models/simple_lstm.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass for the SimpleLSTMForecast model.

    This method calls the parent's forward method and returns only the
    final part of the output sequence corresponding to the forecast length.

    Args:
        x: Input tensor.

    Returns:
        A tensor containing the forecast part of the output sequence.
    """
    # 调用父类的forward方法获取完整的输出
    full_output = super(SimpleLSTMForecast, self).forward(x)

    return full_output[-self.forecast_length :, :, :]

SlowLSTM (Module)

A pedagogic implementation of Hochreiter & Schmidhuber: 'Long-Short Term Memory' http://www.bioinf.jku.at/publications/older/2604.pdf

Source code in torchhydro/models/simple_lstm.py
class SlowLSTM(nn.Module):
    """
    A pedagogic implementation of Hochreiter & Schmidhuber:
    'Long-Short Term Memory'
    http://www.bioinf.jku.at/publications/older/2604.pdf
    """

    def __init__(
        self, input_size: int, hidden_size: int, bias: bool = True, dropout: float = 0.0
    ):
        super(SlowLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.dropout = dropout
        # input to hidden weights
        self.w_xi = P(T(hidden_size, input_size))
        self.w_xf = P(T(hidden_size, input_size))
        self.w_xo = P(T(hidden_size, input_size))
        self.w_xc = P(T(hidden_size, input_size))
        # hidden to hidden weights
        self.w_hi = P(T(hidden_size, hidden_size))
        self.w_hf = P(T(hidden_size, hidden_size))
        self.w_ho = P(T(hidden_size, hidden_size))
        self.w_hc = P(T(hidden_size, hidden_size))
        # bias terms
        self.b_i = T(hidden_size).fill_(0)
        self.b_f = T(hidden_size).fill_(0)
        self.b_o = T(hidden_size).fill_(0)
        self.b_c = T(hidden_size).fill_(0)

        # Wrap biases as parameters if desired, else as variables without gradients
        W = P if bias else (lambda x: P(x, requires_grad=False))
        self.b_i = W(self.b_i)
        self.b_f = W(self.b_f)
        self.b_o = W(self.b_o)
        self.b_c = W(self.b_c)
        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(
        self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        h, c = hidden
        h = h.view(h.size(0), -1)
        c = c.view(h.size(0), -1)
        x = x.view(x.size(0), -1)
        # Linear mappings
        i_t = th.mm(x, self.w_xi) + th.mm(h, self.w_hi) + self.b_i
        f_t = th.mm(x, self.w_xf) + th.mm(h, self.w_hf) + self.b_f
        o_t = th.mm(x, self.w_xo) + th.mm(h, self.w_ho) + self.b_o
        # activations
        i_t.sigmoid_()
        f_t.sigmoid_()
        o_t.sigmoid_()
        # cell computations
        c_t = th.mm(x, self.w_xc) + th.mm(h, self.w_hc) + self.b_c
        c_t.tanh_()
        c_t = th.mul(c, f_t) + th.mul(i_t, c_t)
        h_t = th.mul(o_t, th.tanh(c_t))
        # Reshape for compatibility
        h_t = h_t.view(h_t.size(0), 1, -1)
        c_t = c_t.view(c_t.size(0), 1, -1)
        if self.dropout > 0.0:
            F.dropout(h_t, p=self.dropout, training=self.training, inplace=True)
        return h_t, (h_t, c_t)

    def sample_mask(self) -> None:
        pass

forward(self, x, hidden)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/simple_lstm.py
def forward(
    self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    h, c = hidden
    h = h.view(h.size(0), -1)
    c = c.view(h.size(0), -1)
    x = x.view(x.size(0), -1)
    # Linear mappings
    i_t = th.mm(x, self.w_xi) + th.mm(h, self.w_hi) + self.b_i
    f_t = th.mm(x, self.w_xf) + th.mm(h, self.w_hf) + self.b_f
    o_t = th.mm(x, self.w_xo) + th.mm(h, self.w_ho) + self.b_o
    # activations
    i_t.sigmoid_()
    f_t.sigmoid_()
    o_t.sigmoid_()
    # cell computations
    c_t = th.mm(x, self.w_xc) + th.mm(h, self.w_hc) + self.b_c
    c_t.tanh_()
    c_t = th.mul(c, f_t) + th.mul(i_t, c_t)
    h_t = th.mul(o_t, th.tanh(c_t))
    # Reshape for compatibility
    h_t = h_t.view(h_t.size(0), 1, -1)
    c_t = c_t.view(c_t.size(0), 1, -1)
    if self.dropout > 0.0:
        F.dropout(h_t, p=self.dropout, training=self.training, inplace=True)
    return h_t, (h_t, c_t)

spplstm

Author: Xinzhuo Wu Date: 2023-09-30 1:20:18 LastEditTime: 2024-05-27 16:26:06 LastEditors: Wenyu Ouyang Description: spp lstm model FilePath: orchhydro orchhydro\models\spplstm.py Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.

SPP_LSTM_Model (Module)

Source code in torchhydro/models/spplstm.py
class SPP_LSTM_Model(nn.Module):
    def __init__(
        self, hindcast_length, forecast_length, n_output, n_hidden_states, dropout
    ):
        super(SPP_LSTM_Model, self).__init__()

        self.conv1 = TimeDistributed(
            nn.Conv2d(
                in_channels=1,
                out_channels=32,
                kernel_size=1,
                padding="same",
            )
        )

        self.conv2 = TimeDistributed(
            nn.Conv2d(
                in_channels=32,
                out_channels=16,
                kernel_size=3,
                padding="same",
                bias=True,
            )
        )

        self.conv3 = TimeDistributed(
            nn.Conv2d(
                in_channels=16,
                out_channels=16,
                kernel_size=3,
                padding="same",
                bias=True,
            )
        )

        self.maxpool1 = TimeDistributed(nn.MaxPool2d(kernel_size=2, stride=(2, 2)))

        self.conv4 = TimeDistributed(
            nn.Conv2d(
                in_channels=16,
                out_channels=32,
                kernel_size=3,
                padding="same",
                bias=True,
            )
        )

        self.conv5 = TimeDistributed(
            nn.Conv2d(
                in_channels=32,
                out_channels=32,
                kernel_size=3,
                padding="same",
                bias=True,
            )
        )

        self.maxpool2 = TimeDistributed(SppLayer([4, 2, 1]))

        self.dropout = nn.Dropout(dropout)

        self.lstm = nn.LSTM(
            input_size=21 * 32, hidden_size=n_hidden_states, batch_first=True
        )

        self.dense = nn.Linear(in_features=n_hidden_states, out_features=n_output)

        self.forecast_length = forecast_length

    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = self.conv3(x)
        x = torch.relu(x)
        x = self.maxpool1(x)
        x = self.conv4(x)
        x = torch.relu(x)
        x = self.conv5(x)
        x = torch.relu(x)
        x = self.maxpool2(x)
        x = self.dropout(x)
        x = x.view(x.shape[0], x.shape[1], -1)
        x, _ = self.lstm(x)
        x = self.dense(x)
        x = x[:, -self.forecast_length :, :]
        return x

forward(self, x)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/spplstm.py
def forward(self, x):
    x = self.conv1(x)
    x = torch.relu(x)
    x = self.conv2(x)
    x = torch.relu(x)
    x = self.conv3(x)
    x = torch.relu(x)
    x = self.maxpool1(x)
    x = self.conv4(x)
    x = torch.relu(x)
    x = self.conv5(x)
    x = torch.relu(x)
    x = self.maxpool2(x)
    x = self.dropout(x)
    x = x.view(x.shape[0], x.shape[1], -1)
    x, _ = self.lstm(x)
    x = self.dense(x)
    x = x[:, -self.forecast_length :, :]
    return x

SPP_LSTM_Model_2 (Module)

Source code in torchhydro/models/spplstm.py
class SPP_LSTM_Model_2(nn.Module):
    def __init__(
        self,
        hindcast_length,
        forecast_length,
        p_n_output,
        p_n_hidden_states,
        p_dropout,
        p_in_channels,
        p_out_channels,
        len_c=None,
        s_hindcast_length=None,
        s_n_output=None,
        s_n_hidden_states=None,
        s_dropout=None,
        s_in_channels=None,
        s_out_channels=None,
    ):
        """Initializes the SPP_LSTM_Model_2.

        A custom neural network model for handling and integrating various types
        of meteorological and geographical data, including precipitation (p),
        soil (s), and basin attributes (c).

        Args:
            hindcast_length: The length of the input sequence for precipitation data.
            forecast_length: The length of the forecast period.
            p_n_output: Output dimension for the precipitation (p) data path.
            p_n_hidden_states: Number of hidden states in the LSTM for the precipitation path.
            p_dropout: Dropout rate applied in the precipitation path.
            p_in_channels: Number of input channels for the conv layer in the precipitation path.
            p_out_channels: Number of output channels for the conv layer in the precipitation path.
            len_c: Optional, the number of basin attribute (c) features.
            s_hindcast_length: Optional, hindcast length for the soil (s) data path.
            s_n_output: Optional, output dimension for the soil path.
            s_n_hidden_states: Optional, number of hidden states for the soil path LSTM.
            s_dropout: Optional, dropout rate for the soil path.
            s_in_channels: Optional, input channels for the soil path conv layer.
            s_out_channels: Optional, output channels for the soil path conv layer.
        """
        super(SPP_LSTM_Model_2, self).__init__()
        self.conv_p = nn.Conv2d(
            in_channels=p_in_channels,
            out_channels=p_out_channels,
            kernel_size=(3, 3),
            padding="same",
        )

        self.leaky_relu_p = nn.LeakyReLU(0.01)

        self.lstm_p = nn.LSTM(
            input_size=p_out_channels * 5 + len_c,
            hidden_size=p_n_hidden_states,
            batch_first=True,
        )

        self.dropout_p = nn.Dropout(p_dropout)

        self.fc_p = nn.Linear(p_n_hidden_states, p_n_output)

        self.spp_p = SppLayer([2, 1])

        self.p_length = hindcast_length + forecast_length
        self.forecast_length = forecast_length

        if s_hindcast_length is not None:
            self.conv_s = nn.Conv2d(
                in_channels=s_in_channels,
                out_channels=s_out_channels,
                kernel_size=(3, 3),
                padding="same",
            )

            self.leaky_relu_s = nn.LeakyReLU(0.01)
            self.sigmoid_s = nn.Sigmoid()

            self.lstm_s = nn.LSTM(
                input_size=s_out_channels * 5,
                hidden_size=s_n_hidden_states,
                batch_first=True,
            )

            self.dropout_s = nn.Dropout(s_dropout)

            self.fc_s = nn.Linear(s_n_hidden_states, s_n_output)

            self.spp_s = SppLayer([2, 1])

            self.s_length = s_hindcast_length

    def forward(self, *x_lst):
        # c and s must be None, g might be None
        if len(x_lst) == 1:
            x = x_lst[0]
            x = x.view(-1, x.shape[2], x.shape[3], x.shape[4])
            x = self.conv_p(x)
            x = self.leaky_relu_p(x)
            x = self.spp_p(x)
            x = x.view(x.shape[0], -1)
            x = x.view(int(x.shape[0] / (self.p_length)), self.p_length, -1)
            x, _ = self.lstm_p(x)
            x = self.dropout_p(x)
            x = self.fc_p(x)
        # g might be None. either c or s must be None, but not both
        elif len(x_lst) == 2:
            p = x_lst[0]
            m = x_lst[1].permute(1, 0, 2)
            # c is not None
            if m.dim() == 3:
                p = p.view(-1, p.shape[2], p.shape[3], p.shape[4])
                p = self.conv_p(p)
                p = self.leaky_relu_p(p)
                p = self.spp_p(p)
                p = p.view(p.shape[0], -1)
                p = p.view(int(p.shape[0] / (self.p_length)), self.p_length, -1)
                x = torch.cat([p, m], dim=2)
                x, _ = self.lstm_p(x)
                x = self.dropout_p(x)
                x = self.fc_p(x)
            # s is not None
            else:
                p = p.view(-1, p.shape[2], p.shape[3], p.shape[4])
                p = self.conv_p(p)
                p = self.leaky_relu_p(p)
                p = self.spp_p(p)
                p = p.view(p.shape[0], -1)
                p = p.view(int(p.shape[0] / (self.p_length)), self.p_length, -1)
                p, _ = self.lstm_p(p)
                p = self.dropout_p(p)
                p = self.fc_p(p)

                m = m.view(-1, m.shape[2], m.shape[3], m.shape[4])
                m = self.conv_s(m)
                m = self.leaky_relu_s(m)
                m = self.spp_s(m)
                m = m.view(m.shape[0], -1)
                m = m.view(int(m.shape[0] / (self.s_length)), self.s_length, -1)
                m, _ = self.lstm_s(m)
                m = m[:, -1:, :]
                m = self.dropout_s(m)
                m = self.fc_s(m)
                m = self.sigmoid_s(m)

                x = m * p
        # g might be None. Both s and c are not None
        elif len(x_lst) == 3:
            p = x_lst[0]
            c = x_lst[1].permute(1, 0, 2)
            s = x_lst[2]

            p = p.view(-1, p.shape[2], p.shape[3], p.shape[4])
            p = self.conv_p(p)
            p = self.leaky_relu_p(p)
            p = self.spp_p(p)
            p = p.view(p.shape[0], -1)
            p = p.view(int(p.shape[0] / (self.p_length)), self.p_length, -1)
            p_c = torch.cat([p, c], dim=2)
            p_c, _ = self.lstm_p(p_c)
            p_c = self.dropout_p(p_c)
            p_c = self.fc_p(p_c)

            s = s.view(-1, s.shape[2], s.shape[3], s.shape[4])
            s = self.conv_s(s)
            s = self.leaky_relu_s(s)
            s = self.spp_s(s)
            s = s.view(s.shape[0], -1)
            s = s.view(int(s.shape[0] / (self.s_length)), self.s_length, -1)
            s, _ = self.lstm_s(s)
            s = s[:, -1:, :]
            s = self.dropout_s(s)
            s = self.fc_s(s)
            s = self.sigmoid_s(s)

            x = s * p_c
        return x[:, -self.forecast_length :, :]

__init__(self, hindcast_length, forecast_length, p_n_output, p_n_hidden_states, p_dropout, p_in_channels, p_out_channels, len_c=None, s_hindcast_length=None, s_n_output=None, s_n_hidden_states=None, s_dropout=None, s_in_channels=None, s_out_channels=None) special

Initializes the SPP_LSTM_Model_2.

A custom neural network model for handling and integrating various types of meteorological and geographical data, including precipitation (p), soil (s), and basin attributes (c).

Parameters:

Name Type Description Default
hindcast_length

The length of the input sequence for precipitation data.

required
forecast_length

The length of the forecast period.

required
p_n_output

Output dimension for the precipitation (p) data path.

required
p_n_hidden_states

Number of hidden states in the LSTM for the precipitation path.

required
p_dropout

Dropout rate applied in the precipitation path.

required
p_in_channels

Number of input channels for the conv layer in the precipitation path.

required
p_out_channels

Number of output channels for the conv layer in the precipitation path.

required
len_c

Optional, the number of basin attribute (c) features.

None
s_hindcast_length

Optional, hindcast length for the soil (s) data path.

None
s_n_output

Optional, output dimension for the soil path.

None
s_n_hidden_states

Optional, number of hidden states for the soil path LSTM.

None
s_dropout

Optional, dropout rate for the soil path.

None
s_in_channels

Optional, input channels for the soil path conv layer.

None
s_out_channels

Optional, output channels for the soil path conv layer.

None
Source code in torchhydro/models/spplstm.py
def __init__(
    self,
    hindcast_length,
    forecast_length,
    p_n_output,
    p_n_hidden_states,
    p_dropout,
    p_in_channels,
    p_out_channels,
    len_c=None,
    s_hindcast_length=None,
    s_n_output=None,
    s_n_hidden_states=None,
    s_dropout=None,
    s_in_channels=None,
    s_out_channels=None,
):
    """Initializes the SPP_LSTM_Model_2.

    A custom neural network model for handling and integrating various types
    of meteorological and geographical data, including precipitation (p),
    soil (s), and basin attributes (c).

    Args:
        hindcast_length: The length of the input sequence for precipitation data.
        forecast_length: The length of the forecast period.
        p_n_output: Output dimension for the precipitation (p) data path.
        p_n_hidden_states: Number of hidden states in the LSTM for the precipitation path.
        p_dropout: Dropout rate applied in the precipitation path.
        p_in_channels: Number of input channels for the conv layer in the precipitation path.
        p_out_channels: Number of output channels for the conv layer in the precipitation path.
        len_c: Optional, the number of basin attribute (c) features.
        s_hindcast_length: Optional, hindcast length for the soil (s) data path.
        s_n_output: Optional, output dimension for the soil path.
        s_n_hidden_states: Optional, number of hidden states for the soil path LSTM.
        s_dropout: Optional, dropout rate for the soil path.
        s_in_channels: Optional, input channels for the soil path conv layer.
        s_out_channels: Optional, output channels for the soil path conv layer.
    """
    super(SPP_LSTM_Model_2, self).__init__()
    self.conv_p = nn.Conv2d(
        in_channels=p_in_channels,
        out_channels=p_out_channels,
        kernel_size=(3, 3),
        padding="same",
    )

    self.leaky_relu_p = nn.LeakyReLU(0.01)

    self.lstm_p = nn.LSTM(
        input_size=p_out_channels * 5 + len_c,
        hidden_size=p_n_hidden_states,
        batch_first=True,
    )

    self.dropout_p = nn.Dropout(p_dropout)

    self.fc_p = nn.Linear(p_n_hidden_states, p_n_output)

    self.spp_p = SppLayer([2, 1])

    self.p_length = hindcast_length + forecast_length
    self.forecast_length = forecast_length

    if s_hindcast_length is not None:
        self.conv_s = nn.Conv2d(
            in_channels=s_in_channels,
            out_channels=s_out_channels,
            kernel_size=(3, 3),
            padding="same",
        )

        self.leaky_relu_s = nn.LeakyReLU(0.01)
        self.sigmoid_s = nn.Sigmoid()

        self.lstm_s = nn.LSTM(
            input_size=s_out_channels * 5,
            hidden_size=s_n_hidden_states,
            batch_first=True,
        )

        self.dropout_s = nn.Dropout(s_dropout)

        self.fc_s = nn.Linear(s_n_hidden_states, s_n_output)

        self.spp_s = SppLayer([2, 1])

        self.s_length = s_hindcast_length

forward(self, *x_lst)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/spplstm.py
def forward(self, *x_lst):
    # c and s must be None, g might be None
    if len(x_lst) == 1:
        x = x_lst[0]
        x = x.view(-1, x.shape[2], x.shape[3], x.shape[4])
        x = self.conv_p(x)
        x = self.leaky_relu_p(x)
        x = self.spp_p(x)
        x = x.view(x.shape[0], -1)
        x = x.view(int(x.shape[0] / (self.p_length)), self.p_length, -1)
        x, _ = self.lstm_p(x)
        x = self.dropout_p(x)
        x = self.fc_p(x)
    # g might be None. either c or s must be None, but not both
    elif len(x_lst) == 2:
        p = x_lst[0]
        m = x_lst[1].permute(1, 0, 2)
        # c is not None
        if m.dim() == 3:
            p = p.view(-1, p.shape[2], p.shape[3], p.shape[4])
            p = self.conv_p(p)
            p = self.leaky_relu_p(p)
            p = self.spp_p(p)
            p = p.view(p.shape[0], -1)
            p = p.view(int(p.shape[0] / (self.p_length)), self.p_length, -1)
            x = torch.cat([p, m], dim=2)
            x, _ = self.lstm_p(x)
            x = self.dropout_p(x)
            x = self.fc_p(x)
        # s is not None
        else:
            p = p.view(-1, p.shape[2], p.shape[3], p.shape[4])
            p = self.conv_p(p)
            p = self.leaky_relu_p(p)
            p = self.spp_p(p)
            p = p.view(p.shape[0], -1)
            p = p.view(int(p.shape[0] / (self.p_length)), self.p_length, -1)
            p, _ = self.lstm_p(p)
            p = self.dropout_p(p)
            p = self.fc_p(p)

            m = m.view(-1, m.shape[2], m.shape[3], m.shape[4])
            m = self.conv_s(m)
            m = self.leaky_relu_s(m)
            m = self.spp_s(m)
            m = m.view(m.shape[0], -1)
            m = m.view(int(m.shape[0] / (self.s_length)), self.s_length, -1)
            m, _ = self.lstm_s(m)
            m = m[:, -1:, :]
            m = self.dropout_s(m)
            m = self.fc_s(m)
            m = self.sigmoid_s(m)

            x = m * p
    # g might be None. Both s and c are not None
    elif len(x_lst) == 3:
        p = x_lst[0]
        c = x_lst[1].permute(1, 0, 2)
        s = x_lst[2]

        p = p.view(-1, p.shape[2], p.shape[3], p.shape[4])
        p = self.conv_p(p)
        p = self.leaky_relu_p(p)
        p = self.spp_p(p)
        p = p.view(p.shape[0], -1)
        p = p.view(int(p.shape[0] / (self.p_length)), self.p_length, -1)
        p_c = torch.cat([p, c], dim=2)
        p_c, _ = self.lstm_p(p_c)
        p_c = self.dropout_p(p_c)
        p_c = self.fc_p(p_c)

        s = s.view(-1, s.shape[2], s.shape[3], s.shape[4])
        s = self.conv_s(s)
        s = self.leaky_relu_s(s)
        s = self.spp_s(s)
        s = s.view(s.shape[0], -1)
        s = s.view(int(s.shape[0] / (self.s_length)), self.s_length, -1)
        s, _ = self.lstm_s(s)
        s = s[:, -1:, :]
        s = self.dropout_s(s)
        s = self.fc_s(s)
        s = self.sigmoid_s(s)

        x = s * p_c
    return x[:, -self.forecast_length :, :]

SppLayer (Module)

Source code in torchhydro/models/spplstm.py
class SppLayer(nn.Module):
    def __init__(self, out_pool_size):
        """
        out_pool_size: a int vector of expected output size of max pooling layer
        """
        super(SppLayer, self).__init__()
        self.out_pool_size = out_pool_size
        self.pools = []
        for i in range(len(out_pool_size)):
            pool_i = nn.AdaptiveMaxPool2d(out_pool_size[i])
            self.pools.append(pool_i)

    def forward(self, previous_conv):
        """
        Parameters
        ----------
        previous_conv
            a tensor vector of previous convolution layer

        Returns
        -------
        torch.Tensor
            a tensor vector with shape [1 x n] is the concentration of multi-level pooling
        """

        num_sample = previous_conv.size(0)
        channel_size = previous_conv.size(1)
        out_pool_size = self.out_pool_size
        for i in range(len(out_pool_size)):
            maxpool = self.pools[i]
            x = maxpool(previous_conv)
            if i == 0:
                spp = x.view(num_sample, channel_size, -1)
            else:
                spp = torch.cat((spp, x.view(num_sample, channel_size, -1)), -1)
        return spp

__init__(self, out_pool_size) special

out_pool_size: a int vector of expected output size of max pooling layer

Source code in torchhydro/models/spplstm.py
def __init__(self, out_pool_size):
    """
    out_pool_size: a int vector of expected output size of max pooling layer
    """
    super(SppLayer, self).__init__()
    self.out_pool_size = out_pool_size
    self.pools = []
    for i in range(len(out_pool_size)):
        pool_i = nn.AdaptiveMaxPool2d(out_pool_size[i])
        self.pools.append(pool_i)

forward(self, previous_conv)

Parameters

previous_conv a tensor vector of previous convolution layer

Returns

torch.Tensor a tensor vector with shape [1 x n] is the concentration of multi-level pooling

Source code in torchhydro/models/spplstm.py
def forward(self, previous_conv):
    """
    Parameters
    ----------
    previous_conv
        a tensor vector of previous convolution layer

    Returns
    -------
    torch.Tensor
        a tensor vector with shape [1 x n] is the concentration of multi-level pooling
    """

    num_sample = previous_conv.size(0)
    channel_size = previous_conv.size(1)
    out_pool_size = self.out_pool_size
    for i in range(len(out_pool_size)):
        maxpool = self.pools[i]
        x = maxpool(previous_conv)
        if i == 0:
            spp = x.view(num_sample, channel_size, -1)
        else:
            spp = torch.cat((spp, x.view(num_sample, channel_size, -1)), -1)
    return spp

TimeDistributed (Module)

Source code in torchhydro/models/spplstm.py
class TimeDistributed(nn.Module):
    def __init__(self, layer):
        super(TimeDistributed, self).__init__()
        self.layer = layer

    def forward(self, x):
        outputs = []
        for t in range(x.size(1)):
            xt = x[:, t, :]
            output = self.layer(xt)
            outputs.append(output.unsqueeze(1))
        outputs = torch.cat(outputs, dim=1)
        return outputs

forward(self, x)

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in torchhydro/models/spplstm.py
def forward(self, x):
    outputs = []
    for t in range(x.size(1)):
        xt = x[:, t, :]
        output = self.layer(xt)
        outputs.append(output.unsqueeze(1))
    outputs = torch.cat(outputs, dim=1)
    return outputs