Skip to content

Trainers API

deep_hydro

Author: Wenyu Ouyang Date: 2024-04-08 18:15:48 LastEditTime: 2025-07-13 16:25:31 LastEditors: Wenyu Ouyang Description: HydroDL model class FilePath: orchhydro orchhydro rainers\deep_hydro.py Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved.

DeepHydro (DeepHydroInterface)

The Base Trainer class for Hydrological Deep Learning models

Source code in torchhydro/trainers/deep_hydro.py
class DeepHydro(DeepHydroInterface):
    """
    The Base Trainer class for Hydrological Deep Learning models
    """

    def __init__(
        self,
        cfgs: Dict,
        pre_model=None,
    ):
        """
        Parameters
        ----------
        cfgs
            configs for the model
        pre_model
            a pre-trained model, if it is not None,
            we will use its weights to initialize this model
            by default None
        """
        super().__init__(cfgs)
        # Initialize fabric based on configuration
        self.fabric = create_fabric_wrapper(cfgs.get("training_cfgs", {}))
        self.pre_model = pre_model
        self.model = self.fabric.setup_module(self.load_model())
        if cfgs["training_cfgs"]["train_mode"]:
            self.traindataset = self.make_dataset("train")
            if cfgs["data_cfgs"]["t_range_valid"] is not None:
                self.validdataset = self.make_dataset("valid")
        self.testdataset: BaseDataset = self.make_dataset("test")

    @property
    def device(self):
        """Get the device from fabric wrapper"""
        return self.fabric._device

    def load_model(self, mode="train"):
        """
        Load a time series forecast model in pytorch_model_dict in model_dict_function.py

        Returns
        -------
        object
            model in pytorch_model_dict in model_dict_function.py
        """
        if mode == "infer":
            if self.weight_path is None or self.cfgs["model_cfgs"]["continue_train"]:
                # if no weight path is provided
                # or weight file is provided but continue train again,
                # we will use the trained model in the new case_dir directory
                self.weight_path = self._get_trained_model()
        elif mode != "train":
            raise ValueError("Invalid mode; must be 'train' or 'infer'")
        model_cfgs = self.cfgs["model_cfgs"]
        model_name = model_cfgs["model_name"]
        if model_name not in pytorch_model_dict:
            raise NotImplementedError(
                f"Error the model {model_name} was not found in the model dict. Please add it."
            )
        if self.pre_model is not None:
            return self._load_pretrain_model()
        elif self.weight_path is not None:
            return self._load_model_from_pth()
        else:
            return pytorch_model_dict[model_name](**model_cfgs["model_hyperparam"])

    def _load_pretrain_model(self):
        """load a pretrained model as the initial model"""
        return self.pre_model

    def _load_model_from_pth(self):
        weight_path = self.weight_path
        model_cfgs = self.cfgs["model_cfgs"]
        model_name = model_cfgs["model_name"]
        model = pytorch_model_dict[model_name](**model_cfgs["model_hyperparam"])
        checkpoint = torch.load(weight_path, map_location=self.device)
        model.load_state_dict(checkpoint)
        print("Weights sucessfully loaded")
        return model

    def make_dataset(self, is_tra_val_te: str):
        """
        Initializes a pytorch dataset.

        Parameters
        ----------
        is_tra_val_te
            train or valid or test

        Returns
        -------
        object
            an object initializing from class in datasets_dict in data_dict.py
        """
        data_cfgs = self.cfgs["data_cfgs"]
        dataset_name = data_cfgs["dataset"]

        if dataset_name in list(datasets_dict.keys()):
            dataset = datasets_dict[dataset_name](self.cfgs, is_tra_val_te)
        else:
            raise NotImplementedError(
                f"Error the dataset {str(dataset_name)} was not found in the dataset dict. Please add it."
            )
        return dataset

    def model_train(self) -> None:
        """train a hydrological DL model"""
        # A dictionary of the necessary parameters for training
        training_cfgs = self.cfgs["training_cfgs"]
        # The file path to load model weights from; defaults to "model_save"
        model_filepath = self.cfgs["data_cfgs"]["case_dir"]
        data_cfgs = self.cfgs["data_cfgs"]
        es = None
        if training_cfgs["early_stopping"]:
            es = EarlyStopper(training_cfgs["patience"])
        criterion = self._get_loss_func(training_cfgs)
        opt = self._get_optimizer(training_cfgs)
        scheduler = self._get_scheduler(training_cfgs, opt)
        max_epochs = training_cfgs["epochs"]
        start_epoch = training_cfgs["start_epoch"]
        # use PyTorch's DataLoader to load the data into batches in each epoch
        data_loader, validation_data_loader = self._get_dataloader(
            training_cfgs, data_cfgs
        )
        logger = TrainLogger(model_filepath, self.cfgs, opt)
        for epoch in range(start_epoch, max_epochs + 1):
            with logger.log_epoch_train(epoch) as train_logs:
                total_loss, n_iter_ep = torch_single_train(
                    self.model,
                    opt,
                    criterion,
                    data_loader,
                    device=self.device,
                    which_first_tensor=training_cfgs["which_first_tensor"],
                )
                train_logs["train_loss"] = total_loss
                train_logs["model"] = self.model

            valid_loss = None
            valid_metrics = None
            if data_cfgs["t_range_valid"] is not None:
                with logger.log_epoch_valid(epoch) as valid_logs:
                    valid_loss, valid_metrics = self._1epoch_valid(
                        training_cfgs, criterion, validation_data_loader, valid_logs
                    )

            self._scheduler_step(training_cfgs, scheduler, valid_loss)
            logger.save_session_param(
                epoch, total_loss, n_iter_ep, valid_loss, valid_metrics
            )
            logger.save_model_and_params(self.model, epoch, self.cfgs)
            if es and not es.check_loss(
                self.model,
                valid_loss,
                self.cfgs["data_cfgs"]["case_dir"],
            ):
                print("Stopping model now")
                break
        # logger.plot_model_structure(self.model)
        logger.tb.close()

        # return the trained model weights and bias and the epoch loss
        return self.model.state_dict(), sum(logger.epoch_loss) / len(logger.epoch_loss)

    def _get_scheduler(self, training_cfgs, opt):
        lr_scheduler_cfg = training_cfgs["lr_scheduler"]

        if "lr" in lr_scheduler_cfg and "lr_factor" not in lr_scheduler_cfg:
            scheduler = LambdaLR(opt, lr_lambda=lambda epoch: 1.0)
        elif isinstance(lr_scheduler_cfg, dict) and all(
            isinstance(epoch, int) for epoch in lr_scheduler_cfg
        ):
            # piecewise constant learning rate
            epochs = sorted(lr_scheduler_cfg.keys())
            values = [lr_scheduler_cfg[e] for e in epochs]

            def lr_lambda(epoch):
                idx = bisect.bisect_right(epochs, epoch) - 1
                return 1.0 if idx < 0 else values[idx]

            scheduler = LambdaLR(opt, lr_lambda=lr_lambda)
        elif "lr_factor" in lr_scheduler_cfg and "lr_patience" not in lr_scheduler_cfg:
            scheduler = ExponentialLR(opt, gamma=lr_scheduler_cfg["lr_factor"])
        elif "lr_factor" in lr_scheduler_cfg:
            scheduler = ReduceLROnPlateau(
                opt,
                mode="min",
                factor=lr_scheduler_cfg["lr_factor"],
                patience=lr_scheduler_cfg["lr_patience"],
            )
        else:
            raise ValueError("Invalid lr_scheduler configuration")

        return scheduler

    def _scheduler_step(self, training_cfgs, scheduler, valid_loss):
        lr_scheduler_cfg = training_cfgs["lr_scheduler"]
        required_keys = {"lr_factor", "lr_patience"}
        if required_keys.issubset(lr_scheduler_cfg.keys()):
            scheduler.step(valid_loss)
        else:
            scheduler.step()

    def _1epoch_valid(
        self, training_cfgs, criterion, validation_data_loader, valid_logs
    ):
        valid_obss_np, valid_preds_np, valid_loss = compute_validation(
            self.model,
            criterion,
            validation_data_loader,
            device=self.device,
            which_first_tensor=training_cfgs["which_first_tensor"],
        )
        valid_logs["valid_loss"] = valid_loss
        if (
            self.cfgs["training_cfgs"]["valid_batch_mode"] == "test"
            and self.cfgs["training_cfgs"]["calc_metrics"]
        ):
            # NOTE: Now we only evaluate the metrics for test-mode validation
            target_col = self.cfgs["data_cfgs"]["target_cols"]
            valid_metrics = evaluate_validation(
                validation_data_loader,
                valid_preds_np,
                valid_obss_np,
                self.cfgs["evaluation_cfgs"],
                target_col,
            )
            valid_logs["valid_metrics"] = valid_metrics
            return valid_loss, valid_metrics
        return valid_loss, None

    def _get_trained_model(self):
        model_loader = self.cfgs["evaluation_cfgs"]["model_loader"]
        model_pth_dir = self.cfgs["data_cfgs"]["case_dir"]
        return read_pth_from_model_loader(model_loader, model_pth_dir)

    def model_evaluate(self) -> Tuple[Dict, np.array, np.array]:
        """
        A function to evaluate a model, called at end of training.

        Returns
        -------
        tuple[dict, np.array, np.array]
            eval_log, denormalized predictions and observations
        """
        self.model = self.load_model(mode="infer").to(self.device)
        preds_xr, obss_xr = self.inference()
        return preds_xr, obss_xr

    def inference(self) -> Tuple[xr.Dataset, xr.Dataset]:
        """infer using trained model and unnormalized results"""
        data_cfgs = self.cfgs["data_cfgs"]
        training_cfgs = self.cfgs["training_cfgs"]
        test_dataloader = self._get_dataloader(training_cfgs, data_cfgs, mode="infer")
        seq_first = training_cfgs["which_first_tensor"] == "sequence"
        self.model.eval()
        # here the batch is just an index of lookup table, so any batch size could be chosen
        test_preds = []
        obss = []
        with torch.no_grad():
            test_preds = []
            obss = []
            for i, batch in enumerate(
                tqdm(test_dataloader, desc="Model inference", unit="batch")
            ):
                ys, pred = model_infer(
                    seq_first,
                    self.device,
                    self.model,
                    batch,
                    variable_length_cfgs=None,
                    return_key=(
                        self.cfgs.get("evaluation_cfgs", {})
                        .get("evaluator", {})
                        .get("return_key", None)
                    )

                )

                test_preds.append(pred.cpu())
                obss.append(ys.cpu())
                if i % 100 == 0:
                    torch.cuda.empty_cache()
            pred = torch.cat(test_preds, dim=0).numpy()  # 在最后转换为numpy
            obs = torch.cat(obss, dim=0).numpy()  # 在最后转换为numpy
        if pred.ndim == 2:
            # TODO: check
            # the ndim is 2 meaning we use an Nto1 mode
            # as lookup table is (basin 1's all time length, basin 2's all time length, ...)
            # params of reshape should be (basin size, time length)
            pred = pred.flatten().reshape(test_dataloader.test_data.y.shape[0], -1, 1)
            obs = obs.flatten().reshape(test_dataloader.test_data.y.shape[0], -1, 1)
        evaluation_cfgs = self.cfgs["evaluation_cfgs"]
        obs_xr, pred_xr = get_preds_to_be_eval(
            test_dataloader,
            evaluation_cfgs,
            pred,
            obs,
        )
        return pred_xr, obs_xr

    def _get_optimizer(self, training_cfgs):
        params_in_opt = self.model.parameters()
        return pytorch_opt_dict[training_cfgs["optimizer"]](
            params_in_opt, **training_cfgs["optim_params"]
        )

    def _get_loss_func(self, training_cfgs):
        criterion_init_params = {}
        if "criterion_params" in training_cfgs:
            loss_param = training_cfgs["criterion_params"]
            if loss_param is not None:
                for key in loss_param.keys():
                    if key == "loss_funcs":
                        criterion_init_params[key] = pytorch_criterion_dict[
                            loss_param[key]
                        ]()
                    else:
                        criterion_init_params[key] = loss_param[key]
        return pytorch_criterion_dict[training_cfgs["criterion"]](
            **criterion_init_params
        )

    def _flood_event_collate_fn(self, batch):
        """自定义的洪水事件 collate 函数,确保所有样本长度一致"""

        # 找到这个批次中最长的序列长度
        max_len = max(tensor_data[0].shape[0] for tensor_data in batch)

        # 调整所有样本到相同长度
        processed_batch = []
        for tensor_data in batch:
            # 获取x和y(假设tensor_data[0]是x,tensor_data[1]是y)
            x = tensor_data[0]
            y = tensor_data[1] if len(tensor_data) > 1 else None

            current_len = x.shape[0]
            if current_len < max_len:
                # 使用最后一个值填充x
                padding_x = x[-1:].repeat(max_len - current_len, 1)
                padded_x = torch.cat([x, padding_x], dim=0)

                # 如果有y,也进行填充
                if y is not None:
                    padding_y = y[-1:].repeat(max_len - current_len, 1)
                    padded_y = torch.cat([y, padding_y], dim=0)
                else:
                    padded_y = None
            else:
                # 如果更长则截断
                padded_x = x[:max_len]
                padded_y = y[:max_len] if y is not None else None

            if padded_y is not None:
                processed_batch.append((padded_x, padded_y))
            else:
                processed_batch.append(padded_x)

        # 根据数据结构返回堆叠后的结果
        if len(processed_batch) > 0 and isinstance(processed_batch[0], tuple):
            return (
                torch.stack([x for x, _ in processed_batch], 0),
                torch.stack([y for _, y in processed_batch], 0)
            )
        else:
            return torch.stack(processed_batch, 0)

    def _get_dataloader(self, training_cfgs, data_cfgs, mode="train"):
        if mode == "infer":
            _collate_fn = None
            # Use GNN collate function for GNN datasets in inference mode
            if hasattr(self.testdataset, '__class__') and 'GNN' in self.testdataset.__class__.__name__:
                _collate_fn = gnn_collate_fn
            # 使用自定义的 collate 函数处理 FloodEventDataset
            elif hasattr(self.testdataset, '__class__') and 'FloodEvent' in self.testdataset.__class__.__name__:
                _collate_fn = self._flood_event_collate_fn
            return DataLoader(
                self.testdataset,
                batch_size=training_cfgs["batch_size"],
                shuffle=False,
                sampler=None,
                batch_sampler=None,
                drop_last=False,
                timeout=0,
                worker_init_fn=None,
                collate_fn=_collate_fn,
            )
        worker_num = 0
        pin_memory = False
        if "num_workers" in training_cfgs:
            worker_num = training_cfgs["num_workers"]
            print(f"using {str(worker_num)} workers")
        if "pin_memory" in training_cfgs:
            pin_memory = training_cfgs["pin_memory"]
            print(f"Pin memory set to {str(pin_memory)}")
        sampler = self._get_sampler(data_cfgs, training_cfgs, self.traindataset)
        _collate_fn = None
        if training_cfgs["variable_length_cfgs"]["use_variable_length"]:
            _collate_fn = varied_length_collate_fn
        # Use GNN collate function for GNN datasets
        elif hasattr(self.traindataset, '__class__') and 'GNN' in self.traindataset.__class__.__name__:
            _collate_fn = gnn_collate_fn
        data_loader = DataLoader(
            self.traindataset,
            batch_size=training_cfgs["batch_size"],
            shuffle=(sampler is None),
            sampler=sampler,
            num_workers=worker_num,
            pin_memory=pin_memory,
            timeout=0,
            collate_fn=_collate_fn,
        )
        if data_cfgs["t_range_valid"] is not None:
            # Use the same collate function for validation dataset
            _val_collate_fn = None
            if training_cfgs["variable_length_cfgs"]["use_variable_length"]:
                _val_collate_fn = varied_length_collate_fn
            elif hasattr(self.validdataset, '__class__') and 'GNN' in self.validdataset.__class__.__name__:
                _val_collate_fn = gnn_collate_fn

            validation_data_loader = DataLoader(
                self.validdataset,
                batch_size=training_cfgs["batch_size"],
                shuffle=False,
                num_workers=worker_num,
                pin_memory=pin_memory,
                timeout=0,
                collate_fn=_val_collate_fn,
            )
            return data_loader, validation_data_loader

        return data_loader, None

    def _get_sampler(self, data_cfgs, training_cfgs, train_dataset):
        """
        return data sampler based on the provided configuration and training dataset.

        Parameters
        ----------
        data_cfgs : dict
            Configuration dictionary containing parameters for data sampling. Expected keys are:
            - "sampler": dict, containing:
            - "name": str, name of the sampler to use.
            - "sampler_hyperparam": dict, optional hyperparameters for the sampler.
        training_cfgs: dict
            Configuration dictionary containing parameters for training. Expected keys are:
            - "batch_size": int, size of each batch.
        train_dataset : Dataset
            The training dataset object which contains the data to be sampled. Expected attributes are:
            - ngrid: int, number of grids in the dataset.
            - nt: int, number of time steps in the dataset.
            - rho: int, length of the input sequence.
            - warmup_length: int, length of the warmup period.
            - horizon: int, length of the forecast horizon.

        Returns
        -------
        sampler_class
            An instance of the specified sampler class, initialized with the provided dataset and hyperparameters.

        Raises
        ------
        NotImplementedError
            If the specified sampler name is not found in the `data_sampler_dict`.
        """
        if data_cfgs["sampler"] is None:
            return None
        batch_size = training_cfgs["batch_size"]
        rho = train_dataset.rho
        warmup_length = train_dataset.warmup_length
        horizon = train_dataset.horizon
        ngrid = train_dataset.ngrid
        nt = train_dataset.nt
        sampler_name = data_cfgs["sampler"]
        if sampler_name not in data_sampler_dict:
            raise NotImplementedError(f"Sampler {sampler_name} not implemented yet")
        sampler_class = data_sampler_dict[sampler_name]
        sampler_hyperparam = {}
        if sampler_name == "KuaiSampler":
            sampler_hyperparam |= {
                "batch_size": batch_size,
                "warmup_length": warmup_length,
                "rho_horizon": rho + horizon,
                "ngrid": ngrid,
                "nt": nt,
            }
        elif sampler_name == "WindowLenBatchSampler":
            sampler_hyperparam |= {
                "batch_size": batch_size,
            }

        return sampler_class(train_dataset, **sampler_hyperparam)

device property readonly

Get the device from fabric wrapper

__init__(self, cfgs, pre_model=None) special

Parameters

cfgs configs for the model pre_model a pre-trained model, if it is not None, we will use its weights to initialize this model by default None

Source code in torchhydro/trainers/deep_hydro.py
def __init__(
    self,
    cfgs: Dict,
    pre_model=None,
):
    """
    Parameters
    ----------
    cfgs
        configs for the model
    pre_model
        a pre-trained model, if it is not None,
        we will use its weights to initialize this model
        by default None
    """
    super().__init__(cfgs)
    # Initialize fabric based on configuration
    self.fabric = create_fabric_wrapper(cfgs.get("training_cfgs", {}))
    self.pre_model = pre_model
    self.model = self.fabric.setup_module(self.load_model())
    if cfgs["training_cfgs"]["train_mode"]:
        self.traindataset = self.make_dataset("train")
        if cfgs["data_cfgs"]["t_range_valid"] is not None:
            self.validdataset = self.make_dataset("valid")
    self.testdataset: BaseDataset = self.make_dataset("test")

inference(self)

infer using trained model and unnormalized results

Source code in torchhydro/trainers/deep_hydro.py
def inference(self) -> Tuple[xr.Dataset, xr.Dataset]:
    """infer using trained model and unnormalized results"""
    data_cfgs = self.cfgs["data_cfgs"]
    training_cfgs = self.cfgs["training_cfgs"]
    test_dataloader = self._get_dataloader(training_cfgs, data_cfgs, mode="infer")
    seq_first = training_cfgs["which_first_tensor"] == "sequence"
    self.model.eval()
    # here the batch is just an index of lookup table, so any batch size could be chosen
    test_preds = []
    obss = []
    with torch.no_grad():
        test_preds = []
        obss = []
        for i, batch in enumerate(
            tqdm(test_dataloader, desc="Model inference", unit="batch")
        ):
            ys, pred = model_infer(
                seq_first,
                self.device,
                self.model,
                batch,
                variable_length_cfgs=None,
                return_key=(
                    self.cfgs.get("evaluation_cfgs", {})
                    .get("evaluator", {})
                    .get("return_key", None)
                )

            )

            test_preds.append(pred.cpu())
            obss.append(ys.cpu())
            if i % 100 == 0:
                torch.cuda.empty_cache()
        pred = torch.cat(test_preds, dim=0).numpy()  # 在最后转换为numpy
        obs = torch.cat(obss, dim=0).numpy()  # 在最后转换为numpy
    if pred.ndim == 2:
        # TODO: check
        # the ndim is 2 meaning we use an Nto1 mode
        # as lookup table is (basin 1's all time length, basin 2's all time length, ...)
        # params of reshape should be (basin size, time length)
        pred = pred.flatten().reshape(test_dataloader.test_data.y.shape[0], -1, 1)
        obs = obs.flatten().reshape(test_dataloader.test_data.y.shape[0], -1, 1)
    evaluation_cfgs = self.cfgs["evaluation_cfgs"]
    obs_xr, pred_xr = get_preds_to_be_eval(
        test_dataloader,
        evaluation_cfgs,
        pred,
        obs,
    )
    return pred_xr, obs_xr

load_model(self, mode='train')

Load a time series forecast model in pytorch_model_dict in model_dict_function.py

Returns

object model in pytorch_model_dict in model_dict_function.py

Source code in torchhydro/trainers/deep_hydro.py
def load_model(self, mode="train"):
    """
    Load a time series forecast model in pytorch_model_dict in model_dict_function.py

    Returns
    -------
    object
        model in pytorch_model_dict in model_dict_function.py
    """
    if mode == "infer":
        if self.weight_path is None or self.cfgs["model_cfgs"]["continue_train"]:
            # if no weight path is provided
            # or weight file is provided but continue train again,
            # we will use the trained model in the new case_dir directory
            self.weight_path = self._get_trained_model()
    elif mode != "train":
        raise ValueError("Invalid mode; must be 'train' or 'infer'")
    model_cfgs = self.cfgs["model_cfgs"]
    model_name = model_cfgs["model_name"]
    if model_name not in pytorch_model_dict:
        raise NotImplementedError(
            f"Error the model {model_name} was not found in the model dict. Please add it."
        )
    if self.pre_model is not None:
        return self._load_pretrain_model()
    elif self.weight_path is not None:
        return self._load_model_from_pth()
    else:
        return pytorch_model_dict[model_name](**model_cfgs["model_hyperparam"])

make_dataset(self, is_tra_val_te)

Initializes a pytorch dataset.

Parameters

is_tra_val_te train or valid or test

Returns

object an object initializing from class in datasets_dict in data_dict.py

Source code in torchhydro/trainers/deep_hydro.py
def make_dataset(self, is_tra_val_te: str):
    """
    Initializes a pytorch dataset.

    Parameters
    ----------
    is_tra_val_te
        train or valid or test

    Returns
    -------
    object
        an object initializing from class in datasets_dict in data_dict.py
    """
    data_cfgs = self.cfgs["data_cfgs"]
    dataset_name = data_cfgs["dataset"]

    if dataset_name in list(datasets_dict.keys()):
        dataset = datasets_dict[dataset_name](self.cfgs, is_tra_val_te)
    else:
        raise NotImplementedError(
            f"Error the dataset {str(dataset_name)} was not found in the dataset dict. Please add it."
        )
    return dataset

model_evaluate(self)

A function to evaluate a model, called at end of training.

Returns

tuple[dict, np.array, np.array] eval_log, denormalized predictions and observations

Source code in torchhydro/trainers/deep_hydro.py
def model_evaluate(self) -> Tuple[Dict, np.array, np.array]:
    """
    A function to evaluate a model, called at end of training.

    Returns
    -------
    tuple[dict, np.array, np.array]
        eval_log, denormalized predictions and observations
    """
    self.model = self.load_model(mode="infer").to(self.device)
    preds_xr, obss_xr = self.inference()
    return preds_xr, obss_xr

model_train(self)

train a hydrological DL model

Source code in torchhydro/trainers/deep_hydro.py
def model_train(self) -> None:
    """train a hydrological DL model"""
    # A dictionary of the necessary parameters for training
    training_cfgs = self.cfgs["training_cfgs"]
    # The file path to load model weights from; defaults to "model_save"
    model_filepath = self.cfgs["data_cfgs"]["case_dir"]
    data_cfgs = self.cfgs["data_cfgs"]
    es = None
    if training_cfgs["early_stopping"]:
        es = EarlyStopper(training_cfgs["patience"])
    criterion = self._get_loss_func(training_cfgs)
    opt = self._get_optimizer(training_cfgs)
    scheduler = self._get_scheduler(training_cfgs, opt)
    max_epochs = training_cfgs["epochs"]
    start_epoch = training_cfgs["start_epoch"]
    # use PyTorch's DataLoader to load the data into batches in each epoch
    data_loader, validation_data_loader = self._get_dataloader(
        training_cfgs, data_cfgs
    )
    logger = TrainLogger(model_filepath, self.cfgs, opt)
    for epoch in range(start_epoch, max_epochs + 1):
        with logger.log_epoch_train(epoch) as train_logs:
            total_loss, n_iter_ep = torch_single_train(
                self.model,
                opt,
                criterion,
                data_loader,
                device=self.device,
                which_first_tensor=training_cfgs["which_first_tensor"],
            )
            train_logs["train_loss"] = total_loss
            train_logs["model"] = self.model

        valid_loss = None
        valid_metrics = None
        if data_cfgs["t_range_valid"] is not None:
            with logger.log_epoch_valid(epoch) as valid_logs:
                valid_loss, valid_metrics = self._1epoch_valid(
                    training_cfgs, criterion, validation_data_loader, valid_logs
                )

        self._scheduler_step(training_cfgs, scheduler, valid_loss)
        logger.save_session_param(
            epoch, total_loss, n_iter_ep, valid_loss, valid_metrics
        )
        logger.save_model_and_params(self.model, epoch, self.cfgs)
        if es and not es.check_loss(
            self.model,
            valid_loss,
            self.cfgs["data_cfgs"]["case_dir"],
        ):
            print("Stopping model now")
            break
    # logger.plot_model_structure(self.model)
    logger.tb.close()

    # return the trained model weights and bias and the epoch loss
    return self.model.state_dict(), sum(logger.epoch_loss) / len(logger.epoch_loss)

DeepHydroInterface (ABC)

An abstract class used to handle different configurations of hydrological deep learning models + hyperparams for training, test, and predict functions. This class assumes that data is already split into test train and validation at this point.

Source code in torchhydro/trainers/deep_hydro.py
class DeepHydroInterface(ABC):
    """
    An abstract class used to handle different configurations
    of hydrological deep learning models + hyperparams for training, test, and predict functions.
    This class assumes that data is already split into test train and validation at this point.
    """

    def __init__(self, cfgs: Dict):
        """
        Parameters
        ----------
        cfgs
            configs for initializing DeepHydro
        """

        self._cfgs = cfgs

    @property
    def cfgs(self):
        """all configs"""
        return self._cfgs

    @property
    def weight_path(self):
        """weight path"""
        return self._cfgs["model_cfgs"]["weight_path"]

    @weight_path.setter
    def weight_path(self, weight_path):
        self._cfgs["model_cfgs"]["weight_path"] = weight_path

    @abstractmethod
    def load_model(self, mode="train") -> object:
        """Get a Hydro DL model"""
        raise NotImplementedError

    @abstractmethod
    def make_dataset(self, is_tra_val_te: str) -> object:
        """
        Initializes a pytorch dataset.

        Parameters
        ----------
        is_tra_val_te
            train or valid or test

        Returns
        -------
        object
            a dataset class loading data from data source
        """
        raise NotImplementedError

    @abstractmethod
    def model_train(self):
        """
        Train the model
        """
        raise NotImplementedError

    @abstractmethod
    def model_evaluate(self):
        """
        Evaluate the model
        """
        raise NotImplementedError

cfgs property readonly

all configs

weight_path property writable

weight path

__init__(self, cfgs) special

Parameters

cfgs configs for initializing DeepHydro

Source code in torchhydro/trainers/deep_hydro.py
def __init__(self, cfgs: Dict):
    """
    Parameters
    ----------
    cfgs
        configs for initializing DeepHydro
    """

    self._cfgs = cfgs

load_model(self, mode='train')

Get a Hydro DL model

Source code in torchhydro/trainers/deep_hydro.py
@abstractmethod
def load_model(self, mode="train") -> object:
    """Get a Hydro DL model"""
    raise NotImplementedError

make_dataset(self, is_tra_val_te)

Initializes a pytorch dataset.

Parameters

is_tra_val_te train or valid or test

Returns

object a dataset class loading data from data source

Source code in torchhydro/trainers/deep_hydro.py
@abstractmethod
def make_dataset(self, is_tra_val_te: str) -> object:
    """
    Initializes a pytorch dataset.

    Parameters
    ----------
    is_tra_val_te
        train or valid or test

    Returns
    -------
    object
        a dataset class loading data from data source
    """
    raise NotImplementedError

model_evaluate(self)

Evaluate the model

Source code in torchhydro/trainers/deep_hydro.py
@abstractmethod
def model_evaluate(self):
    """
    Evaluate the model
    """
    raise NotImplementedError

model_train(self)

Train the model

Source code in torchhydro/trainers/deep_hydro.py
@abstractmethod
def model_train(self):
    """
    Train the model
    """
    raise NotImplementedError

FedLearnHydro (DeepHydro)

Federated Learning Hydrological DL model

Source code in torchhydro/trainers/deep_hydro.py
class FedLearnHydro(DeepHydro):
    """Federated Learning Hydrological DL model"""

    def __init__(self, cfgs: Dict):
        super().__init__(cfgs)
        # a user group which is a dict where the keys are the user index
        # and the values are the corresponding data for each of those users
        train_dataset = self.traindataset
        fl_hyperparam = self.cfgs["model_cfgs"]["fl_hyperparam"]
        # sample training data amongst users
        if fl_hyperparam["fl_sample"] == "basin":
            # Sample a basin for a user
            user_groups = fl_sample_basin(train_dataset)
        elif fl_hyperparam["fl_sample"] == "region":
            # Sample a region for a user
            user_groups = fl_sample_region(train_dataset)
        else:
            raise NotImplementedError()
        self.user_groups = user_groups

    @property
    def num_users(self):
        """number of users in federated learning"""
        return len(self.user_groups)

    def model_train(self) -> None:
        # BUILD MODEL
        global_model = self.model

        # copy weights
        global_weights = global_model.state_dict()

        # Training
        train_loss, train_accuracy = [], []
        print_every = 2

        training_cfgs = self.cfgs["training_cfgs"]
        model_cfgs = self.cfgs["model_cfgs"]
        max_epochs = training_cfgs["epochs"]
        start_epoch = training_cfgs["start_epoch"]
        fl_hyperparam = model_cfgs["fl_hyperparam"]
        # total rounds in a FL system is max_epochs
        for epoch in tqdm(range(start_epoch, max_epochs + 1)):
            local_weights, local_losses = [], []
            print(f"\n | Global Training Round : {epoch} |\n")

            global_model.train()
            m = max(int(fl_hyperparam["fl_frac"] * self.num_users), 1)
            # randomly select m users, they will be the clients in this round
            idxs_users = np.random.choice(range(self.num_users), m, replace=False)

            for idx in idxs_users:
                # each user will be used to train the model locally
                # user_gourps[idx] means the idx of dataset for a user
                user_cfgs = self._get_a_user_cfgs(idx)
                local_model = DeepHydro(
                    user_cfgs,
                    pre_model=copy.deepcopy(global_model),
                )
                w, loss = local_model.model_train()
                local_weights.append(copy.deepcopy(w))
                local_losses.append(copy.deepcopy(loss))

            # update global weights
            global_weights = average_weights(local_weights)

            # update global weights
            global_model.load_state_dict(global_weights)

            loss_avg = sum(local_losses) / len(local_losses)
            train_loss.append(loss_avg)

            # Calculate avg training accuracy over all users at every epoch
            list_acc = []
            global_model.eval()
            for c in range(self.num_users):
                one_user_cfg = self._get_a_user_cfgs(c)
                local_model = DeepHydro(
                    one_user_cfg,
                    pre_model=global_model,
                )
                acc, _, _ = local_model.model_evaluate()
                list_acc.append(acc)
            values = [list(d.values())[0][0] for d in list_acc]
            filtered_values = [v for v in values if not np.isnan(v)]
            train_accuracy.append(sum(filtered_values) / len(filtered_values))

            # print global training loss after every 'i' rounds
            if (epoch + 1) % print_every == 0:
                print(f" \nAvg Training Stats after {epoch+1} global rounds:")
                print(f"Training Loss : {np.mean(np.array(train_loss))}")
                print("Train Accuracy: {:.2f}% \n".format(100 * train_accuracy[-1]))

    def _get_a_user_cfgs(self, idx):
        """To get a user's configs for local training"""
        user = self.user_groups[idx]

        # update data_cfgs
        # Use defaultdict to collect dates for each basin
        basin_dates = defaultdict(list)

        for _, (basin, time) in user.items():
            basin_dates[basin].append(time)

        # Initialize a list to store distinct basins
        basins = []

        # for each basin, we can find its date range
        date_ranges = {}
        for basin, times in basin_dates.items():
            basins.append(basin)
            date_ranges[basin] = (np.min(times), np.max(times))
        # get the longest date range
        longest_date_range = max(date_ranges.values(), key=lambda x: x[1] - x[0])
        # transform the date range of numpy data into string
        longest_date_range = [
            np.datetime_as_string(dt, unit="D") for dt in longest_date_range
        ]
        user_cfgs = copy.deepcopy(self.cfgs)
        # update data_cfgs
        update_nested_dict(
            user_cfgs, ["data_cfgs", "t_range_train"], longest_date_range
        )
        # for local training in FL, we don't need a validation set
        update_nested_dict(user_cfgs, ["data_cfgs", "t_range_valid"], None)
        # for local training in FL, we don't need a test set, but we should set one to avoid error
        update_nested_dict(user_cfgs, ["data_cfgs", "t_range_test"], longest_date_range)
        update_nested_dict(user_cfgs, ["data_cfgs", "object_ids"], basins)

        # update training_cfgs
        # we also need to update some training params for local training from FL settings
        update_nested_dict(
            user_cfgs,
            ["training_cfgs", "epochs"],
            user_cfgs["model_cfgs"]["fl_hyperparam"]["fl_local_ep"],
        )
        update_nested_dict(
            user_cfgs,
            ["evaluation_cfgs", "test_epoch"],
            user_cfgs["model_cfgs"]["fl_hyperparam"]["fl_local_ep"],
        )
        # don't need to save model weights for local training
        update_nested_dict(
            user_cfgs,
            ["training_cfgs", "save_epoch"],
            None,
        )
        # there are two settings for batch size in configs, we need to update both of them
        update_nested_dict(
            user_cfgs,
            ["training_cfgs", "batch_size"],
            user_cfgs["model_cfgs"]["fl_hyperparam"]["fl_local_bs"],
        )
        update_nested_dict(
            user_cfgs,
            ["data_cfgs", "batch_size"],
            user_cfgs["model_cfgs"]["fl_hyperparam"]["fl_local_bs"],
        )

        # update model_cfgs finally
        # For local model, its model_type is Normal
        update_nested_dict(user_cfgs, ["model_cfgs", "model_type"], "Normal")
        update_nested_dict(
            user_cfgs,
            ["model_cfgs", "fl_hyperparam"],
            None,
        )
        return user_cfgs

num_users property readonly

number of users in federated learning

model_train(self)

train a hydrological DL model

Source code in torchhydro/trainers/deep_hydro.py
def model_train(self) -> None:
    # BUILD MODEL
    global_model = self.model

    # copy weights
    global_weights = global_model.state_dict()

    # Training
    train_loss, train_accuracy = [], []
    print_every = 2

    training_cfgs = self.cfgs["training_cfgs"]
    model_cfgs = self.cfgs["model_cfgs"]
    max_epochs = training_cfgs["epochs"]
    start_epoch = training_cfgs["start_epoch"]
    fl_hyperparam = model_cfgs["fl_hyperparam"]
    # total rounds in a FL system is max_epochs
    for epoch in tqdm(range(start_epoch, max_epochs + 1)):
        local_weights, local_losses = [], []
        print(f"\n | Global Training Round : {epoch} |\n")

        global_model.train()
        m = max(int(fl_hyperparam["fl_frac"] * self.num_users), 1)
        # randomly select m users, they will be the clients in this round
        idxs_users = np.random.choice(range(self.num_users), m, replace=False)

        for idx in idxs_users:
            # each user will be used to train the model locally
            # user_gourps[idx] means the idx of dataset for a user
            user_cfgs = self._get_a_user_cfgs(idx)
            local_model = DeepHydro(
                user_cfgs,
                pre_model=copy.deepcopy(global_model),
            )
            w, loss = local_model.model_train()
            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))

        # update global weights
        global_weights = average_weights(local_weights)

        # update global weights
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # Calculate avg training accuracy over all users at every epoch
        list_acc = []
        global_model.eval()
        for c in range(self.num_users):
            one_user_cfg = self._get_a_user_cfgs(c)
            local_model = DeepHydro(
                one_user_cfg,
                pre_model=global_model,
            )
            acc, _, _ = local_model.model_evaluate()
            list_acc.append(acc)
        values = [list(d.values())[0][0] for d in list_acc]
        filtered_values = [v for v in values if not np.isnan(v)]
        train_accuracy.append(sum(filtered_values) / len(filtered_values))

        # print global training loss after every 'i' rounds
        if (epoch + 1) % print_every == 0:
            print(f" \nAvg Training Stats after {epoch+1} global rounds:")
            print(f"Training Loss : {np.mean(np.array(train_loss))}")
            print("Train Accuracy: {:.2f}% \n".format(100 * train_accuracy[-1]))

TransLearnHydro (DeepHydro)

Source code in torchhydro/trainers/deep_hydro.py
class TransLearnHydro(DeepHydro):
    def __init__(self, cfgs: Dict, pre_model=None):
        super().__init__(cfgs, pre_model)

    def load_model(self, mode="train"):
        """Load model for transfer learning"""
        model_cfgs = self.cfgs["model_cfgs"]
        if self.weight_path is None and self.pre_model is None:
            raise NotImplementedError(
                "For transfer learning, we need a pre-trained model"
            )
        if mode == "train":
            model = super().load_model(mode)
        elif mode == "infer":
            self.weight_path = self._get_trained_model()
            model = self._load_model_from_pth()
            model.to(self.device)
        if (
            "weight_path_add" in model_cfgs
            and "freeze_params" in model_cfgs["weight_path_add"]
        ):
            freeze_params = model_cfgs["weight_path_add"]["freeze_params"]
            for param in freeze_params:
                exec(f"model.{param}.requires_grad = False")
        return model

    def _load_model_from_pth(self):
        weight_path = self.weight_path
        model_cfgs = self.cfgs["model_cfgs"]
        model_name = model_cfgs["model_name"]
        model = pytorch_model_dict[model_name](**model_cfgs["model_hyperparam"])
        checkpoint = torch.load(weight_path, map_location=self.device)
        if "weight_path_add" in model_cfgs:
            if "excluded_layers" in model_cfgs["weight_path_add"]:
                # delete some layers from source model if we don't need them
                excluded_layers = model_cfgs["weight_path_add"]["excluded_layers"]
                for layer in excluded_layers:
                    del checkpoint[layer]
                print("sucessfully deleted layers")
            else:
                print("directly loading identically-named layers of source model")
        model.load_state_dict(checkpoint, strict=False)
        print("Weights sucessfully loaded")
        return model

load_model(self, mode='train')

Load model for transfer learning

Source code in torchhydro/trainers/deep_hydro.py
def load_model(self, mode="train"):
    """Load model for transfer learning"""
    model_cfgs = self.cfgs["model_cfgs"]
    if self.weight_path is None and self.pre_model is None:
        raise NotImplementedError(
            "For transfer learning, we need a pre-trained model"
        )
    if mode == "train":
        model = super().load_model(mode)
    elif mode == "infer":
        self.weight_path = self._get_trained_model()
        model = self._load_model_from_pth()
        model.to(self.device)
    if (
        "weight_path_add" in model_cfgs
        and "freeze_params" in model_cfgs["weight_path_add"]
    ):
        freeze_params = model_cfgs["weight_path_add"]["freeze_params"]
        for param in freeze_params:
            exec(f"model.{param}.requires_grad = False")
    return model

fabric_wrapper

Author: Wenyu Ouyang Date: 2023-07-25 16:47:19 LastEditTime: 2025-06-17 10:39:32 LastEditors: Wenyu Ouyang Description: Lightning Fabric wrapper for debugging and distributed training FilePath: orchhydro orchhydro rainers abric_wrapper.py Copyright (c) 2025-2026 Wenyu Ouyang. All rights reserved.

FabricWrapper

A wrapper class that can switch between Lightning Fabric and normal PyTorch operations based on configuration settings.

TODO: the fabric wrapper is not fully used for parallel training yet

Source code in torchhydro/trainers/fabric_wrapper.py
class FabricWrapper:
    """
    A wrapper class that can switch between Lightning Fabric and normal PyTorch operations
    based on configuration settings.

    TODO: the fabric wrapper is not fully used for parallel training yet
    """

    def __init__(self, use_fabric: bool = True, fabric_config: Optional[Dict] = None):
        """
        Initialize the Fabric wrapper.

        Parameters
        ----------
        use_fabric : bool
            Whether to use Lightning Fabric or normal PyTorch operations
        fabric_config : Optional[Dict]
            Configuration for Fabric (devices, strategy, etc.)
        """
        self.use_fabric = use_fabric
        self.fabric_config = fabric_config or {}
        self._fabric: Optional[Any] = None
        self._device: Optional[torch.device] = None

        if self.use_fabric:
            self._init_fabric()
        else:
            self._init_pytorch()

    def _init_fabric(self) -> None:
        """Initialize Lightning Fabric"""
        try:
            import lightning as L

            # Default fabric configuration
            default_config = {
                "accelerator": "auto",
                "devices": "auto",
                "strategy": "auto",
                "precision": "32-true",
            }

            # Update with user config
            default_config.update(self.fabric_config)

            self._fabric = L.Fabric(**default_config)
            print("✅ Lightning Fabric initialized successfully")

        except ImportError:
            print("❌ Lightning not found, falling back to normal PyTorch")
            self.use_fabric = False
            self._init_pytorch()

    def _init_pytorch(self) -> None:
        """Initialize normal PyTorch setup"""
        self.device_num = self.fabric_config["devices"]
        #self.device_num = [0]
        self._device = get_the_device(self.device_num)
        print(f"✅ Normal PyTorch initialized, using device: {self._device}")

    def setup_module(self, model: torch.nn.Module) -> torch.nn.Module:
        """Setup model for training"""
        if self.use_fabric:
            return self._fabric.setup_module(model)
        else:
            return model.to(self._device)

    def setup_optimizers(
        self, optimizer: torch.optim.Optimizer
    ) -> torch.optim.Optimizer:
        """Setup optimizer"""
        if self.use_fabric:
            return self._fabric.setup_optimizers(optimizer)
        else:
            return optimizer

    def setup_dataloaders(
        self, *dataloaders: torch.utils.data.DataLoader
    ) -> Tuple[torch.utils.data.DataLoader, ...]:
        """Setup dataloaders"""
        if self.use_fabric:
            return self._fabric.setup_dataloaders(*dataloaders)
        else:
            return dataloaders

    def save(self, path: str, state_dict: Dict[str, Any]) -> None:
        """Save model state"""
        if self.use_fabric:
            self._fabric.save(path, state_dict)
        else:
            torch.save(state_dict, path)

    def load(self, path: str, model: Optional[torch.nn.Module] = None) -> Any:
        """Load model state"""
        if self.use_fabric:
            return self._fabric.load(path, model)
        else:
            return torch.load(path, map_location=self._device)

    def load_raw(self, path: str, model: torch.nn.Module) -> None:
        """Load raw model weights"""
        if self.use_fabric:
            checkpoint = self._fabric.load(path)
            model.load_state_dict(checkpoint)
        else:
            checkpoint = torch.load(path, map_location=self._device)
            model.load_state_dict(checkpoint)

    def launch(self, fn: Optional[Any] = None, *args: Any, **kwargs: Any) -> Any:
        """Launch training function"""
        if self.use_fabric:
            if fn is None:
                # This is called without a function, just launch fabric
                return self._fabric.launch()
            else:
                return self._fabric.launch(fn, *args, **kwargs)
        else:
            # Normal PyTorch, just call the function directly
            if fn is not None:
                return fn(*args, **kwargs)
            else:
                return None

    def backward(self, loss: torch.Tensor) -> None:
        """Backward pass"""
        if self.use_fabric:
            self._fabric.backward(loss)
        else:
            loss.backward()

    def clip_gradients(
        self,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        max_norm: float = 1.0,
    ) -> None:
        """Clip gradients"""
        if self.use_fabric:
            self._fabric.clip_gradients(model, optimizer, max_norm=max_norm)
        else:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

    @property
    def device(self) -> torch.device:
        """Get current device"""
        if self.use_fabric:
            return self._fabric.device
        else:
            return self._device

    @property
    def local_rank(self) -> int:
        """Get local rank"""
        if self.use_fabric:
            return self._fabric.local_rank
        else:
            return 0

    @property
    def global_rank(self) -> int:
        """Get global rank"""
        if self.use_fabric:
            return self._fabric.global_rank
        else:
            return 0

    @property
    def world_size(self) -> int:
        """Get world size"""
        if self.use_fabric:
            return self._fabric.world_size
        else:
            return 1

    def barrier(self) -> None:
        """Synchronization barrier"""
        if self.use_fabric:
            self._fabric.barrier()
        else:
            pass  # No barrier needed for single process

    def print(self, *args: Any, **kwargs: Any) -> None:
        """Print only on rank 0"""
        if self.use_fabric:
            self._fabric.print(*args, **kwargs)
        else:
            print(*args, **kwargs)

device: device property readonly

Get current device

global_rank: int property readonly

Get global rank

local_rank: int property readonly

Get local rank

world_size: int property readonly

Get world size

__init__(self, use_fabric=True, fabric_config=None) special

Initialize the Fabric wrapper.

Parameters

use_fabric : bool Whether to use Lightning Fabric or normal PyTorch operations fabric_config : Optional[Dict] Configuration for Fabric (devices, strategy, etc.)

Source code in torchhydro/trainers/fabric_wrapper.py
def __init__(self, use_fabric: bool = True, fabric_config: Optional[Dict] = None):
    """
    Initialize the Fabric wrapper.

    Parameters
    ----------
    use_fabric : bool
        Whether to use Lightning Fabric or normal PyTorch operations
    fabric_config : Optional[Dict]
        Configuration for Fabric (devices, strategy, etc.)
    """
    self.use_fabric = use_fabric
    self.fabric_config = fabric_config or {}
    self._fabric: Optional[Any] = None
    self._device: Optional[torch.device] = None

    if self.use_fabric:
        self._init_fabric()
    else:
        self._init_pytorch()

backward(self, loss)

Backward pass

Source code in torchhydro/trainers/fabric_wrapper.py
def backward(self, loss: torch.Tensor) -> None:
    """Backward pass"""
    if self.use_fabric:
        self._fabric.backward(loss)
    else:
        loss.backward()

barrier(self)

Synchronization barrier

Source code in torchhydro/trainers/fabric_wrapper.py
def barrier(self) -> None:
    """Synchronization barrier"""
    if self.use_fabric:
        self._fabric.barrier()
    else:
        pass  # No barrier needed for single process

clip_gradients(self, model, optimizer, max_norm=1.0)

Clip gradients

Source code in torchhydro/trainers/fabric_wrapper.py
def clip_gradients(
    self,
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    max_norm: float = 1.0,
) -> None:
    """Clip gradients"""
    if self.use_fabric:
        self._fabric.clip_gradients(model, optimizer, max_norm=max_norm)
    else:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

launch(self, fn=None, *args, **kwargs)

Launch training function

Source code in torchhydro/trainers/fabric_wrapper.py
def launch(self, fn: Optional[Any] = None, *args: Any, **kwargs: Any) -> Any:
    """Launch training function"""
    if self.use_fabric:
        if fn is None:
            # This is called without a function, just launch fabric
            return self._fabric.launch()
        else:
            return self._fabric.launch(fn, *args, **kwargs)
    else:
        # Normal PyTorch, just call the function directly
        if fn is not None:
            return fn(*args, **kwargs)
        else:
            return None

load(self, path, model=None)

Load model state

Source code in torchhydro/trainers/fabric_wrapper.py
def load(self, path: str, model: Optional[torch.nn.Module] = None) -> Any:
    """Load model state"""
    if self.use_fabric:
        return self._fabric.load(path, model)
    else:
        return torch.load(path, map_location=self._device)

load_raw(self, path, model)

Load raw model weights

Source code in torchhydro/trainers/fabric_wrapper.py
def load_raw(self, path: str, model: torch.nn.Module) -> None:
    """Load raw model weights"""
    if self.use_fabric:
        checkpoint = self._fabric.load(path)
        model.load_state_dict(checkpoint)
    else:
        checkpoint = torch.load(path, map_location=self._device)
        model.load_state_dict(checkpoint)

print(self, *args, **kwargs)

Print only on rank 0

Source code in torchhydro/trainers/fabric_wrapper.py
def print(self, *args: Any, **kwargs: Any) -> None:
    """Print only on rank 0"""
    if self.use_fabric:
        self._fabric.print(*args, **kwargs)
    else:
        print(*args, **kwargs)

save(self, path, state_dict)

Save model state

Source code in torchhydro/trainers/fabric_wrapper.py
def save(self, path: str, state_dict: Dict[str, Any]) -> None:
    """Save model state"""
    if self.use_fabric:
        self._fabric.save(path, state_dict)
    else:
        torch.save(state_dict, path)

setup_dataloaders(self, *dataloaders)

Setup dataloaders

Source code in torchhydro/trainers/fabric_wrapper.py
def setup_dataloaders(
    self, *dataloaders: torch.utils.data.DataLoader
) -> Tuple[torch.utils.data.DataLoader, ...]:
    """Setup dataloaders"""
    if self.use_fabric:
        return self._fabric.setup_dataloaders(*dataloaders)
    else:
        return dataloaders

setup_module(self, model)

Setup model for training

Source code in torchhydro/trainers/fabric_wrapper.py
def setup_module(self, model: torch.nn.Module) -> torch.nn.Module:
    """Setup model for training"""
    if self.use_fabric:
        return self._fabric.setup_module(model)
    else:
        return model.to(self._device)

setup_optimizers(self, optimizer)

Setup optimizer

Source code in torchhydro/trainers/fabric_wrapper.py
def setup_optimizers(
    self, optimizer: torch.optim.Optimizer
) -> torch.optim.Optimizer:
    """Setup optimizer"""
    if self.use_fabric:
        return self._fabric.setup_optimizers(optimizer)
    else:
        return optimizer

create_fabric_wrapper(training_cfgs)

Create a fabric wrapper based on training configuration.

Parameters

training_cfgs : Dict Training configuration dictionary

Returns

FabricWrapper Initialized fabric wrapper

Source code in torchhydro/trainers/fabric_wrapper.py
def create_fabric_wrapper(training_cfgs: Dict) -> FabricWrapper:
    """
    Create a fabric wrapper based on training configuration.

    Parameters
    ----------
    training_cfgs : Dict
        Training configuration dictionary

    Returns
    -------
    FabricWrapper
        Initialized fabric wrapper
    """
    # Check if we should use fabric
    fabric_strategy = training_cfgs.get("fabric_strategy")
    use_fabric = fabric_strategy is not None

    # Check if we have multiple devices
    devices = training_cfgs.get("device", [0])
    if isinstance(devices, list) and len(devices) == 1 and use_fabric:
        print("📱 Single device detected - we can disable Fabric")
        use_fabric = False

    # Fabric configuration
    fabric_config = {
        "devices": devices if isinstance(devices, list) else [devices],
        "strategy": fabric_strategy,
        "precision": training_cfgs.get("precision", "32-true"),
        "accelerator": training_cfgs.get("accelerator", "auto"),
    }

    return FabricWrapper(use_fabric=use_fabric, fabric_config=fabric_config)

resulter

Resulter

Source code in torchhydro/trainers/resulter.py
class Resulter:
    def __init__(self, cfgs) -> None:
        self.cfgs = cfgs
        self.result_dir = cfgs["data_cfgs"]["case_dir"]
        if not os.path.exists(self.result_dir):
            os.makedirs(self.result_dir)

    @property
    def pred_name(self):
        return f"epoch{str(self.chosen_trained_epoch)}flow_pred"

    @property
    def obs_name(self):
        return f"epoch{str(self.chosen_trained_epoch)}flow_obs"

    @property
    def chosen_trained_epoch(self):
        model_loader = self.cfgs["evaluation_cfgs"]["model_loader"]
        if model_loader["load_way"] == "specified":
            epoch_name = str(model_loader["test_epoch"])
        elif model_loader["load_way"] == "best":
            # NOTE: TO make it consistent with the name in case of model_loader["load_way"] == "pth", the name have to be "best_model.pth"
            epoch_name = "best_model.pth"
        elif model_loader["load_way"] == "latest":
            epoch_name = str(self.cfgs["training_cfgs"]["epochs"])
        elif model_loader["load_way"] == "pth":
            epoch_name = model_loader["pth_path"].split(os.sep)[-1]
        else:
            raise ValueError("Invalid load_way")
        return epoch_name

    def save_cfg(self, cfgs):
        # save the cfgs after training
        # update the cfgs with the latest one
        self.cfgs = cfgs
        param_file_exist = any(
            (
                fnmatch.fnmatch(file, "*.json")
                and "_stat" not in file  # statistics json file
                and "_dict" not in file  # data cache json file
            )
            for file in os.listdir(self.result_dir)
        )
        if not param_file_exist:
            # although we save params log during training, but sometimes we directly evaluate a model
            # so here we still save params log if param file does not exist
            # no param file was saved yet, here we save data and params setting
            save_model_params_log(cfgs, self.result_dir)

    def save_result(self, pred, obs):
        """
        save the pred value of testing period and obs value

        Parameters
        ----------
        pred
            predictions
        obs
            observations
        pred_name
            the file name of predictions
        obs_name
            the file name of observations

        Returns
        -------
        None
        """
        save_dir = self.result_dir
        flow_pred_file = os.path.join(save_dir, self.pred_name)
        flow_obs_file = os.path.join(save_dir, self.obs_name)
        max_len = max(len(basin) for basin in pred.basin.values)
        encoding = {"basin": {"dtype": f"U{max_len}"}}
        pred.to_netcdf(flow_pred_file + ".nc", encoding=encoding)
        obs.to_netcdf(flow_obs_file + ".nc", encoding=encoding)

    def eval_result(self, preds_xr, obss_xr):
        # types of observations
        target_col = self.cfgs["data_cfgs"]["target_cols"]
        evaluation_metrics = self.cfgs["evaluation_cfgs"]["metrics"]
        basin_ids = self.cfgs["data_cfgs"]["object_ids"]
        test_path = self.cfgs["data_cfgs"]["case_dir"]
        # Assume object_ids like ['changdian_61561']
        # fill_nan: "no" means ignoring the NaN value;
        #           "sum" means calculate the sum of the following values in the NaN locations.
        #           For example, observations are [1, nan, nan, 2], and predictions are [0.3, 0.3, 0.3, 1.5].
        #           Then, "no" means [1, 2] v.s. [0.3, 1.5] while "sum" means [1, 2] v.s. [0.3 + 0.3 + 0.3, 1.5].
        #           If it is a str, then all target vars use same fill_nan method;
        #           elif it is a list, each for a var
        fill_nan = self.cfgs["evaluation_cfgs"]["fill_nan"]
        #  Then evaluate the model metrics
        if type(fill_nan) is list and len(fill_nan) != len(target_col):
            raise ValueError("length of fill_nan must be equal to target_col's")
        for i, col in enumerate(target_col):
            eval_log = {}
            obs = obss_xr[col].to_numpy()
            pred = preds_xr[col].to_numpy()

            eval_log = calculate_and_record_metrics(
                obs,
                pred,
                evaluation_metrics,
                col,
                fill_nan[i] if isinstance(fill_nan, list) else fill_nan,
                eval_log,
            )
            # Create pandas DataFrames from eval_log for each target variable (e.g., streamflow)
            # Create a dictionary to hold the data for the DataFrame
            data = {}
            # Iterate over metrics in eval_log
            for metric, values in eval_log.items():
                # Remove 'of streamflow' (or similar) from the metric name
                clean_metric = metric.replace(f"of {col}", "").strip()

                # Add the cleaned metric to the data dictionary
                data[clean_metric] = values

            # Create a DataFrame using object_ids as the index and metrics as columns
            df = pd.DataFrame(data, index=basin_ids)

            # Define the output file name based on the target variable
            output_file = os.path.join(test_path, f"metric_{col}.csv")

            # Save the DataFrame to a CSV file
            df.to_csv(output_file, index_label="basin_id")

        # Finally, try to explain model behaviour using shap
        is_shap = self.cfgs["evaluation_cfgs"]["explainer"] == "shap"
        if is_shap:
            shap_summary_plot(self.model, self.traindataset, self.testdataset)
            # deep_explain_model_summary_plot(self.model, test_data)
            # deep_explain_model_heatmap(self.model, test_data)

    def _convert_streamflow_units(self, ds):
        """convert the streamflow units to m^3/s

        Parameters
        ----------
        pred : np.array
            predictions

        Returns
        -------
        """
        data_cfgs = self.cfgs["data_cfgs"]
        source_name = data_cfgs["source_cfgs"]["source_name"]
        source_path = data_cfgs["source_cfgs"]["source_path"]
        other_settings = data_cfgs["source_cfgs"].get("other_settings", {})
        data_source = data_sources_dict[source_name](source_path, **other_settings)
        basin_id = data_cfgs["object_ids"]
        # NOTE: all datasource should have read_area method
        basin_area = data_source.read_area(basin_id)
        target_unit = "m^3/s"
        # Get the flow variable name dynamically from config instead of hardcoding "streamflow"
        # NOTE: the first target variable must be the flow variable
        var_flow = self.cfgs["data_cfgs"]["target_cols"][0]
        streamflow_ds = ds[[var_flow]]
        ds_ = streamflow_unit_conv(
            streamflow_ds, basin_area, target_unit=target_unit, inverse=True
        )
        new_ds = ds.copy(deep=True)
        new_ds[var_flow] = ds_[var_flow]
        return new_ds

    def load_result(self, convert_flow_unit=False) -> Tuple[np.array, np.array]:
        """load the pred value of testing period and obs value"""
        save_dir = self.result_dir
        pred_file = os.path.join(save_dir, self.pred_name + ".nc")
        obs_file = os.path.join(save_dir, self.obs_name + ".nc")
        pred = xr.open_dataset(pred_file)
        obs = xr.open_dataset(obs_file)
        if convert_flow_unit:
            pred = self._convert_streamflow_units(pred)
            obs = self._convert_streamflow_units(obs)
        return pred, obs

    def save_intermediate_results(self, **kwargs):
        """Load model weights and deal with some intermediate results"""
        is_cell_states = kwargs.get("is_cell_states", False)
        is_pbm_params = kwargs.get("is_pbm_params", False)
        cfgs = self.cfgs
        cfgs["training_cfgs"]["train_mode"] = False
        training_cfgs = cfgs["training_cfgs"]
        seq_first = training_cfgs["which_first_tensor"] == "sequence"
        if is_cell_states:
            # TODO: not support return_cell_states yet
            return cellstates_when_inference(seq_first, data_cfgs, pred)
        if is_pbm_params:
            self._save_pbm_params(cfgs, seq_first)

    def _save_pbm_params(self, cfgs, seq_first):
        training_cfgs = cfgs["training_cfgs"]
        model_loader = cfgs["evaluation_cfgs"]["model_loader"]
        model_pth_dir = cfgs["data_cfgs"]["case_dir"]
        weight_path = read_pth_from_model_loader(model_loader, model_pth_dir)
        cfgs["model_cfgs"]["weight_path"] = weight_path
        cfgs["training_cfgs"]["device"] = [0] if torch.cuda.is_available() else [-1]
        deephydro = DeepHydro(cfgs)
        device = deephydro.device
        dl_model = deephydro.model.dl_model
        pb_model = deephydro.model.pb_model
        param_func = deephydro.model.param_func
        # TODO: check for dplnnmodule model
        param_test_way = deephydro.model.param_test_way
        test_dataloader = DataLoader(
            deephydro.testdataset,
            batch_size=training_cfgs["batch_size"],
            shuffle=False,
            sampler=None,
            batch_sampler=None,
            drop_last=False,
            timeout=0,
            worker_init_fn=None,
        )
        deephydro.model.eval()
        # here the batch is just an index of lookup table, so any batch size could be chosen
        params_lst = []
        with torch.no_grad():
            for batch in test_dataloader:
                ys, gen = model_infer(seq_first, device, dl_model, batch)
                # we set all params' values in [0, 1] and will scale them when forwarding
                if param_func == "clamp":
                    params_ = torch.clamp(gen, min=0.0, max=1.0)
                elif param_func == "sigmoid":
                    params_ = F.sigmoid(gen)
                else:
                    raise NotImplementedError(
                        "We don't provide this way to limit parameters' range!! Please choose sigmoid or clamp"
                    )
                # just get one-period values, here we use the final period's values
                params = params_[:, -1, :]
                params_lst.append(params)
        pb_params = reduce(lambda a, b: torch.cat((a, b), dim=0), params_lst)
        # trans tensor to pandas dataframe
        sites = deephydro.cfgs["data_cfgs"]["object_ids"]
        params_names = pb_model.params_names
        params_df = pd.DataFrame(
            pb_params.cpu().numpy(), columns=params_names, index=sites
        )
        save_param_file = os.path.join(
            model_pth_dir, f"pb_params_{int(time.time())}.csv"
        )
        params_df.to_csv(save_param_file, index_label="GAGE_ID")

    def read_tensorboard_log(self, **kwargs):
        """read tensorboard log files"""
        is_scalar = kwargs.get("is_scalar", False)
        is_histogram = kwargs.get("is_histogram", False)
        log_dir = self.cfgs["data_cfgs"]["case_dir"]
        if is_scalar:
            scalar_file = os.path.join(log_dir, "tb_scalars.csv")
            if not os.path.exists(scalar_file):
                reader = SummaryReader(log_dir)
                df_scalar = reader.scalars
                df_scalar.to_csv(scalar_file, index=False)
            else:
                df_scalar = pd.read_csv(scalar_file)
        if is_histogram:
            histogram_file = os.path.join(log_dir, "tb_histograms.csv")
            if not os.path.exists(histogram_file):
                reader = SummaryReader(log_dir)
                df_histogram = reader.histograms
                df_histogram.to_csv(histogram_file, index=False)
            else:
                df_histogram = pd.read_csv(histogram_file)
        if is_scalar and is_histogram:
            return df_scalar, df_histogram
        elif is_scalar:
            return df_scalar
        elif is_histogram:
            return df_histogram

    # TODO: the following code is not finished yet
    def load_ensemble_result(
        self, save_dirs, test_epoch, flow_unit="m3/s", basin_areas=None
    ) -> Tuple[np.array, np.array]:
        """
        load ensemble mean value

        Parameters
        ----------
        save_dirs
        test_epoch
        flow_unit
            default is m3/s, if it is not m3/s, transform the results
        basin_areas
            if unit is mm/day it will be used, default is None

        Returns
        -------

        """
        preds = []
        obss = []
        for save_dir in save_dirs:
            pred_i, obs_i = self.load_result(save_dir, test_epoch)
            if pred_i.ndim == 3 and pred_i.shape[-1] == 1:
                pred_i = pred_i.reshape(pred_i.shape[0], pred_i.shape[1])
                obs_i = obs_i.reshape(obs_i.shape[0], obs_i.shape[1])
            preds.append(pred_i)
            obss.append(obs_i)
        preds_np = np.array(preds)
        obss_np = np.array(obss)
        pred_mean = np.mean(preds_np, axis=0)
        obs_mean = np.mean(obss_np, axis=0)
        if flow_unit == "mm/day":
            if basin_areas is None:
                raise ArithmeticError("No basin areas we cannot calculate")
            basin_areas = np.repeat(basin_areas, obs_mean.shape[1], axis=0).reshape(
                obs_mean.shape
            )
            obs_mean = obs_mean * basin_areas * 1e-3 * 1e6 / 86400
            pred_mean = pred_mean * basin_areas * 1e-3 * 1e6 / 86400
        elif flow_unit == "m3/s":
            pass
        elif flow_unit == "ft3/s":
            obs_mean = obs_mean / 35.314666721489
            pred_mean = pred_mean / 35.314666721489
        return pred_mean, obs_mean

    def eval_ensemble_result(
        self,
        save_dirs,
        test_epoch,
        return_value=False,
        flow_unit="m3/s",
        basin_areas=None,
    ) -> Tuple[np.array, np.array]:
        """calculate statistics for ensemble results

        Parameters
        ----------
        save_dirs : _type_
            where the results save
        test_epoch : _type_
            we name the results files with the test_epoch
        return_value : bool, optional
            if True, return (inds_df, pred_mean, obs_mean), by default False
        flow_unit : str, optional
            arg for load_ensemble_result, by default "m3/s"
        basin_areas : _type_, optional
            arg for load_ensemble_result, by default None

        Returns
        -------
        Tuple[np.array, np.array]
            inds_df or (inds_df, pred_mean, obs_mean)
        """
        pred_mean, obs_mean = self.load_ensemble_result(
            save_dirs, test_epoch, flow_unit=flow_unit, basin_areas=basin_areas
        )
        inds = stat_error(obs_mean, pred_mean)
        inds_df = pd.DataFrame(inds)
        return (inds_df, pred_mean, obs_mean) if return_value else inds_df

eval_ensemble_result(self, save_dirs, test_epoch, return_value=False, flow_unit='m3/s', basin_areas=None)

calculate statistics for ensemble results

Parameters

save_dirs : type where the results save test_epoch : type we name the results files with the test_epoch return_value : bool, optional if True, return (inds_df, pred_mean, obs_mean), by default False flow_unit : str, optional arg for load_ensemble_result, by default "m3/s" basin_areas : type, optional arg for load_ensemble_result, by default None

Returns

Tuple[np.array, np.array] inds_df or (inds_df, pred_mean, obs_mean)

Source code in torchhydro/trainers/resulter.py
def eval_ensemble_result(
    self,
    save_dirs,
    test_epoch,
    return_value=False,
    flow_unit="m3/s",
    basin_areas=None,
) -> Tuple[np.array, np.array]:
    """calculate statistics for ensemble results

    Parameters
    ----------
    save_dirs : _type_
        where the results save
    test_epoch : _type_
        we name the results files with the test_epoch
    return_value : bool, optional
        if True, return (inds_df, pred_mean, obs_mean), by default False
    flow_unit : str, optional
        arg for load_ensemble_result, by default "m3/s"
    basin_areas : _type_, optional
        arg for load_ensemble_result, by default None

    Returns
    -------
    Tuple[np.array, np.array]
        inds_df or (inds_df, pred_mean, obs_mean)
    """
    pred_mean, obs_mean = self.load_ensemble_result(
        save_dirs, test_epoch, flow_unit=flow_unit, basin_areas=basin_areas
    )
    inds = stat_error(obs_mean, pred_mean)
    inds_df = pd.DataFrame(inds)
    return (inds_df, pred_mean, obs_mean) if return_value else inds_df

load_ensemble_result(self, save_dirs, test_epoch, flow_unit='m3/s', basin_areas=None)

load ensemble mean value

Parameters

save_dirs test_epoch flow_unit default is m3/s, if it is not m3/s, transform the results basin_areas if unit is mm/day it will be used, default is None

Returns
Source code in torchhydro/trainers/resulter.py
def load_ensemble_result(
    self, save_dirs, test_epoch, flow_unit="m3/s", basin_areas=None
) -> Tuple[np.array, np.array]:
    """
    load ensemble mean value

    Parameters
    ----------
    save_dirs
    test_epoch
    flow_unit
        default is m3/s, if it is not m3/s, transform the results
    basin_areas
        if unit is mm/day it will be used, default is None

    Returns
    -------

    """
    preds = []
    obss = []
    for save_dir in save_dirs:
        pred_i, obs_i = self.load_result(save_dir, test_epoch)
        if pred_i.ndim == 3 and pred_i.shape[-1] == 1:
            pred_i = pred_i.reshape(pred_i.shape[0], pred_i.shape[1])
            obs_i = obs_i.reshape(obs_i.shape[0], obs_i.shape[1])
        preds.append(pred_i)
        obss.append(obs_i)
    preds_np = np.array(preds)
    obss_np = np.array(obss)
    pred_mean = np.mean(preds_np, axis=0)
    obs_mean = np.mean(obss_np, axis=0)
    if flow_unit == "mm/day":
        if basin_areas is None:
            raise ArithmeticError("No basin areas we cannot calculate")
        basin_areas = np.repeat(basin_areas, obs_mean.shape[1], axis=0).reshape(
            obs_mean.shape
        )
        obs_mean = obs_mean * basin_areas * 1e-3 * 1e6 / 86400
        pred_mean = pred_mean * basin_areas * 1e-3 * 1e6 / 86400
    elif flow_unit == "m3/s":
        pass
    elif flow_unit == "ft3/s":
        obs_mean = obs_mean / 35.314666721489
        pred_mean = pred_mean / 35.314666721489
    return pred_mean, obs_mean

load_result(self, convert_flow_unit=False)

load the pred value of testing period and obs value

Source code in torchhydro/trainers/resulter.py
def load_result(self, convert_flow_unit=False) -> Tuple[np.array, np.array]:
    """load the pred value of testing period and obs value"""
    save_dir = self.result_dir
    pred_file = os.path.join(save_dir, self.pred_name + ".nc")
    obs_file = os.path.join(save_dir, self.obs_name + ".nc")
    pred = xr.open_dataset(pred_file)
    obs = xr.open_dataset(obs_file)
    if convert_flow_unit:
        pred = self._convert_streamflow_units(pred)
        obs = self._convert_streamflow_units(obs)
    return pred, obs

read_tensorboard_log(self, **kwargs)

read tensorboard log files

Source code in torchhydro/trainers/resulter.py
def read_tensorboard_log(self, **kwargs):
    """read tensorboard log files"""
    is_scalar = kwargs.get("is_scalar", False)
    is_histogram = kwargs.get("is_histogram", False)
    log_dir = self.cfgs["data_cfgs"]["case_dir"]
    if is_scalar:
        scalar_file = os.path.join(log_dir, "tb_scalars.csv")
        if not os.path.exists(scalar_file):
            reader = SummaryReader(log_dir)
            df_scalar = reader.scalars
            df_scalar.to_csv(scalar_file, index=False)
        else:
            df_scalar = pd.read_csv(scalar_file)
    if is_histogram:
        histogram_file = os.path.join(log_dir, "tb_histograms.csv")
        if not os.path.exists(histogram_file):
            reader = SummaryReader(log_dir)
            df_histogram = reader.histograms
            df_histogram.to_csv(histogram_file, index=False)
        else:
            df_histogram = pd.read_csv(histogram_file)
    if is_scalar and is_histogram:
        return df_scalar, df_histogram
    elif is_scalar:
        return df_scalar
    elif is_histogram:
        return df_histogram

save_intermediate_results(self, **kwargs)

Load model weights and deal with some intermediate results

Source code in torchhydro/trainers/resulter.py
def save_intermediate_results(self, **kwargs):
    """Load model weights and deal with some intermediate results"""
    is_cell_states = kwargs.get("is_cell_states", False)
    is_pbm_params = kwargs.get("is_pbm_params", False)
    cfgs = self.cfgs
    cfgs["training_cfgs"]["train_mode"] = False
    training_cfgs = cfgs["training_cfgs"]
    seq_first = training_cfgs["which_first_tensor"] == "sequence"
    if is_cell_states:
        # TODO: not support return_cell_states yet
        return cellstates_when_inference(seq_first, data_cfgs, pred)
    if is_pbm_params:
        self._save_pbm_params(cfgs, seq_first)

save_result(self, pred, obs)

save the pred value of testing period and obs value

Parameters

pred predictions obs observations pred_name the file name of predictions obs_name the file name of observations

Returns

None

Source code in torchhydro/trainers/resulter.py
def save_result(self, pred, obs):
    """
    save the pred value of testing period and obs value

    Parameters
    ----------
    pred
        predictions
    obs
        observations
    pred_name
        the file name of predictions
    obs_name
        the file name of observations

    Returns
    -------
    None
    """
    save_dir = self.result_dir
    flow_pred_file = os.path.join(save_dir, self.pred_name)
    flow_obs_file = os.path.join(save_dir, self.obs_name)
    max_len = max(len(basin) for basin in pred.basin.values)
    encoding = {"basin": {"dtype": f"U{max_len}"}}
    pred.to_netcdf(flow_pred_file + ".nc", encoding=encoding)
    obs.to_netcdf(flow_obs_file + ".nc", encoding=encoding)

train_logger

Author: Wenyu Ouyang Date: 2021-12-31 11:08:29 LastEditTime: 2025-11-07 08:36:50 LastEditors: Wenyu Ouyang Description: Training function for DL models FilePath: orchhydro orchhydro rainers rain_logger.py Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.

TrainLogger

Source code in torchhydro/trainers/train_logger.py
class TrainLogger:
    def __init__(self, model_filepath, params, opt):
        self.training_cfgs = params["training_cfgs"]
        self.data_cfgs = params["data_cfgs"]
        self.evaluation_cfgs = params["evaluation_cfgs"]
        self.model_cfgs = params["model_cfgs"]
        self.opt = opt
        self.training_save_dir = model_filepath
        self.tb = SummaryWriter(self.training_save_dir)
        self.session_params = []
        self.train_time = []
        # log loss for each epoch
        self.epoch_loss = []
        # reload previous logs if continue_train is True and weight_path is not None
        if (
            self.model_cfgs["continue_train"]
            and self.model_cfgs["weight_path"] is not None
        ):
            the_logger_file = get_lastest_logger_file_in_a_dir(self.training_save_dir)
            if the_logger_file is not None:
                with open(the_logger_file, "r") as f:
                    logs = json.load(f)
            start_epoch = self.training_cfgs["start_epoch"]
            # read the logs before start_epoch and load them to session_params, train_time, epoch_loss
            for log in logs["run"]:
                if log["epoch"] < start_epoch:
                    self.session_params.append(log)
                    self.train_time.append(log["train_time"])
                    self.epoch_loss.append(float(log["train_loss"]))

    def save_session_param(
        self, epoch, total_loss, n_iter_ep, valid_loss=None, valid_metrics=None
    ):
        if valid_metrics is None:
            if valid_loss is None:
                epoch_params = {
                    "epoch": epoch,
                    "train_loss": str(total_loss),
                    "iter_num": n_iter_ep,
                }
            else:
                epoch_params = {
                    "epoch": epoch,
                    "train_loss": str(total_loss),
                    "validation_loss": str(valid_loss),
                    "iter_num": n_iter_ep,
                }
        else:
            epoch_params = {
                "epoch": epoch,
                "train_loss": str(total_loss),
                "validation_loss": str(valid_loss),
                "validation_metric": valid_metrics,
                "iter_num": n_iter_ep,
            }
        epoch_params["train_time"] = self.train_time[epoch - 1]
        self.session_params.append(epoch_params)

    @contextmanager
    def log_epoch_train(self, epoch):
        start_time = time.time()
        logs = {}
        # here content in the 'with' block will be performed after yeild
        yield logs
        total_loss = logs["train_loss"]
        elapsed_time = time.time() - start_time
        lr = self.opt.param_groups[0]["lr"]
        log_str = "Epoch {} Loss {:.4f} time {:.2f} lr {}".format(
            epoch, total_loss, elapsed_time, lr
        )
        print(log_str)
        model = logs["model"]
        print(model)
        self.tb.add_scalar("Loss", total_loss, epoch)
        # self.plot_hist_img(model, epoch)
        self.train_time.append(log_str)
        self.epoch_loss.append(total_loss)

    @contextmanager
    def log_epoch_valid(self, epoch):
        logs = {}
        yield logs
        valid_loss = logs["valid_loss"]
        if (
            self.training_cfgs["valid_batch_mode"] == "test"
            and self.training_cfgs["calc_metrics"]
        ):
            # NOTE: Now we only evaluate the metrics for test-mode validation
            valid_metrics = logs["valid_metrics"]
            val_log = "Epoch {} Valid Loss {:.4f} Valid Metric {}".format(
                epoch, valid_loss, valid_metrics
            )
            print(val_log)
            self.tb.add_scalar("ValidLoss", valid_loss, epoch)
            target_col = self.data_cfgs["target_cols"]
            evaluation_metrics = self.evaluation_cfgs["metrics"]
            for i in range(len(target_col)):
                for evaluation_metric in evaluation_metrics:
                    self.tb.add_scalar(
                        f"Valid{target_col[i]}{evaluation_metric}mean",
                        np.nanmean(
                            valid_metrics[f"{evaluation_metric} of {target_col[i]}"]
                        ),
                        epoch,
                    )
                    self.tb.add_scalar(
                        f"Valid{target_col[i]}{evaluation_metric}median",
                        np.nanmedian(
                            valid_metrics[f"{evaluation_metric} of {target_col[i]}"]
                        ),
                        epoch,
                    )
        else:
            val_log = "Epoch {} Valid Loss {:.4f} ".format(epoch, valid_loss)
            print(val_log)
            self.tb.add_scalar("ValidLoss", valid_loss, epoch)

    def save_model_and_params(self, model, epoch, params):
        final_epoch = params["training_cfgs"]["epochs"]
        save_epoch = params["training_cfgs"]["save_epoch"]
        if save_epoch is None or save_epoch == 0 and epoch != final_epoch:
            return
        if (save_epoch > 0 and epoch % save_epoch == 0) or epoch == final_epoch:
            # save for save_epoch
            model_file = os.path.join(
                self.training_save_dir, f"model_Ep{str(epoch)}.pth"
            )
            save_model(model, model_file)
        if epoch == final_epoch:
            self._save_final_epoch(params, model)

    def _save_final_epoch(self, params, model):
        # In final epoch, we save the model and params in case_dir
        final_path = params["data_cfgs"]["case_dir"]
        params["run"] = self.session_params
        time_stamp = datetime.now().strftime("%d_%B_%Y%I_%M%p")
        model_save_path = os.path.join(final_path, f"{time_stamp}_model.pth")
        save_model(model, model_save_path)
        save_model_params_log(params, final_path)
        # also save one for a training directory for one hyperparameter setting
        save_model_params_log(params, self.training_save_dir)

    def plot_hist_img(self, model, global_step):
        for tag, parm in model.named_parameters():
            self.tb.add_histogram(
                f"{tag}_hist", parm.detach().cpu().numpy(), global_step
            )
            if len(parm.shape) == 2:
                img_format = "HW"
                if parm.shape[0] > parm.shape[1]:
                    img_format = "WH"
                    self.tb.add_image(
                        f"{tag}_img",
                        parm.detach().cpu().numpy(),
                        global_step,
                        dataformats=img_format,
                    )

    def plot_model_structure(self, model):
        """plot model structure in tensorboard

        TODO: This function is not working as expected. It should be rewritten.

        Parameters
        ----------
        model :
            torch model
        """
        # input4modelplot = torch.randn(
        #     self.training_cfgs["batch_size"],
        #     self.training_cfgs["hindcast_length"],
        #     # self.model_cfgs["model_hyperparam"]["n_input_features"],
        #     self.model_cfgs["model_hyperparam"]["input_size"],
        # )
        if self.data_cfgs["model_mode"] == "single":
            input4modelplot = [
                torch.randn(
                    self.training_cfgs["batch_size"],
                    self.training_cfgs["hindcast_length"],
                    self.data_cfgs["input_features"] - 1,
                ),
                torch.randn(
                    self.training_cfgs["batch_size"],
                    self.training_cfgs["hindcast_length"],
                    self.data_cfgs["cnn_size"],
                ),
                torch.rand(
                    self.training_cfgs["batch_size"],
                    1,
                    self.data_cfgs["output_features"],
                ),
            ]
        else:
            input4modelplot = [
                torch.randn(
                    self.training_cfgs["batch_size"],
                    self.training_cfgs["hindcast_length"],
                    self.data_cfgs["input_features"],
                ),
                torch.randn(
                    self.training_cfgs["batch_size"],
                    self.training_cfgs["hindcast_length"],
                    self.data_cfgs["input_size_encoder2"],
                ),
                torch.rand(
                    self.training_cfgs["batch_size"],
                    1,
                    self.data_cfgs["output_features"],
                ),
            ]
        self.tb.add_graph(model, input4modelplot)

plot_model_structure(self, model)

plot model structure in tensorboard

TODO: This function is not working as expected. It should be rewritten.

Parameters

model : torch model

Source code in torchhydro/trainers/train_logger.py
def plot_model_structure(self, model):
    """plot model structure in tensorboard

    TODO: This function is not working as expected. It should be rewritten.

    Parameters
    ----------
    model :
        torch model
    """
    # input4modelplot = torch.randn(
    #     self.training_cfgs["batch_size"],
    #     self.training_cfgs["hindcast_length"],
    #     # self.model_cfgs["model_hyperparam"]["n_input_features"],
    #     self.model_cfgs["model_hyperparam"]["input_size"],
    # )
    if self.data_cfgs["model_mode"] == "single":
        input4modelplot = [
            torch.randn(
                self.training_cfgs["batch_size"],
                self.training_cfgs["hindcast_length"],
                self.data_cfgs["input_features"] - 1,
            ),
            torch.randn(
                self.training_cfgs["batch_size"],
                self.training_cfgs["hindcast_length"],
                self.data_cfgs["cnn_size"],
            ),
            torch.rand(
                self.training_cfgs["batch_size"],
                1,
                self.data_cfgs["output_features"],
            ),
        ]
    else:
        input4modelplot = [
            torch.randn(
                self.training_cfgs["batch_size"],
                self.training_cfgs["hindcast_length"],
                self.data_cfgs["input_features"],
            ),
            torch.randn(
                self.training_cfgs["batch_size"],
                self.training_cfgs["hindcast_length"],
                self.data_cfgs["input_size_encoder2"],
            ),
            torch.rand(
                self.training_cfgs["batch_size"],
                1,
                self.data_cfgs["output_features"],
            ),
        ]
    self.tb.add_graph(model, input4modelplot)

train_utils

Author: Wenyu Ouyang Date: 2024-04-08 18:16:26 LastEditTime: 2025-11-06 13:48:48 LastEditors: Wenyu Ouyang Description: Some basic functions for training FilePath: orchhydro orchhydro rainers rain_utils.py Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved.

EarlyStopper

Source code in torchhydro/trainers/train_utils.py
class EarlyStopper(object):
    def __init__(
        self,
        patience: int,
        min_delta: float = 0.0,
        cumulative_delta: bool = False,
    ):
        """
        EarlyStopping handler can be used to stop the training if no improvement after a given number of events.

        Parameters
        ----------
        patience
            Number of events to wait if no improvement and then stop the training.
        min_delta
            A minimum increase in the score to qualify as an improvement,
            i.e. an increase of less than or equal to `min_delta`, will count as no improvement.
        cumulative_delta
            It True, `min_delta` defines an increase since the last `patience` reset, otherwise,
        it defines an increase after the last event. Default value is False.
        """

        if patience < 1:
            raise ValueError("Argument patience should be positive integer.")

        if min_delta < 0.0:
            raise ValueError("Argument min_delta should not be a negative number.")

        self.patience = patience
        self.min_delta = min_delta
        self.cumulative_delta = cumulative_delta
        self.counter = 0
        self.best_score = None

    def check_loss(self, model, validation_loss, save_dir) -> bool:
        score = validation_loss
        if self.best_score is None:
            self.save_model_checkpoint(model, save_dir)
            self.best_score = score

        elif score + self.min_delta >= self.best_score:
            self.counter += 1
            print("Epochs without Model Update:", self.counter)
            if self.counter >= self.patience:
                return False
        else:
            self.save_model_checkpoint(model, save_dir)
            print("Model Update")
            self.best_score = score
            self.counter = 0
        return True

    def save_model_checkpoint(self, model, save_dir):
        torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))

__init__(self, patience, min_delta=0.0, cumulative_delta=False) special

EarlyStopping handler can be used to stop the training if no improvement after a given number of events.

Parameters

patience Number of events to wait if no improvement and then stop the training. min_delta A minimum increase in the score to qualify as an improvement, i.e. an increase of less than or equal to min_delta, will count as no improvement. cumulative_delta It True, min_delta defines an increase since the last patience reset, otherwise, it defines an increase after the last event. Default value is False.

Source code in torchhydro/trainers/train_utils.py
def __init__(
    self,
    patience: int,
    min_delta: float = 0.0,
    cumulative_delta: bool = False,
):
    """
    EarlyStopping handler can be used to stop the training if no improvement after a given number of events.

    Parameters
    ----------
    patience
        Number of events to wait if no improvement and then stop the training.
    min_delta
        A minimum increase in the score to qualify as an improvement,
        i.e. an increase of less than or equal to `min_delta`, will count as no improvement.
    cumulative_delta
        It True, `min_delta` defines an increase since the last `patience` reset, otherwise,
    it defines an increase after the last event. Default value is False.
    """

    if patience < 1:
        raise ValueError("Argument patience should be positive integer.")

    if min_delta < 0.0:
        raise ValueError("Argument min_delta should not be a negative number.")

    self.patience = patience
    self.min_delta = min_delta
    self.cumulative_delta = cumulative_delta
    self.counter = 0
    self.best_score = None

average_weights(w)

Returns the average of the weights.

Source code in torchhydro/trainers/train_utils.py
def average_weights(w):
    """
    Returns the average of the weights.
    """
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg

cellstates_when_inference(seq_first, data_cfgs, pred)

get cell states when inference

Source code in torchhydro/trainers/train_utils.py
def cellstates_when_inference(seq_first, data_cfgs, pred):
    """get cell states when inference"""
    cs_out = (
        cs_cat_lst.detach().cpu().numpy().swapaxes(0, 1)
        if seq_first
        else cs_cat_lst.detach().cpu().numpy()
    )
    cs_out_lst = [cs_out]
    cell_state = reduce(lambda a, b: np.vstack((a, b)), cs_out_lst)
    np.save(os.path.join(data_cfgs["case_dir"], "cell_states.npy"), cell_state)
    # model.zero_grad()
    torch.cuda.empty_cache()
    return pred, cell_state

compute_loss(labels, output, criterion, **kwargs)

Function for computing the loss

Parameters

labels The real values for the target. Shape can be variable but should follow (batch_size, time) output The output of the model criterion loss function validation_dataset Only passed when unscaling of data is needed. m defaults to 1

Returns

torch.Tensor the computed loss

Source code in torchhydro/trainers/train_utils.py
def compute_loss(
    labels: torch.Tensor, output: torch.Tensor, criterion, **kwargs
) -> torch.Tensor:
    """
    Function for computing the loss

    Parameters
    ----------
    labels
        The real values for the target. Shape can be variable but should follow (batch_size, time)
    output
        The output of the model
    criterion
        loss function
    validation_dataset
        Only passed when unscaling of data is needed.
    m
        defaults to 1

    Returns
    -------
    torch.Tensor
        the computed loss
    """
    # a = np.sum(output.cpu().detach().numpy(),axis=1)/len(output)
    # b=[]
    # for i in a:
    #     b.append([i.tolist()])
    # output = torch.tensor(b, requires_grad=True).to(torch.device("cuda"))

    if isinstance(criterion, GaussianLoss):
        if len(output[0].shape) > 2:
            g_loss = GaussianLoss(output[0][:, :, 0], output[1][:, :, 0])
        else:
            g_loss = GaussianLoss(output[0][:, 0], output[1][:, 0])
        return g_loss(labels)
    if isinstance(criterion, FloodBaseLoss):
        # labels has one more column than output, which is the flood mask
        # so we need to remove the last column of labels to get targets
        flood_mask = labels[:, :, -1:]  # Extract flood mask from last column
        targets = labels[:, :, :-1]  # Extract targets (remove last column)
        return criterion(output, targets, flood_mask)
    if (
        isinstance(output, torch.Tensor)
        and len(labels.shape) != len(output.shape)
        and len(labels.shape) > 1
    ):
        if labels.shape[1] == output.shape[1]:
            labels = labels.unsqueeze(2)
        else:
            labels = labels.unsqueeze(0)
    assert labels.shape == output.shape
    return criterion(output, labels.float())

compute_validation(model, criterion, data_loader, device=None, **kwargs)

Function to compute the validation loss metrics

Parameters

model the trained model criterion torch.nn.modules.loss dataloader The data-loader of either validation or test-data device torch.device

Returns

tuple validation observations (numpy array), predictions (numpy array) and the loss of validation

Source code in torchhydro/trainers/train_utils.py
def compute_validation(
    model,
    criterion,
    data_loader: DataLoader,
    device: torch.device = None,
    **kwargs,
):
    """
    Function to compute the validation loss metrics

    Parameters
    ----------
    model
        the trained model
    criterion
        torch.nn.modules.loss
    dataloader
        The data-loader of either validation or test-data
    device
        torch.device

    Returns
    -------
    tuple
        validation observations (numpy array), predictions (numpy array) and the loss of validation
    """
    model.eval()
    seq_first = kwargs["which_first_tensor"] != "batch"
    obs = []
    preds = []
    valid_loss = 0.0
    obs_final = None
    pred_final = None
    with torch.no_grad():
        iter_num = 0
        for batch in tqdm(data_loader, desc="Evaluating", total=len(data_loader)):
            trg, output = model_infer(seq_first, device, model, batch)
            obs.append(trg)
            preds.append(output)
            valid_loss_ = compute_loss(trg, output, criterion)
            if torch.isnan(valid_loss_):
                # for not-train mode, we may get all nan data for trg
                # so we skip this batch
                continue
                print("NAN loss detected, skipping this batch")
            valid_loss = valid_loss + valid_loss_.item()
            iter_num = iter_num + 1

            # For flood datasets, remove the flood_mask column from observations
            # to match the prediction dimensions for evaluation
            trg_for_eval = (
                trg[:, :, :-1] if isinstance(criterion, FloodBaseLoss) else trg
            )
            # clear memory to save GPU memory
            if obs_final is None:
                obs_final = trg_for_eval.detach().cpu()
                pred_final = output.detach().cpu()
            else:
                obs_final = torch.cat([obs_final, trg_for_eval.detach().cpu()], dim=0)
                pred_final = torch.cat([pred_final, output.detach().cpu()], dim=0)
            del trg, output
            torch.cuda.empty_cache()
    valid_loss = valid_loss / iter_num
    y_obs = obs_final.numpy()
    y_pred = pred_final.numpy()
    return y_obs, y_pred, valid_loss

evaluate_validation(validation_data_loader, output, labels, evaluation_cfgs, target_col)

calculate metrics for validation

Parameters

output model output labels model target evaluation_cfgs evaluation configs target_col target columns

Returns

tuple metrics

Source code in torchhydro/trainers/train_utils.py
def evaluate_validation(
    validation_data_loader,
    output,
    labels,
    evaluation_cfgs,
    target_col,
):
    """
    calculate metrics for validation

    Parameters
    ----------
    output
        model output
    labels
        model target
    evaluation_cfgs
        evaluation configs
    target_col
        target columns

    Returns
    -------
    tuple
        metrics
    """
    fill_nan = evaluation_cfgs["fill_nan"]
    if isinstance(fill_nan, list) and len(fill_nan) != len(target_col):
        raise ValueError("Length of fill_nan must be equal to length of target_col.")
    eval_log = {}
    evaluation_metrics = evaluation_cfgs["metrics"]
    obss_xr, preds_xr = get_preds_to_be_eval(
        validation_data_loader,
        evaluation_cfgs,
        output,
        labels,
    )
    # obss_xr_list
    # preds_xr_list
    # if type()
    # for i in range(obs.shape[0]): # 第几个预见期
    ## obs_ = obs[i]
    if isinstance(obss_xr, list):
        obss_xr_list = obss_xr
        preds_xr_list = preds_xr
        for horizon_idx in range(len(obss_xr_list)):
            obss_xr = obss_xr_list[horizon_idx]
            preds_xr = preds_xr_list[horizon_idx]
            for i, col in enumerate(target_col):
                obs = obss_xr[col].to_numpy()
                pred = preds_xr[col].to_numpy()
                # eval_log will be updated rather than completely replaced, no need to use eval_log["key"]
                eval_log = calculate_and_record_metrics(
                    obs,
                    pred,
                    evaluation_metrics,
                    col,
                    fill_nan[i] if isinstance(fill_nan, list) else fill_nan,
                    eval_log,
                    horizon_idx + 1,
                )
        return eval_log

    for i, col in enumerate(target_col):
        obs = obss_xr[col].to_numpy()
        pred = preds_xr[col].to_numpy()
        # eval_log will be updated rather than completely replaced, no need to use eval_log["key"]
        eval_log = calculate_and_record_metrics(
            obs,
            pred,
            evaluation_metrics,
            col,
            fill_nan[i] if isinstance(fill_nan, list) else fill_nan,
            eval_log,
        )
    return eval_log

get_lastest_logger_file_in_a_dir(dir_path)

Get the last logger file in a directory

Parameters

dir_path : str the directory

Returns

str the path of the logger file

Source code in torchhydro/trainers/train_utils.py
def get_lastest_logger_file_in_a_dir(dir_path):
    """Get the last logger file in a directory

    Parameters
    ----------
    dir_path : str
        the directory

    Returns
    -------
    str
        the path of the logger file
    """
    pattern = r"^\d{1,2}_[A-Za-z]+_\d{6}_\d{2}(AM|PM)\.json$"
    pth_files_lst = [
        os.path.join(dir_path, file)
        for file in os.listdir(dir_path)
        if re.match(pattern, file)
    ]
    return get_latest_file_in_a_lst(pth_files_lst)

get_latest_pbm_param_file(param_dir)

Get the latest parameter file of physics-based models in the current directory.

Parameters

param_dir : str The directory of parameter files.

Returns

str The latest parameter file.

Source code in torchhydro/trainers/train_utils.py
def get_latest_pbm_param_file(param_dir):
    """Get the latest parameter file of physics-based models in the current directory.

    Parameters
    ----------
    param_dir : str
        The directory of parameter files.

    Returns
    -------
    str
        The latest parameter file.
    """
    param_file_lst = [
        os.path.join(param_dir, f)
        for f in os.listdir(param_dir)
        if f.startswith("pb_params") and f.endswith(".csv")
    ]
    param_files = [Path(f) for f in param_file_lst]
    param_file_names_lst = [param_file.stem.split("_") for param_file in param_files]
    ctimes = [
        int(param_file_names[param_file_names.index("params") + 1])
        for param_file_names in param_file_names_lst
    ]
    return param_files[ctimes.index(max(ctimes))] if ctimes else None

get_latest_tensorboard_event_file(log_dir)

Get the latest event file in the log_dir directory.

Parameters

log_dir : str The directory where the event files are stored.

Returns

str The latest event file.

Source code in torchhydro/trainers/train_utils.py
def get_latest_tensorboard_event_file(log_dir):
    """Get the latest event file in the log_dir directory.

    Parameters
    ----------
    log_dir : str
        The directory where the event files are stored.

    Returns
    -------
    str
        The latest event file.
    """
    event_file_lst = [
        os.path.join(log_dir, f) for f in os.listdir(log_dir) if f.startswith("events")
    ]
    event_files = [Path(f) for f in event_file_lst]
    event_file_names_lst = [event_file.stem.split(".") for event_file in event_files]
    ctimes = [
        int(event_file_names[event_file_names.index("tfevents") + 1])
        for event_file_names in event_file_names_lst
    ]
    return event_files[ctimes.index(max(ctimes))]

get_masked_tensors(variable_length_cfgs, batch, seq_first)

Get the mask for the data

Parameters

variable_length_cfgs : dict The variable length configuration batch : tuple or list The batch data from collate_fn or dataset seq_first : bool Whether the data is in sequence first format

Returns

tuple For standard datasets: (xs, ys, xs_mask, ys_mask, xs_lens, ys_lens) For GNN datasets: (xs, ys, edge_index, edge_weight, xs_mask, ys_mask, xs_lens, ys_lens) For GNN with batch vector: (xs, ys, edge_index, edge_weight, batch_vector, xs_mask, ys_mask, xs_lens, ys_lens)

Source code in torchhydro/trainers/train_utils.py
def get_masked_tensors(variable_length_cfgs, batch, seq_first):
    """Get the mask for the data

    Parameters
    ----------
    variable_length_cfgs : dict
        The variable length configuration
    batch : tuple or list
        The batch data from collate_fn or dataset
    seq_first : bool
        Whether the data is in sequence first format

    Returns
    -------
    tuple
        For standard datasets: (xs, ys, xs_mask, ys_mask, xs_lens, ys_lens)
        For GNN datasets: (xs, ys, edge_index, edge_weight, xs_mask, ys_mask, xs_lens, ys_lens)
        For GNN with batch vector: (xs, ys, edge_index, edge_weight, batch_vector, xs_mask, ys_mask, xs_lens, ys_lens)
    """
    xs_mask = None
    ys_mask = None
    xs_lens = None
    ys_lens = None
    edge_index = None
    edge_weight = None

    if variable_length_cfgs is None:
        # Check batch length to determine format
        if len(batch) == 5:
            # GNN batch with batch_vector: [sxc, y, edge_index, edge_weight, batch_vector]
            xs, ys, edge_index, edge_weight, batch_vector = batch
            return (
                xs,
                ys,
                edge_index,
                edge_weight,
                batch_vector,
                xs_mask,
                ys_mask,
                xs_lens,
                ys_lens,
            )
        elif len(batch) == 4:
            # GNN batch: [sxc, y, edge_index, edge_weight]
            xs, ys, edge_index, edge_weight = batch
            return xs, ys, edge_index, edge_weight, xs_mask, ys_mask, xs_lens, ys_lens
        else:
            # Standard batch: [xs, ys]
            xs, ys = batch[0], batch[1]
            return xs, ys, xs_mask, ys_mask, xs_lens, ys_lens

    if variable_length_cfgs.get("use_variable_length", False):
        # When using variable length training, batch comes from varied_length_collate_fn
        # which returns [xs_pad, ys_pad, xs_lens, ys_lens, xs_mask, ys_mask]
        if len(batch) >= 6:
            xs, ys, xs_lens, ys_lens, xs_mask_bool, ys_mask_bool = batch[:6]
        else:
            # Fallback: treat as regular batch with first two elements
            xs, ys = batch[0], batch[1]
            xs_lens = ys_lens = xs_mask_bool = ys_mask_bool = None

        if xs_mask_bool is None and ys_mask_bool is None:
            # sometime even you choose to use variable length training, the batch data may still be fixed length
            # so we need to return the batch data directly
            return xs, ys, xs_mask_bool, ys_mask_bool, xs_lens, ys_lens
        # Convert masks to the format expected by model (float tensor with shape [..., 1])
        xs_mask = xs_mask_bool.unsqueeze(-1).float()  # [batch, seq, 1]
        ys_mask = ys_mask_bool.unsqueeze(-1).float()  # [batch, seq, 1]

        # Convert to appropriate format for model if needed
        if seq_first:
            xs_mask = xs_mask.transpose(0, 1)  # [seq, batch, 1]
            ys_mask = ys_mask.transpose(0, 1)  # [seq, batch, 1]
    else:
        # Check batch length to determine format
        if len(batch) == 5:
            # GNN batch with batch_vector: [sxc, y, edge_index, edge_weight, batch_vector]
            xs, ys, edge_index, edge_weight, batch_vector = batch
        elif len(batch) == 4:
            # GNN batch: [sxc, y, edge_index, edge_weight]
            xs, ys, edge_index, edge_weight = batch
        else:
            # Standard batch: [xs, ys]
            xs, ys = batch[0], batch[1]

    # Return appropriate format based on what we have
    if edge_index is not None and edge_weight is not None:
        return (
            xs,
            ys,
            edge_index,
            edge_weight,
            batch_vector,
            xs_mask,
            ys_mask,
            xs_lens,
            ys_lens,
        )
    else:
        return xs, ys, xs_mask, ys_mask, xs_lens, ys_lens

get_preds_to_be_eval(valorte_data_loader, evaluation_cfgs, output, labels)

Get prediction results prepared for evaluation: the denormalized data without metrics by different eval ways

Parameters

valorte_data_loader : DataLoader validation or test data loader evaluation_cfgs : dict evaluation configs output : np.ndarray model output labels : np.ndarray model target

Returns

tuple description

Source code in torchhydro/trainers/train_utils.py
def get_preds_to_be_eval(
    valorte_data_loader,
    evaluation_cfgs,
    output,
    labels,
):
    """
    Get prediction results prepared for evaluation:
    the denormalized data without metrics by different eval ways

    Parameters
    ----------
    valorte_data_loader : DataLoader
        validation or test data loader
    evaluation_cfgs : dict
        evaluation configs
    output : np.ndarray
        model output
    labels : np.ndarray
        model target

    Returns
    -------
    tuple
        _description_
    """
    evaluator = evaluation_cfgs["evaluator"]
    # this test_rolling means how we perform prediction during testing
    test_rolling = evaluation_cfgs["rolling"]
    batch_size = valorte_data_loader.batch_size
    target_scaler = valorte_data_loader.dataset.target_scaler
    target_data = target_scaler.data_target
    rho = valorte_data_loader.dataset.rho
    horizon = valorte_data_loader.dataset.horizon
    warmup_length = valorte_data_loader.dataset.warmup_length
    hindcast_output_window = target_scaler.data_cfgs["hindcast_output_window"]
    nf = valorte_data_loader.dataset.noutputvar  # number of features
    # number of time steps after warmup as outputs typically don't include warmup period
    nt = valorte_data_loader.dataset.nt - warmup_length
    basin_num = len(target_data.basin)
    data_shape = (basin_num, nt, nf)
    if evaluator["eval_way"] == "once":
        stride = evaluator["stride"]
        if stride > 0:
            if horizon != stride:
                raise NotImplementedError(
                    "horizon should be equal to stride in evaluator if you chose eval_way to be once, or else you need to change the eval_way to be 1pace or rolling"
                )
            obs = _rolling_preds_for_once_eval(
                (basin_num, horizon, nf),
                rho,
                evaluation_cfgs["forecast_length"],
                stride,
                hindcast_output_window,
                target_data.reshape(basin_num, horizon, nf),
            )
            pred = _rolling_preds_for_once_eval(
                (basin_num, horizon, nf),
                rho,
                evaluation_cfgs["forecast_length"],
                stride,
                hindcast_output_window,
                output.reshape(batch_size, horizon, nf),
            )
        else:
            if test_rolling > 0:
                raise RuntimeError(
                    "please set rolling to 0 when you chose eval way as once and stride=0"
                )
            obs = labels.reshape(basin_num, -1, nf)
            pred = output.reshape(basin_num, -1, nf)
    elif evaluator["eval_way"] == "1pace":
        if test_rolling < 1:
            raise NotImplementedError(
                "rolling should be larger than 0 if you chose eval_way to be 1pace"
            )
        pace_idx = evaluator["pace_idx"]
        # stride = evaluator.get("stride", 1)
        # for 1pace with pace_idx meaning which value of output was chosen to show
        # 1st, we need to transpose data to 4-dim to show the whole data

        # TODO:check should we select which def
        pred = _recover_samples_to_basin(output, valorte_data_loader, pace_idx)
        obs = _recover_samples_to_basin(labels, valorte_data_loader, pace_idx)

    elif evaluator["eval_way"] == "rolling":
        # 获取滚动预测所需的参数
        stride = evaluator.get("stride", 1)
        if stride != 1:
            raise NotImplementedError(
                "if stride is not equal to 1, we think it is meaningless"
            )
        # 重组预测结果和观测值
        basin_num = len(target_data.basin)

        # 新增:根据配置选择不同的数据组织方式
        recover_mode = evaluator.get("recover_mode", "bybasins")
        stride = evaluator.get("stride", 1)
        data_shape = (basin_num, nt, nf)

        if recover_mode == "bybasins":

            pred = _recover_samples_to_4d_by_basins(
                data_shape,
                valorte_data_loader,
                stride,
                hindcast_output_window,
                output,
            )
            obs = _recover_samples_to_4d_by_basins(
                data_shape,
                valorte_data_loader,
                stride,
                hindcast_output_window,
                labels,
            )
        elif recover_mode == "byforecast":
            pred = _recover_samples_to_4d_by_forecast(
                data_shape,
                valorte_data_loader,
                stride,
                hindcast_output_window,
                output,  # samples, seq_length, nf
            )
            obs = _recover_samples_to_4d_by_forecast(
                data_shape,
                valorte_data_loader,
                stride,
                hindcast_output_window,
                labels,
            )
        elif recover_mode == "byensembles":
            pred = _recover_samples_to_3d_by_4d_ensembles(
                data_shape,
                valorte_data_loader,
                stride,
                hindcast_output_window,
                output,
            )
            obs = _recover_samples_to_3d_by_4d_ensembles(
                data_shape,
                valorte_data_loader,
                stride,
                hindcast_output_window,
                labels,
            )
        else:
            raise ValueError(
                f"Unsupported recover_mode: {recover_mode}, must be 'bybasins' or 'byforecast' or 'byensembles'"
            )
    elif evaluator["eval_way"] == "floodevent":
        # For flood event evaluation, stride is not typically used, but we set it to 1 for consistency
        stride = evaluator.get("stride", 1)
        pred = _recover_samples_to_continuous_by_floodevent(
            data_shape,
            valorte_data_loader,
            stride,
            hindcast_output_window,
            output,
        )
        obs = _recover_samples_to_continuous_by_floodevent(
            data_shape,
            valorte_data_loader,
            stride,
            hindcast_output_window,
            labels,
        )
    else:
        raise ValueError("eval_way should be rolling or 1pace")

    # pace_idx = np.nan
    recover_mode = evaluator.get("recover_mode")
    valte_dataset = valorte_data_loader.dataset
    # 检查数据维度并进行适当处理
    if pred.ndim == 4:
        # 如果是四维数据,需要根据评估方式选择合适的处理方法
        if evaluator["eval_way"] == "1pace" and "pace_idx" in evaluator:
            # 对于1pace模式,选择特定的预测步长
            pace_idx = evaluator["pace_idx"]
            # 选择特定预测步长的数据
            pred_3d = pred[:, :, pace_idx, :]
            obs_3d = obs[:, :, pace_idx, :]
            preds_xr = valte_dataset.denormalize(pred_3d, pace_idx)
            obss_xr = valte_dataset.denormalize(obs_3d, pace_idx)
        elif evaluator["eval_way"] == "rolling" and recover_mode == "byforecast":
            # 对于byforecast模式,需要特殊处理
            # 创建一个列表存储每个预测步长的结果
            preds_xr_list = []
            obss_xr_list = []
            for i in range(pred.shape[2]):
                pred_3d = pred[:, :, i, :]
                obs_3d = obs[:, :, i, :]
                the_array_pred_ = np.full(target_data.shape, np.nan)
                the_array_obs_ = np.full(target_data.shape, np.nan)
                start = rho + i  # TODO:need check
                end = start + pred_3d.shape[1]
                assert end <= the_array_pred_.shape[1]
                the_array_pred_[:, start:end, :] = pred_3d
                the_array_obs_[:, start:end, :] = obs_3d
                preds_xr_list.append(valte_dataset.denormalize(the_array_pred_, i))
                obss_xr_list.append(valte_dataset.denormalize(the_array_obs_, i))
            # 合并结果
            # preds_xr = xr.concat(preds_xr_list, dim="horizon")
            # obss_xr = xr.concat(obss_xr_list, dim="horizon")
            return obss_xr_list, preds_xr_list
        elif evaluator["eval_way"] == "rolling" and recover_mode == "bybasins":
            # 对于其他情况,可以考虑将四维数据转换为三维
            # 例如,取最后一个预测步长
            preds_xr_list = []
            obss_xr_list = []
            for i in range(pred.shape[0]):
                pred_3d = pred[i, :, :, :]
                obs_3d = obs[i, :, :, :]
                selected_data = target_scaler.data_target
                the_array_pred_ = np.full(selected_data.shape, np.nan)
                the_array_obs_ = np.full(selected_data.shape, np.nan)
                start = rho  # TODO:need check
                end = start + pred_3d.shape[1]  # 自动计算填充的结束位置

                # 检查是否越界(可选)
                assert end <= the_array_pred_.shape[1]  # "填充范围超出目标数组的边界"

                # 执行填充
                the_array_pred_[:, start:end, :] = pred_3d
                the_array_obs_[:, start:end, :] = obs_3d

                preds_xr = valte_dataset.denormalize(the_array_pred_, -1)
                obss_xr = valte_dataset.denormalize(the_array_obs_, -1)
    else:
        # for 3d data, directly process
        # TODO: maybe need more test for the pace_idx case
        preds_xr = valte_dataset.denormalize(pred)
        obss_xr = valte_dataset.denormalize(obs)

    def _align_and_order(_obs, _pred):
        # 对齐到公共 (basin,time,variable) 的交集,避免 outer 引入 NaN
        _obs, _pred = xr.align(_obs, _pred, join="inner")
        # time 维为空(无交集)时直接抛错,避免进入 nanmean
        if _obs.sizes.get("time", 0) == 0:
            raise ValueError(
                "No overlapping timestamps between observations and predictions "
                f"(obs.time len={_obs.sizes.get('time',0)}, pred.time len={_pred.sizes.get('time',0)})."
            )
        # 按时间排序(保险)
        if "time" in _obs.dims:
            _obs = _obs.sortby("time")
        if "time" in _pred.dims:
            _pred = _pred.sortby("time")
        # 规范维度顺序(若存在)
        wanted = [d for d in ("basin", "time", "variable") if d in _obs.dims]
        _obs = _obs.transpose(*wanted, missing_dims="ignore")
        _pred = _pred.transpose(*wanted, missing_dims="ignore")
        return _obs, _pred

    # 单对象 vs 列表分别处理
    if preds_xr is not None and obss_xr is not None:
        obss_xr, preds_xr = _align_and_order(obss_xr, preds_xr)
        return obss_xr, preds_xr

    elif preds_xr_list is not None and obss_xr_list is not None:
        obss_aligned, preds_aligned = [], []
        for _o, _p in zip(obss_xr_list, preds_xr_list):
            _o2, _p2 = _align_and_order(_o, _p)
            obss_aligned.append(_o2)
            preds_aligned.append(_p2)
        return obss_aligned, preds_aligned

    else:
        # 理论不应走到这
        raise RuntimeError("Failed to build preds_xr / obss_xr for evaluation.")

gnn_collate_fn(batch)

Custom collate function for GNN datasets that handles variable-sized graphs.

Parameters:

Name Type Description Default
batch

A list of samples, where each sample is a tuple of (sxc, y, edge_index, edge_weight).

required

Returns:

Type Description
A list containing batched tensors

[batched_sxc, batched_y, batched_edge_index, batched_edge_weight, batch_vector]

Source code in torchhydro/trainers/train_utils.py
def gnn_collate_fn(batch):
    """Custom collate function for GNN datasets that handles variable-sized graphs.

    Args:
        batch: A list of samples, where each sample is a tuple of
            (sxc, y, edge_index, edge_weight).

    Returns:
        A list containing batched tensors:
        [batched_sxc, batched_y, batched_edge_index, batched_edge_weight, batch_vector]
    """
    import torch

    if len(batch) == 0:
        return []

    # Unpack the batch
    sxc_list, y_list, edge_index_list, edge_weight_list = zip(*batch)

    # Batch the target values (y) - these should have the same shape
    batched_y = torch.stack(y_list, dim=0)  # [batch_size, forecast_length, output_dim]

    # Find the maximum number of nodes in this batch
    max_num_nodes = max(sxc.shape[0] for sxc in sxc_list)

    # Get dimensions
    batch_size = len(sxc_list)
    seq_length = sxc_list[0].shape[1]
    feature_dim = sxc_list[0].shape[2]

    # Create padded tensor for node features
    batched_sxc = torch.zeros(batch_size, max_num_nodes, seq_length, feature_dim)

    # Create batched edge indices and weights
    # For each graph in the batch, we need to offset node indices
    batched_edge_index = []
    batched_edge_weight = []
    batch_vector = []
    node_offset = 0
    for i, (sxc, edge_index, edge_weight) in enumerate(
        zip(sxc_list, edge_index_list, edge_weight_list)
    ):
        num_nodes = sxc.shape[0]
        # Fill the padded tensor with actual node features
        batched_sxc[i, :num_nodes] = sxc
        # For edge indices, we need to offset by node_offset to make them unique across batch
        if edge_index.numel() > 0:
            if edge_index.max() >= num_nodes:
                print(
                    f"Warning: Graph {i} has edge indices {edge_index.max().item()} >= num_nodes {num_nodes}"
                )
                valid_mask = (edge_index[0] < num_nodes) & (edge_index[1] < num_nodes)
                edge_index = edge_index[:, valid_mask]
                edge_weight = edge_weight[valid_mask]
            if edge_index.numel() > 0:
                offset_edge_index = edge_index + node_offset
                batched_edge_index.append(offset_edge_index)
                batched_edge_weight.append(edge_weight)
        # batch_vector: for each node in this graph, assign batch index i
        batch_vector.append(torch.full((num_nodes,), i, dtype=torch.long))
        node_offset += num_nodes
    # Concatenate edge indices and weights if they exist
    if batched_edge_index:
        batched_edge_index = torch.cat(batched_edge_index, dim=1)  # [2, total_edges]
        batched_edge_weight = torch.cat(batched_edge_weight, dim=0)  # [total_edges]
    else:
        batched_edge_index = torch.empty((2, 0), dtype=torch.long)
        batched_edge_weight = torch.empty(0)
    batch_vector = torch.cat(batch_vector, dim=0)  # [total_nodes]
    return [
        batched_sxc,
        batched_y,
        batched_edge_index,
        batched_edge_weight,
        batch_vector,
    ]

model_infer(seq_first, device, model, batch, variable_length_cfgs=None, return_key=None)

Unified model inference function with variable length support

Parameters

seq_first : bool if True, the input data is sequence first device : torch.device cpu or gpu model : torch.nn.Module the model batch : tuple or list batch data from collate_fn or dataset variable_length_cfgs : dict, optional variable length configuration containing mask settings return_key : str, optional when model returns a dict, choose which key (e.g., "f2") to return. if None, defaults to the last frequency (max key).

Source code in torchhydro/trainers/train_utils.py
def model_infer(
    seq_first, device, model, batch, variable_length_cfgs=None, return_key=None
):
    """
    Unified model inference function with variable length support

    Parameters
    ----------
    seq_first : bool
        if True, the input data is sequence first
    device : torch.device
        cpu or gpu
    model : torch.nn.Module
        the model
    batch : tuple or list
        batch data from collate_fn or dataset
    variable_length_cfgs : dict, optional
        variable length configuration containing mask settings
    return_key : str, optional
        when model returns a dict, choose which key (e.g., "f2") to return.
        if None, defaults to the last frequency (max key).
    """
    result = get_masked_tensors(variable_length_cfgs, batch, seq_first)

    # --- unpack inputs ---
    if len(result) == 9:
        (
            xs,
            ys,
            edge_index,
            edge_weight,
            batch_vector,
            xs_mask,
            ys_mask,
            xs_lens,
            ys_lens,
        ) = result
    elif len(result) == 8:
        xs, ys, edge_index, edge_weight, xs_mask, ys_mask, xs_lens, ys_lens = result
        batch_vector = None
    else:
        xs, ys, xs_mask, ys_mask, xs_lens, ys_lens = result
        edge_index = edge_weight = batch_vector = None

    # --- move xs to device ---
    if isinstance(xs, list):
        xs = [
            (
                x.permute(1, 0, 2).to(device)
                if seq_first and x.ndim == 3
                else x.to(device)
            )
            for x in xs
        ]
    else:
        xs = [
            (
                xs.permute(1, 0, 2).to(device)
                if seq_first and xs.ndim == 3
                else xs.to(device)
            )
        ]

    # --- move ys to device ---
    if ys is not None:
        ys = (
            ys.permute(1, 0, 2).to(device)
            if seq_first and ys.ndim == 3
            else ys.to(device)
        )

    # --- move graph data ---
    if edge_index is not None:
        edge_index = edge_index.to(device)
    if edge_weight is not None:
        edge_weight = edge_weight.to(device)
    if batch_vector is not None:
        batch_vector = batch_vector.to(device)

    # --- forward ---
    if xs_mask is not None and ys_mask is not None:
        if edge_index is not None and edge_weight is not None:
            output = model(
                *xs,
                edge_index=edge_index,
                edge_weight=edge_weight,
                batch_vector=batch_vector,
                mask=xs_mask,
                seq_lengths=xs_lens,
            )
        else:
            output = model(*xs, mask=xs_mask, seq_lengths=xs_lens)
    else:
        if edge_index is not None and edge_weight is not None:
            output = model(
                *xs,
                edge_index=edge_index,
                edge_weight=edge_weight,
                batch_vector=batch_vector,
            )
        else:
            output = model(*xs)

    # --- handle model outputs ---
    if isinstance(output, tuple):
        output = output[0]

    if isinstance(output, dict):
        # 默认取最高频的输出
        if return_key is None:
            return_key = sorted(output.keys())[-1]  # e.g., "f2"
        if return_key not in output:
            raise KeyError(
                f"Model returned keys {list(output.keys())}, but return_key='{return_key}' not found"
            )
        output = output[return_key]

    if ys_mask is not None:
        ys = ys.masked_fill(ys_mask == 0, torch.nan)

    # --- seq_first transpose back ---
    if seq_first:
        output = output.transpose(0, 1)
        if ys is not None:
            ys = ys.transpose(0, 1)

    return ys, output

torch_single_train(model, opt, criterion, data_loader, device=None, **kwargs)

Training function for one epoch

Parameters

model a PyTorch model inherit from nn.Module opt optimizer function from PyTorch optim.Optimizer criterion loss function data_loader object for loading data to the model device where we put the tensors and models

Returns

tuple(torch.Tensor, int) loss of this epoch and number of all iterations

Raises

ValueError if nan exits, raise a ValueError

Source code in torchhydro/trainers/train_utils.py
def torch_single_train(
    model,
    opt: optim.Optimizer,
    criterion,
    data_loader: DataLoader,
    device=None,
    **kwargs,
):
    """
    Training function for one epoch

    Parameters
    ----------
    model
        a PyTorch model inherit from nn.Module
    opt
        optimizer function from PyTorch optim.Optimizer
    criterion
        loss function
    data_loader
        object for loading data to the model
    device
        where we put the tensors and models

    Returns
    -------
    tuple(torch.Tensor, int)
        loss of this epoch and number of all iterations

    Raises
    --------
    ValueError
        if nan exits, raise a ValueError
    """
    # we will set model.eval() in the validation function so here we should set model.train()
    model.train()
    n_iter_ep = 0
    running_loss = 0.0
    which_first_tensor = kwargs["which_first_tensor"]
    seq_first = which_first_tensor != "batch"
    variable_length_cfgs = data_loader.dataset.training_cfgs.get(
        "variable_length_cfgs", None
    )
    pbar = tqdm(data_loader)

    for _, batch in enumerate(pbar):
        # mask handling is already done inside model_infer function
        trg, output = model_infer(seq_first, device, model, batch, variable_length_cfgs)
        loss = compute_loss(trg, output, criterion, **kwargs)
        if loss > 100:
            print("Warning: high loss detected")
        if torch.isnan(loss):
            raise ValueError("nan loss detected")
            # continue
        loss.backward()  # Backpropagate to compute the current gradient
        opt.step()  # Update network parameters based on gradients
        model.zero_grad()  # clear gradient
        if loss == float("inf"):
            raise ValueError(
                "Error infinite loss detected. Try normalizing data or performing interpolation"
            )
        running_loss += loss.item()
        n_iter_ep += 1
    if n_iter_ep == 0:
        raise ValueError(
            "All batch computations of loss result in NAN. Please check the data."
        )
    total_loss = running_loss / float(n_iter_ep)
    return total_loss, n_iter_ep

varied_length_collate_fn(batch)

Collate function for variable length training

This function is automatically used by DataLoader when variable_length_cfgs["use_variable_length"] is True. It pads sequences to the same length and generates corresponding masks.

Parameters

batch : list of tuples The batch data after the dataset getitem method

Returns

list [xs_pad, ys_pad, xs_lens, ys_lens, xs_mask, ys_mask] - xs_pad: padded input sequences [batch, max_seq_len, input_dim] - ys_pad: padded output sequences [batch, max_seq_len, output_dim] - xs_lens: original sequence lengths for input - ys_lens: original sequence lengths for output - xs_mask: valid position mask for input [batch, max_seq_len] - ys_mask: valid position mask for output [batch, max_seq_len]

Source code in torchhydro/trainers/train_utils.py
def varied_length_collate_fn(batch):
    """Collate function for variable length training

    This function is automatically used by DataLoader when variable_length_cfgs["use_variable_length"] is True.
    It pads sequences to the same length and generates corresponding masks.

    Parameters
    ----------
    batch : list of tuples
        The batch data after the dataset __getitem__ method

    Returns
    -------
    list
        [xs_pad, ys_pad, xs_lens, ys_lens, xs_mask, ys_mask]
        - xs_pad: padded input sequences [batch, max_seq_len, input_dim]
        - ys_pad: padded output sequences [batch, max_seq_len, output_dim]
        - xs_lens: original sequence lengths for input
        - ys_lens: original sequence lengths for output
        - xs_mask: valid position mask for input [batch, max_seq_len]
        - ys_mask: valid position mask for output [batch, max_seq_len]
    """

    xs, ys = zip(*batch)
    # sometimes x is a tuple like in dpl dataset, then we can get the shape of the first element as the length
    xs_lens = [x[0].shape[0] if type(x) in [tuple, list] else x.shape[0] for x in xs]
    ys_lens = [y[0].shape[0] if type(y) in [tuple, list] else y.shape[0] for y in ys]
    # if all ys_lens are the same, use default collate_fn to create tensors
    if len(set(ys_lens)) == 1 and len(set(xs_lens)) == 1:
        xs_tensor = default_collate(xs)
        ys_tensor = default_collate(ys)
        return [xs_tensor, ys_tensor, None, None, None, None]

    # pad the batch data with padding value 0
    xs_pad = pad_sequence(xs, batch_first=True, padding_value=0)
    ys_pad = pad_sequence(ys, batch_first=True, padding_value=0)

    # generate the mask for the batch data
    # xs_mask: [batch_size, max_seq_len] or [batch_size, max_seq_len, 1]
    batch_size = len(xs_lens)
    max_xs_len = max(xs_lens)
    max_ys_len = max(ys_lens)

    # create the mask for the input sequence (True for valid positions, False for padding positions)
    xs_mask = torch.zeros(batch_size, max_xs_len, dtype=torch.bool)
    for i, length in enumerate(xs_lens):
        xs_mask[i, :length] = True

    # create the mask for the output sequence
    ys_mask = torch.zeros(batch_size, max_ys_len, dtype=torch.bool)
    for i, length in enumerate(ys_lens):
        ys_mask[i, :length] = True

    return [
        xs_pad,
        ys_pad,
        xs_lens,
        ys_lens,
        xs_mask,
        ys_mask,
    ]

trainer

Author: Wenyu Ouyang Date: 2021-12-05 11:21:58 LastEditTime: 2025-11-08 16:02:09 LastEditors: Wenyu Ouyang Description: Main function for training and testing FilePath: orchhydro orchhydro rainers rainer.py Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.

ensemble_train_and_evaluate(cfgs)

Function to train and test for ensemble models

Parameters

cfgs Dictionary containing all configs needed to run the model

Returns

None

Source code in torchhydro/trainers/trainer.py
def ensemble_train_and_evaluate(cfgs: Dict):
    """
    Function to train and test for ensemble models

    Parameters
    ----------
    cfgs
        Dictionary containing all configs needed to run the model

    Returns
    -------
    None
    """
    # for basins and models
    ensemble = cfgs["training_cfgs"]["ensemble"]
    if not ensemble:
        raise ValueError(
            "ensemble should be True, otherwise should use train_and_evaluate rather than ensemble_train_and_evaluate"
        )
    ensemble_items = cfgs["training_cfgs"]["ensemble_items"]
    number_of_items = len(ensemble_items)
    if number_of_items == 0:
        raise ValueError("ensemble_items should not be empty")
    keys_list = list(ensemble_items.keys())
    if "kfold" in keys_list:
        _trans_kfold_to_periods(cfgs, ensemble_items, "kfold")
    _nested_loop_train_and_evaluate(keys_list, 0, ensemble_items, cfgs)

set_random_seed(seed)

Set a random seed to guarantee the reproducibility

Parameters

seed a number

Returns

None

Source code in torchhydro/trainers/trainer.py
def set_random_seed(seed):
    """
    Set a random seed to guarantee the reproducibility

    Parameters
    ----------
    seed
        a number

    Returns
    -------
    None
    """
    # print("Random seed:", seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

train_and_evaluate(cfgs)

Function to train and test a Model

Parameters

cfgs Dictionary containing all configs needed to run the model

Returns

None

Source code in torchhydro/trainers/trainer.py
def train_and_evaluate(cfgs: Dict):
    """
    Function to train and test a Model

    Parameters
    ----------
    cfgs
        Dictionary containing all configs needed to run the model

    Returns
    -------
    None
    """
    random_seed = cfgs["training_cfgs"]["random_seed"]
    set_random_seed(random_seed)
    resulter = Resulter(cfgs)
    deephydro = _get_deep_hydro(cfgs)
    # if train_mode is False, we only evaluate the model
    train_mode = deephydro.cfgs["training_cfgs"]["train_mode"]
    # but if train_mode is True, we still need some conditions to train the model
    continue_train = deephydro.cfgs["model_cfgs"]["continue_train"]
    is_transfer_learning = deephydro.cfgs["model_cfgs"]["model_type"] == "TransLearn"
    is_train = train_mode and (
        (deephydro.weight_path is not None and (continue_train or is_transfer_learning))
        or (deephydro.weight_path is None)
    )
    if is_train:
        deephydro.model_train()
    preds, obss = deephydro.model_evaluate()
    resulter.save_cfg(deephydro.cfgs)
    resulter.save_result(preds, obss)
    resulter.eval_result(preds, obss)