Trainers API¶
deep_hydro
¶
Author: Wenyu Ouyang Date: 2024-04-08 18:15:48 LastEditTime: 2025-07-13 16:25:31 LastEditors: Wenyu Ouyang Description: HydroDL model class FilePath: orchhydro orchhydro rainers\deep_hydro.py Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved.
DeepHydro (DeepHydroInterface)
¶
The Base Trainer class for Hydrological Deep Learning models
Source code in torchhydro/trainers/deep_hydro.py
class DeepHydro(DeepHydroInterface):
"""
The Base Trainer class for Hydrological Deep Learning models
"""
def __init__(
self,
cfgs: Dict,
pre_model=None,
):
"""
Parameters
----------
cfgs
configs for the model
pre_model
a pre-trained model, if it is not None,
we will use its weights to initialize this model
by default None
"""
super().__init__(cfgs)
# Initialize fabric based on configuration
self.fabric = create_fabric_wrapper(cfgs.get("training_cfgs", {}))
self.pre_model = pre_model
self.model = self.fabric.setup_module(self.load_model())
if cfgs["training_cfgs"]["train_mode"]:
self.traindataset = self.make_dataset("train")
if cfgs["data_cfgs"]["t_range_valid"] is not None:
self.validdataset = self.make_dataset("valid")
self.testdataset: BaseDataset = self.make_dataset("test")
@property
def device(self):
"""Get the device from fabric wrapper"""
return self.fabric._device
def load_model(self, mode="train"):
"""
Load a time series forecast model in pytorch_model_dict in model_dict_function.py
Returns
-------
object
model in pytorch_model_dict in model_dict_function.py
"""
if mode == "infer":
if self.weight_path is None or self.cfgs["model_cfgs"]["continue_train"]:
# if no weight path is provided
# or weight file is provided but continue train again,
# we will use the trained model in the new case_dir directory
self.weight_path = self._get_trained_model()
elif mode != "train":
raise ValueError("Invalid mode; must be 'train' or 'infer'")
model_cfgs = self.cfgs["model_cfgs"]
model_name = model_cfgs["model_name"]
if model_name not in pytorch_model_dict:
raise NotImplementedError(
f"Error the model {model_name} was not found in the model dict. Please add it."
)
if self.pre_model is not None:
return self._load_pretrain_model()
elif self.weight_path is not None:
return self._load_model_from_pth()
else:
return pytorch_model_dict[model_name](**model_cfgs["model_hyperparam"])
def _load_pretrain_model(self):
"""load a pretrained model as the initial model"""
return self.pre_model
def _load_model_from_pth(self):
weight_path = self.weight_path
model_cfgs = self.cfgs["model_cfgs"]
model_name = model_cfgs["model_name"]
model = pytorch_model_dict[model_name](**model_cfgs["model_hyperparam"])
checkpoint = torch.load(weight_path, map_location=self.device)
model.load_state_dict(checkpoint)
print("Weights sucessfully loaded")
return model
def make_dataset(self, is_tra_val_te: str):
"""
Initializes a pytorch dataset.
Parameters
----------
is_tra_val_te
train or valid or test
Returns
-------
object
an object initializing from class in datasets_dict in data_dict.py
"""
data_cfgs = self.cfgs["data_cfgs"]
dataset_name = data_cfgs["dataset"]
if dataset_name in list(datasets_dict.keys()):
dataset = datasets_dict[dataset_name](self.cfgs, is_tra_val_te)
else:
raise NotImplementedError(
f"Error the dataset {str(dataset_name)} was not found in the dataset dict. Please add it."
)
return dataset
def model_train(self) -> None:
"""train a hydrological DL model"""
# A dictionary of the necessary parameters for training
training_cfgs = self.cfgs["training_cfgs"]
# The file path to load model weights from; defaults to "model_save"
model_filepath = self.cfgs["data_cfgs"]["case_dir"]
data_cfgs = self.cfgs["data_cfgs"]
es = None
if training_cfgs["early_stopping"]:
es = EarlyStopper(training_cfgs["patience"])
criterion = self._get_loss_func(training_cfgs)
opt = self._get_optimizer(training_cfgs)
scheduler = self._get_scheduler(training_cfgs, opt)
max_epochs = training_cfgs["epochs"]
start_epoch = training_cfgs["start_epoch"]
# use PyTorch's DataLoader to load the data into batches in each epoch
data_loader, validation_data_loader = self._get_dataloader(
training_cfgs, data_cfgs
)
logger = TrainLogger(model_filepath, self.cfgs, opt)
for epoch in range(start_epoch, max_epochs + 1):
with logger.log_epoch_train(epoch) as train_logs:
total_loss, n_iter_ep = torch_single_train(
self.model,
opt,
criterion,
data_loader,
device=self.device,
which_first_tensor=training_cfgs["which_first_tensor"],
)
train_logs["train_loss"] = total_loss
train_logs["model"] = self.model
valid_loss = None
valid_metrics = None
if data_cfgs["t_range_valid"] is not None:
with logger.log_epoch_valid(epoch) as valid_logs:
valid_loss, valid_metrics = self._1epoch_valid(
training_cfgs, criterion, validation_data_loader, valid_logs
)
self._scheduler_step(training_cfgs, scheduler, valid_loss)
logger.save_session_param(
epoch, total_loss, n_iter_ep, valid_loss, valid_metrics
)
logger.save_model_and_params(self.model, epoch, self.cfgs)
if es and not es.check_loss(
self.model,
valid_loss,
self.cfgs["data_cfgs"]["case_dir"],
):
print("Stopping model now")
break
# logger.plot_model_structure(self.model)
logger.tb.close()
# return the trained model weights and bias and the epoch loss
return self.model.state_dict(), sum(logger.epoch_loss) / len(logger.epoch_loss)
def _get_scheduler(self, training_cfgs, opt):
lr_scheduler_cfg = training_cfgs["lr_scheduler"]
if "lr" in lr_scheduler_cfg and "lr_factor" not in lr_scheduler_cfg:
scheduler = LambdaLR(opt, lr_lambda=lambda epoch: 1.0)
elif isinstance(lr_scheduler_cfg, dict) and all(
isinstance(epoch, int) for epoch in lr_scheduler_cfg
):
# piecewise constant learning rate
epochs = sorted(lr_scheduler_cfg.keys())
values = [lr_scheduler_cfg[e] for e in epochs]
def lr_lambda(epoch):
idx = bisect.bisect_right(epochs, epoch) - 1
return 1.0 if idx < 0 else values[idx]
scheduler = LambdaLR(opt, lr_lambda=lr_lambda)
elif "lr_factor" in lr_scheduler_cfg and "lr_patience" not in lr_scheduler_cfg:
scheduler = ExponentialLR(opt, gamma=lr_scheduler_cfg["lr_factor"])
elif "lr_factor" in lr_scheduler_cfg:
scheduler = ReduceLROnPlateau(
opt,
mode="min",
factor=lr_scheduler_cfg["lr_factor"],
patience=lr_scheduler_cfg["lr_patience"],
)
else:
raise ValueError("Invalid lr_scheduler configuration")
return scheduler
def _scheduler_step(self, training_cfgs, scheduler, valid_loss):
lr_scheduler_cfg = training_cfgs["lr_scheduler"]
required_keys = {"lr_factor", "lr_patience"}
if required_keys.issubset(lr_scheduler_cfg.keys()):
scheduler.step(valid_loss)
else:
scheduler.step()
def _1epoch_valid(
self, training_cfgs, criterion, validation_data_loader, valid_logs
):
valid_obss_np, valid_preds_np, valid_loss = compute_validation(
self.model,
criterion,
validation_data_loader,
device=self.device,
which_first_tensor=training_cfgs["which_first_tensor"],
)
valid_logs["valid_loss"] = valid_loss
if (
self.cfgs["training_cfgs"]["valid_batch_mode"] == "test"
and self.cfgs["training_cfgs"]["calc_metrics"]
):
# NOTE: Now we only evaluate the metrics for test-mode validation
target_col = self.cfgs["data_cfgs"]["target_cols"]
valid_metrics = evaluate_validation(
validation_data_loader,
valid_preds_np,
valid_obss_np,
self.cfgs["evaluation_cfgs"],
target_col,
)
valid_logs["valid_metrics"] = valid_metrics
return valid_loss, valid_metrics
return valid_loss, None
def _get_trained_model(self):
model_loader = self.cfgs["evaluation_cfgs"]["model_loader"]
model_pth_dir = self.cfgs["data_cfgs"]["case_dir"]
return read_pth_from_model_loader(model_loader, model_pth_dir)
def model_evaluate(self) -> Tuple[Dict, np.array, np.array]:
"""
A function to evaluate a model, called at end of training.
Returns
-------
tuple[dict, np.array, np.array]
eval_log, denormalized predictions and observations
"""
self.model = self.load_model(mode="infer").to(self.device)
preds_xr, obss_xr = self.inference()
return preds_xr, obss_xr
def inference(self) -> Tuple[xr.Dataset, xr.Dataset]:
"""infer using trained model and unnormalized results"""
data_cfgs = self.cfgs["data_cfgs"]
training_cfgs = self.cfgs["training_cfgs"]
test_dataloader = self._get_dataloader(training_cfgs, data_cfgs, mode="infer")
seq_first = training_cfgs["which_first_tensor"] == "sequence"
self.model.eval()
# here the batch is just an index of lookup table, so any batch size could be chosen
test_preds = []
obss = []
with torch.no_grad():
test_preds = []
obss = []
for i, batch in enumerate(
tqdm(test_dataloader, desc="Model inference", unit="batch")
):
ys, pred = model_infer(
seq_first,
self.device,
self.model,
batch,
variable_length_cfgs=None,
return_key=(
self.cfgs.get("evaluation_cfgs", {})
.get("evaluator", {})
.get("return_key", None)
)
)
test_preds.append(pred.cpu())
obss.append(ys.cpu())
if i % 100 == 0:
torch.cuda.empty_cache()
pred = torch.cat(test_preds, dim=0).numpy() # 在最后转换为numpy
obs = torch.cat(obss, dim=0).numpy() # 在最后转换为numpy
if pred.ndim == 2:
# TODO: check
# the ndim is 2 meaning we use an Nto1 mode
# as lookup table is (basin 1's all time length, basin 2's all time length, ...)
# params of reshape should be (basin size, time length)
pred = pred.flatten().reshape(test_dataloader.test_data.y.shape[0], -1, 1)
obs = obs.flatten().reshape(test_dataloader.test_data.y.shape[0], -1, 1)
evaluation_cfgs = self.cfgs["evaluation_cfgs"]
obs_xr, pred_xr = get_preds_to_be_eval(
test_dataloader,
evaluation_cfgs,
pred,
obs,
)
return pred_xr, obs_xr
def _get_optimizer(self, training_cfgs):
params_in_opt = self.model.parameters()
return pytorch_opt_dict[training_cfgs["optimizer"]](
params_in_opt, **training_cfgs["optim_params"]
)
def _get_loss_func(self, training_cfgs):
criterion_init_params = {}
if "criterion_params" in training_cfgs:
loss_param = training_cfgs["criterion_params"]
if loss_param is not None:
for key in loss_param.keys():
if key == "loss_funcs":
criterion_init_params[key] = pytorch_criterion_dict[
loss_param[key]
]()
else:
criterion_init_params[key] = loss_param[key]
return pytorch_criterion_dict[training_cfgs["criterion"]](
**criterion_init_params
)
def _flood_event_collate_fn(self, batch):
"""自定义的洪水事件 collate 函数,确保所有样本长度一致"""
# 找到这个批次中最长的序列长度
max_len = max(tensor_data[0].shape[0] for tensor_data in batch)
# 调整所有样本到相同长度
processed_batch = []
for tensor_data in batch:
# 获取x和y(假设tensor_data[0]是x,tensor_data[1]是y)
x = tensor_data[0]
y = tensor_data[1] if len(tensor_data) > 1 else None
current_len = x.shape[0]
if current_len < max_len:
# 使用最后一个值填充x
padding_x = x[-1:].repeat(max_len - current_len, 1)
padded_x = torch.cat([x, padding_x], dim=0)
# 如果有y,也进行填充
if y is not None:
padding_y = y[-1:].repeat(max_len - current_len, 1)
padded_y = torch.cat([y, padding_y], dim=0)
else:
padded_y = None
else:
# 如果更长则截断
padded_x = x[:max_len]
padded_y = y[:max_len] if y is not None else None
if padded_y is not None:
processed_batch.append((padded_x, padded_y))
else:
processed_batch.append(padded_x)
# 根据数据结构返回堆叠后的结果
if len(processed_batch) > 0 and isinstance(processed_batch[0], tuple):
return (
torch.stack([x for x, _ in processed_batch], 0),
torch.stack([y for _, y in processed_batch], 0)
)
else:
return torch.stack(processed_batch, 0)
def _get_dataloader(self, training_cfgs, data_cfgs, mode="train"):
if mode == "infer":
_collate_fn = None
# Use GNN collate function for GNN datasets in inference mode
if hasattr(self.testdataset, '__class__') and 'GNN' in self.testdataset.__class__.__name__:
_collate_fn = gnn_collate_fn
# 使用自定义的 collate 函数处理 FloodEventDataset
elif hasattr(self.testdataset, '__class__') and 'FloodEvent' in self.testdataset.__class__.__name__:
_collate_fn = self._flood_event_collate_fn
return DataLoader(
self.testdataset,
batch_size=training_cfgs["batch_size"],
shuffle=False,
sampler=None,
batch_sampler=None,
drop_last=False,
timeout=0,
worker_init_fn=None,
collate_fn=_collate_fn,
)
worker_num = 0
pin_memory = False
if "num_workers" in training_cfgs:
worker_num = training_cfgs["num_workers"]
print(f"using {str(worker_num)} workers")
if "pin_memory" in training_cfgs:
pin_memory = training_cfgs["pin_memory"]
print(f"Pin memory set to {str(pin_memory)}")
sampler = self._get_sampler(data_cfgs, training_cfgs, self.traindataset)
_collate_fn = None
if training_cfgs["variable_length_cfgs"]["use_variable_length"]:
_collate_fn = varied_length_collate_fn
# Use GNN collate function for GNN datasets
elif hasattr(self.traindataset, '__class__') and 'GNN' in self.traindataset.__class__.__name__:
_collate_fn = gnn_collate_fn
data_loader = DataLoader(
self.traindataset,
batch_size=training_cfgs["batch_size"],
shuffle=(sampler is None),
sampler=sampler,
num_workers=worker_num,
pin_memory=pin_memory,
timeout=0,
collate_fn=_collate_fn,
)
if data_cfgs["t_range_valid"] is not None:
# Use the same collate function for validation dataset
_val_collate_fn = None
if training_cfgs["variable_length_cfgs"]["use_variable_length"]:
_val_collate_fn = varied_length_collate_fn
elif hasattr(self.validdataset, '__class__') and 'GNN' in self.validdataset.__class__.__name__:
_val_collate_fn = gnn_collate_fn
validation_data_loader = DataLoader(
self.validdataset,
batch_size=training_cfgs["batch_size"],
shuffle=False,
num_workers=worker_num,
pin_memory=pin_memory,
timeout=0,
collate_fn=_val_collate_fn,
)
return data_loader, validation_data_loader
return data_loader, None
def _get_sampler(self, data_cfgs, training_cfgs, train_dataset):
"""
return data sampler based on the provided configuration and training dataset.
Parameters
----------
data_cfgs : dict
Configuration dictionary containing parameters for data sampling. Expected keys are:
- "sampler": dict, containing:
- "name": str, name of the sampler to use.
- "sampler_hyperparam": dict, optional hyperparameters for the sampler.
training_cfgs: dict
Configuration dictionary containing parameters for training. Expected keys are:
- "batch_size": int, size of each batch.
train_dataset : Dataset
The training dataset object which contains the data to be sampled. Expected attributes are:
- ngrid: int, number of grids in the dataset.
- nt: int, number of time steps in the dataset.
- rho: int, length of the input sequence.
- warmup_length: int, length of the warmup period.
- horizon: int, length of the forecast horizon.
Returns
-------
sampler_class
An instance of the specified sampler class, initialized with the provided dataset and hyperparameters.
Raises
------
NotImplementedError
If the specified sampler name is not found in the `data_sampler_dict`.
"""
if data_cfgs["sampler"] is None:
return None
batch_size = training_cfgs["batch_size"]
rho = train_dataset.rho
warmup_length = train_dataset.warmup_length
horizon = train_dataset.horizon
ngrid = train_dataset.ngrid
nt = train_dataset.nt
sampler_name = data_cfgs["sampler"]
if sampler_name not in data_sampler_dict:
raise NotImplementedError(f"Sampler {sampler_name} not implemented yet")
sampler_class = data_sampler_dict[sampler_name]
sampler_hyperparam = {}
if sampler_name == "KuaiSampler":
sampler_hyperparam |= {
"batch_size": batch_size,
"warmup_length": warmup_length,
"rho_horizon": rho + horizon,
"ngrid": ngrid,
"nt": nt,
}
elif sampler_name == "WindowLenBatchSampler":
sampler_hyperparam |= {
"batch_size": batch_size,
}
return sampler_class(train_dataset, **sampler_hyperparam)
device
property
readonly
¶
Get the device from fabric wrapper
__init__(self, cfgs, pre_model=None)
special
¶
Parameters¶
cfgs configs for the model pre_model a pre-trained model, if it is not None, we will use its weights to initialize this model by default None
Source code in torchhydro/trainers/deep_hydro.py
def __init__(
self,
cfgs: Dict,
pre_model=None,
):
"""
Parameters
----------
cfgs
configs for the model
pre_model
a pre-trained model, if it is not None,
we will use its weights to initialize this model
by default None
"""
super().__init__(cfgs)
# Initialize fabric based on configuration
self.fabric = create_fabric_wrapper(cfgs.get("training_cfgs", {}))
self.pre_model = pre_model
self.model = self.fabric.setup_module(self.load_model())
if cfgs["training_cfgs"]["train_mode"]:
self.traindataset = self.make_dataset("train")
if cfgs["data_cfgs"]["t_range_valid"] is not None:
self.validdataset = self.make_dataset("valid")
self.testdataset: BaseDataset = self.make_dataset("test")
inference(self)
¶
infer using trained model and unnormalized results
Source code in torchhydro/trainers/deep_hydro.py
def inference(self) -> Tuple[xr.Dataset, xr.Dataset]:
"""infer using trained model and unnormalized results"""
data_cfgs = self.cfgs["data_cfgs"]
training_cfgs = self.cfgs["training_cfgs"]
test_dataloader = self._get_dataloader(training_cfgs, data_cfgs, mode="infer")
seq_first = training_cfgs["which_first_tensor"] == "sequence"
self.model.eval()
# here the batch is just an index of lookup table, so any batch size could be chosen
test_preds = []
obss = []
with torch.no_grad():
test_preds = []
obss = []
for i, batch in enumerate(
tqdm(test_dataloader, desc="Model inference", unit="batch")
):
ys, pred = model_infer(
seq_first,
self.device,
self.model,
batch,
variable_length_cfgs=None,
return_key=(
self.cfgs.get("evaluation_cfgs", {})
.get("evaluator", {})
.get("return_key", None)
)
)
test_preds.append(pred.cpu())
obss.append(ys.cpu())
if i % 100 == 0:
torch.cuda.empty_cache()
pred = torch.cat(test_preds, dim=0).numpy() # 在最后转换为numpy
obs = torch.cat(obss, dim=0).numpy() # 在最后转换为numpy
if pred.ndim == 2:
# TODO: check
# the ndim is 2 meaning we use an Nto1 mode
# as lookup table is (basin 1's all time length, basin 2's all time length, ...)
# params of reshape should be (basin size, time length)
pred = pred.flatten().reshape(test_dataloader.test_data.y.shape[0], -1, 1)
obs = obs.flatten().reshape(test_dataloader.test_data.y.shape[0], -1, 1)
evaluation_cfgs = self.cfgs["evaluation_cfgs"]
obs_xr, pred_xr = get_preds_to_be_eval(
test_dataloader,
evaluation_cfgs,
pred,
obs,
)
return pred_xr, obs_xr
load_model(self, mode='train')
¶
Load a time series forecast model in pytorch_model_dict in model_dict_function.py
Returns¶
object model in pytorch_model_dict in model_dict_function.py
Source code in torchhydro/trainers/deep_hydro.py
def load_model(self, mode="train"):
"""
Load a time series forecast model in pytorch_model_dict in model_dict_function.py
Returns
-------
object
model in pytorch_model_dict in model_dict_function.py
"""
if mode == "infer":
if self.weight_path is None or self.cfgs["model_cfgs"]["continue_train"]:
# if no weight path is provided
# or weight file is provided but continue train again,
# we will use the trained model in the new case_dir directory
self.weight_path = self._get_trained_model()
elif mode != "train":
raise ValueError("Invalid mode; must be 'train' or 'infer'")
model_cfgs = self.cfgs["model_cfgs"]
model_name = model_cfgs["model_name"]
if model_name not in pytorch_model_dict:
raise NotImplementedError(
f"Error the model {model_name} was not found in the model dict. Please add it."
)
if self.pre_model is not None:
return self._load_pretrain_model()
elif self.weight_path is not None:
return self._load_model_from_pth()
else:
return pytorch_model_dict[model_name](**model_cfgs["model_hyperparam"])
make_dataset(self, is_tra_val_te)
¶
Initializes a pytorch dataset.
Parameters¶
is_tra_val_te train or valid or test
Returns¶
object an object initializing from class in datasets_dict in data_dict.py
Source code in torchhydro/trainers/deep_hydro.py
def make_dataset(self, is_tra_val_te: str):
"""
Initializes a pytorch dataset.
Parameters
----------
is_tra_val_te
train or valid or test
Returns
-------
object
an object initializing from class in datasets_dict in data_dict.py
"""
data_cfgs = self.cfgs["data_cfgs"]
dataset_name = data_cfgs["dataset"]
if dataset_name in list(datasets_dict.keys()):
dataset = datasets_dict[dataset_name](self.cfgs, is_tra_val_te)
else:
raise NotImplementedError(
f"Error the dataset {str(dataset_name)} was not found in the dataset dict. Please add it."
)
return dataset
model_evaluate(self)
¶
A function to evaluate a model, called at end of training.
Returns¶
tuple[dict, np.array, np.array] eval_log, denormalized predictions and observations
Source code in torchhydro/trainers/deep_hydro.py
def model_evaluate(self) -> Tuple[Dict, np.array, np.array]:
"""
A function to evaluate a model, called at end of training.
Returns
-------
tuple[dict, np.array, np.array]
eval_log, denormalized predictions and observations
"""
self.model = self.load_model(mode="infer").to(self.device)
preds_xr, obss_xr = self.inference()
return preds_xr, obss_xr
model_train(self)
¶
train a hydrological DL model
Source code in torchhydro/trainers/deep_hydro.py
def model_train(self) -> None:
"""train a hydrological DL model"""
# A dictionary of the necessary parameters for training
training_cfgs = self.cfgs["training_cfgs"]
# The file path to load model weights from; defaults to "model_save"
model_filepath = self.cfgs["data_cfgs"]["case_dir"]
data_cfgs = self.cfgs["data_cfgs"]
es = None
if training_cfgs["early_stopping"]:
es = EarlyStopper(training_cfgs["patience"])
criterion = self._get_loss_func(training_cfgs)
opt = self._get_optimizer(training_cfgs)
scheduler = self._get_scheduler(training_cfgs, opt)
max_epochs = training_cfgs["epochs"]
start_epoch = training_cfgs["start_epoch"]
# use PyTorch's DataLoader to load the data into batches in each epoch
data_loader, validation_data_loader = self._get_dataloader(
training_cfgs, data_cfgs
)
logger = TrainLogger(model_filepath, self.cfgs, opt)
for epoch in range(start_epoch, max_epochs + 1):
with logger.log_epoch_train(epoch) as train_logs:
total_loss, n_iter_ep = torch_single_train(
self.model,
opt,
criterion,
data_loader,
device=self.device,
which_first_tensor=training_cfgs["which_first_tensor"],
)
train_logs["train_loss"] = total_loss
train_logs["model"] = self.model
valid_loss = None
valid_metrics = None
if data_cfgs["t_range_valid"] is not None:
with logger.log_epoch_valid(epoch) as valid_logs:
valid_loss, valid_metrics = self._1epoch_valid(
training_cfgs, criterion, validation_data_loader, valid_logs
)
self._scheduler_step(training_cfgs, scheduler, valid_loss)
logger.save_session_param(
epoch, total_loss, n_iter_ep, valid_loss, valid_metrics
)
logger.save_model_and_params(self.model, epoch, self.cfgs)
if es and not es.check_loss(
self.model,
valid_loss,
self.cfgs["data_cfgs"]["case_dir"],
):
print("Stopping model now")
break
# logger.plot_model_structure(self.model)
logger.tb.close()
# return the trained model weights and bias and the epoch loss
return self.model.state_dict(), sum(logger.epoch_loss) / len(logger.epoch_loss)
DeepHydroInterface (ABC)
¶
An abstract class used to handle different configurations of hydrological deep learning models + hyperparams for training, test, and predict functions. This class assumes that data is already split into test train and validation at this point.
Source code in torchhydro/trainers/deep_hydro.py
class DeepHydroInterface(ABC):
"""
An abstract class used to handle different configurations
of hydrological deep learning models + hyperparams for training, test, and predict functions.
This class assumes that data is already split into test train and validation at this point.
"""
def __init__(self, cfgs: Dict):
"""
Parameters
----------
cfgs
configs for initializing DeepHydro
"""
self._cfgs = cfgs
@property
def cfgs(self):
"""all configs"""
return self._cfgs
@property
def weight_path(self):
"""weight path"""
return self._cfgs["model_cfgs"]["weight_path"]
@weight_path.setter
def weight_path(self, weight_path):
self._cfgs["model_cfgs"]["weight_path"] = weight_path
@abstractmethod
def load_model(self, mode="train") -> object:
"""Get a Hydro DL model"""
raise NotImplementedError
@abstractmethod
def make_dataset(self, is_tra_val_te: str) -> object:
"""
Initializes a pytorch dataset.
Parameters
----------
is_tra_val_te
train or valid or test
Returns
-------
object
a dataset class loading data from data source
"""
raise NotImplementedError
@abstractmethod
def model_train(self):
"""
Train the model
"""
raise NotImplementedError
@abstractmethod
def model_evaluate(self):
"""
Evaluate the model
"""
raise NotImplementedError
cfgs
property
readonly
¶
all configs
weight_path
property
writable
¶
weight path
__init__(self, cfgs)
special
¶
Parameters¶
cfgs configs for initializing DeepHydro
Source code in torchhydro/trainers/deep_hydro.py
def __init__(self, cfgs: Dict):
"""
Parameters
----------
cfgs
configs for initializing DeepHydro
"""
self._cfgs = cfgs
load_model(self, mode='train')
¶
Get a Hydro DL model
Source code in torchhydro/trainers/deep_hydro.py
@abstractmethod
def load_model(self, mode="train") -> object:
"""Get a Hydro DL model"""
raise NotImplementedError
make_dataset(self, is_tra_val_te)
¶
Initializes a pytorch dataset.
Parameters¶
is_tra_val_te train or valid or test
Returns¶
object a dataset class loading data from data source
Source code in torchhydro/trainers/deep_hydro.py
@abstractmethod
def make_dataset(self, is_tra_val_te: str) -> object:
"""
Initializes a pytorch dataset.
Parameters
----------
is_tra_val_te
train or valid or test
Returns
-------
object
a dataset class loading data from data source
"""
raise NotImplementedError
model_evaluate(self)
¶
Evaluate the model
Source code in torchhydro/trainers/deep_hydro.py
@abstractmethod
def model_evaluate(self):
"""
Evaluate the model
"""
raise NotImplementedError
model_train(self)
¶
Train the model
Source code in torchhydro/trainers/deep_hydro.py
@abstractmethod
def model_train(self):
"""
Train the model
"""
raise NotImplementedError
FedLearnHydro (DeepHydro)
¶
Federated Learning Hydrological DL model
Source code in torchhydro/trainers/deep_hydro.py
class FedLearnHydro(DeepHydro):
"""Federated Learning Hydrological DL model"""
def __init__(self, cfgs: Dict):
super().__init__(cfgs)
# a user group which is a dict where the keys are the user index
# and the values are the corresponding data for each of those users
train_dataset = self.traindataset
fl_hyperparam = self.cfgs["model_cfgs"]["fl_hyperparam"]
# sample training data amongst users
if fl_hyperparam["fl_sample"] == "basin":
# Sample a basin for a user
user_groups = fl_sample_basin(train_dataset)
elif fl_hyperparam["fl_sample"] == "region":
# Sample a region for a user
user_groups = fl_sample_region(train_dataset)
else:
raise NotImplementedError()
self.user_groups = user_groups
@property
def num_users(self):
"""number of users in federated learning"""
return len(self.user_groups)
def model_train(self) -> None:
# BUILD MODEL
global_model = self.model
# copy weights
global_weights = global_model.state_dict()
# Training
train_loss, train_accuracy = [], []
print_every = 2
training_cfgs = self.cfgs["training_cfgs"]
model_cfgs = self.cfgs["model_cfgs"]
max_epochs = training_cfgs["epochs"]
start_epoch = training_cfgs["start_epoch"]
fl_hyperparam = model_cfgs["fl_hyperparam"]
# total rounds in a FL system is max_epochs
for epoch in tqdm(range(start_epoch, max_epochs + 1)):
local_weights, local_losses = [], []
print(f"\n | Global Training Round : {epoch} |\n")
global_model.train()
m = max(int(fl_hyperparam["fl_frac"] * self.num_users), 1)
# randomly select m users, they will be the clients in this round
idxs_users = np.random.choice(range(self.num_users), m, replace=False)
for idx in idxs_users:
# each user will be used to train the model locally
# user_gourps[idx] means the idx of dataset for a user
user_cfgs = self._get_a_user_cfgs(idx)
local_model = DeepHydro(
user_cfgs,
pre_model=copy.deepcopy(global_model),
)
w, loss = local_model.model_train()
local_weights.append(copy.deepcopy(w))
local_losses.append(copy.deepcopy(loss))
# update global weights
global_weights = average_weights(local_weights)
# update global weights
global_model.load_state_dict(global_weights)
loss_avg = sum(local_losses) / len(local_losses)
train_loss.append(loss_avg)
# Calculate avg training accuracy over all users at every epoch
list_acc = []
global_model.eval()
for c in range(self.num_users):
one_user_cfg = self._get_a_user_cfgs(c)
local_model = DeepHydro(
one_user_cfg,
pre_model=global_model,
)
acc, _, _ = local_model.model_evaluate()
list_acc.append(acc)
values = [list(d.values())[0][0] for d in list_acc]
filtered_values = [v for v in values if not np.isnan(v)]
train_accuracy.append(sum(filtered_values) / len(filtered_values))
# print global training loss after every 'i' rounds
if (epoch + 1) % print_every == 0:
print(f" \nAvg Training Stats after {epoch+1} global rounds:")
print(f"Training Loss : {np.mean(np.array(train_loss))}")
print("Train Accuracy: {:.2f}% \n".format(100 * train_accuracy[-1]))
def _get_a_user_cfgs(self, idx):
"""To get a user's configs for local training"""
user = self.user_groups[idx]
# update data_cfgs
# Use defaultdict to collect dates for each basin
basin_dates = defaultdict(list)
for _, (basin, time) in user.items():
basin_dates[basin].append(time)
# Initialize a list to store distinct basins
basins = []
# for each basin, we can find its date range
date_ranges = {}
for basin, times in basin_dates.items():
basins.append(basin)
date_ranges[basin] = (np.min(times), np.max(times))
# get the longest date range
longest_date_range = max(date_ranges.values(), key=lambda x: x[1] - x[0])
# transform the date range of numpy data into string
longest_date_range = [
np.datetime_as_string(dt, unit="D") for dt in longest_date_range
]
user_cfgs = copy.deepcopy(self.cfgs)
# update data_cfgs
update_nested_dict(
user_cfgs, ["data_cfgs", "t_range_train"], longest_date_range
)
# for local training in FL, we don't need a validation set
update_nested_dict(user_cfgs, ["data_cfgs", "t_range_valid"], None)
# for local training in FL, we don't need a test set, but we should set one to avoid error
update_nested_dict(user_cfgs, ["data_cfgs", "t_range_test"], longest_date_range)
update_nested_dict(user_cfgs, ["data_cfgs", "object_ids"], basins)
# update training_cfgs
# we also need to update some training params for local training from FL settings
update_nested_dict(
user_cfgs,
["training_cfgs", "epochs"],
user_cfgs["model_cfgs"]["fl_hyperparam"]["fl_local_ep"],
)
update_nested_dict(
user_cfgs,
["evaluation_cfgs", "test_epoch"],
user_cfgs["model_cfgs"]["fl_hyperparam"]["fl_local_ep"],
)
# don't need to save model weights for local training
update_nested_dict(
user_cfgs,
["training_cfgs", "save_epoch"],
None,
)
# there are two settings for batch size in configs, we need to update both of them
update_nested_dict(
user_cfgs,
["training_cfgs", "batch_size"],
user_cfgs["model_cfgs"]["fl_hyperparam"]["fl_local_bs"],
)
update_nested_dict(
user_cfgs,
["data_cfgs", "batch_size"],
user_cfgs["model_cfgs"]["fl_hyperparam"]["fl_local_bs"],
)
# update model_cfgs finally
# For local model, its model_type is Normal
update_nested_dict(user_cfgs, ["model_cfgs", "model_type"], "Normal")
update_nested_dict(
user_cfgs,
["model_cfgs", "fl_hyperparam"],
None,
)
return user_cfgs
num_users
property
readonly
¶
number of users in federated learning
model_train(self)
¶
train a hydrological DL model
Source code in torchhydro/trainers/deep_hydro.py
def model_train(self) -> None:
# BUILD MODEL
global_model = self.model
# copy weights
global_weights = global_model.state_dict()
# Training
train_loss, train_accuracy = [], []
print_every = 2
training_cfgs = self.cfgs["training_cfgs"]
model_cfgs = self.cfgs["model_cfgs"]
max_epochs = training_cfgs["epochs"]
start_epoch = training_cfgs["start_epoch"]
fl_hyperparam = model_cfgs["fl_hyperparam"]
# total rounds in a FL system is max_epochs
for epoch in tqdm(range(start_epoch, max_epochs + 1)):
local_weights, local_losses = [], []
print(f"\n | Global Training Round : {epoch} |\n")
global_model.train()
m = max(int(fl_hyperparam["fl_frac"] * self.num_users), 1)
# randomly select m users, they will be the clients in this round
idxs_users = np.random.choice(range(self.num_users), m, replace=False)
for idx in idxs_users:
# each user will be used to train the model locally
# user_gourps[idx] means the idx of dataset for a user
user_cfgs = self._get_a_user_cfgs(idx)
local_model = DeepHydro(
user_cfgs,
pre_model=copy.deepcopy(global_model),
)
w, loss = local_model.model_train()
local_weights.append(copy.deepcopy(w))
local_losses.append(copy.deepcopy(loss))
# update global weights
global_weights = average_weights(local_weights)
# update global weights
global_model.load_state_dict(global_weights)
loss_avg = sum(local_losses) / len(local_losses)
train_loss.append(loss_avg)
# Calculate avg training accuracy over all users at every epoch
list_acc = []
global_model.eval()
for c in range(self.num_users):
one_user_cfg = self._get_a_user_cfgs(c)
local_model = DeepHydro(
one_user_cfg,
pre_model=global_model,
)
acc, _, _ = local_model.model_evaluate()
list_acc.append(acc)
values = [list(d.values())[0][0] for d in list_acc]
filtered_values = [v for v in values if not np.isnan(v)]
train_accuracy.append(sum(filtered_values) / len(filtered_values))
# print global training loss after every 'i' rounds
if (epoch + 1) % print_every == 0:
print(f" \nAvg Training Stats after {epoch+1} global rounds:")
print(f"Training Loss : {np.mean(np.array(train_loss))}")
print("Train Accuracy: {:.2f}% \n".format(100 * train_accuracy[-1]))
TransLearnHydro (DeepHydro)
¶
Source code in torchhydro/trainers/deep_hydro.py
class TransLearnHydro(DeepHydro):
def __init__(self, cfgs: Dict, pre_model=None):
super().__init__(cfgs, pre_model)
def load_model(self, mode="train"):
"""Load model for transfer learning"""
model_cfgs = self.cfgs["model_cfgs"]
if self.weight_path is None and self.pre_model is None:
raise NotImplementedError(
"For transfer learning, we need a pre-trained model"
)
if mode == "train":
model = super().load_model(mode)
elif mode == "infer":
self.weight_path = self._get_trained_model()
model = self._load_model_from_pth()
model.to(self.device)
if (
"weight_path_add" in model_cfgs
and "freeze_params" in model_cfgs["weight_path_add"]
):
freeze_params = model_cfgs["weight_path_add"]["freeze_params"]
for param in freeze_params:
exec(f"model.{param}.requires_grad = False")
return model
def _load_model_from_pth(self):
weight_path = self.weight_path
model_cfgs = self.cfgs["model_cfgs"]
model_name = model_cfgs["model_name"]
model = pytorch_model_dict[model_name](**model_cfgs["model_hyperparam"])
checkpoint = torch.load(weight_path, map_location=self.device)
if "weight_path_add" in model_cfgs:
if "excluded_layers" in model_cfgs["weight_path_add"]:
# delete some layers from source model if we don't need them
excluded_layers = model_cfgs["weight_path_add"]["excluded_layers"]
for layer in excluded_layers:
del checkpoint[layer]
print("sucessfully deleted layers")
else:
print("directly loading identically-named layers of source model")
model.load_state_dict(checkpoint, strict=False)
print("Weights sucessfully loaded")
return model
load_model(self, mode='train')
¶
Load model for transfer learning
Source code in torchhydro/trainers/deep_hydro.py
def load_model(self, mode="train"):
"""Load model for transfer learning"""
model_cfgs = self.cfgs["model_cfgs"]
if self.weight_path is None and self.pre_model is None:
raise NotImplementedError(
"For transfer learning, we need a pre-trained model"
)
if mode == "train":
model = super().load_model(mode)
elif mode == "infer":
self.weight_path = self._get_trained_model()
model = self._load_model_from_pth()
model.to(self.device)
if (
"weight_path_add" in model_cfgs
and "freeze_params" in model_cfgs["weight_path_add"]
):
freeze_params = model_cfgs["weight_path_add"]["freeze_params"]
for param in freeze_params:
exec(f"model.{param}.requires_grad = False")
return model
fabric_wrapper
¶
Author: Wenyu Ouyang Date: 2023-07-25 16:47:19 LastEditTime: 2025-06-17 10:39:32 LastEditors: Wenyu Ouyang Description: Lightning Fabric wrapper for debugging and distributed training FilePath: orchhydro orchhydro rainersabric_wrapper.py Copyright (c) 2025-2026 Wenyu Ouyang. All rights reserved.
FabricWrapper
¶
A wrapper class that can switch between Lightning Fabric and normal PyTorch operations based on configuration settings.
TODO: the fabric wrapper is not fully used for parallel training yet
Source code in torchhydro/trainers/fabric_wrapper.py
class FabricWrapper:
"""
A wrapper class that can switch between Lightning Fabric and normal PyTorch operations
based on configuration settings.
TODO: the fabric wrapper is not fully used for parallel training yet
"""
def __init__(self, use_fabric: bool = True, fabric_config: Optional[Dict] = None):
"""
Initialize the Fabric wrapper.
Parameters
----------
use_fabric : bool
Whether to use Lightning Fabric or normal PyTorch operations
fabric_config : Optional[Dict]
Configuration for Fabric (devices, strategy, etc.)
"""
self.use_fabric = use_fabric
self.fabric_config = fabric_config or {}
self._fabric: Optional[Any] = None
self._device: Optional[torch.device] = None
if self.use_fabric:
self._init_fabric()
else:
self._init_pytorch()
def _init_fabric(self) -> None:
"""Initialize Lightning Fabric"""
try:
import lightning as L
# Default fabric configuration
default_config = {
"accelerator": "auto",
"devices": "auto",
"strategy": "auto",
"precision": "32-true",
}
# Update with user config
default_config.update(self.fabric_config)
self._fabric = L.Fabric(**default_config)
print("✅ Lightning Fabric initialized successfully")
except ImportError:
print("❌ Lightning not found, falling back to normal PyTorch")
self.use_fabric = False
self._init_pytorch()
def _init_pytorch(self) -> None:
"""Initialize normal PyTorch setup"""
self.device_num = self.fabric_config["devices"]
#self.device_num = [0]
self._device = get_the_device(self.device_num)
print(f"✅ Normal PyTorch initialized, using device: {self._device}")
def setup_module(self, model: torch.nn.Module) -> torch.nn.Module:
"""Setup model for training"""
if self.use_fabric:
return self._fabric.setup_module(model)
else:
return model.to(self._device)
def setup_optimizers(
self, optimizer: torch.optim.Optimizer
) -> torch.optim.Optimizer:
"""Setup optimizer"""
if self.use_fabric:
return self._fabric.setup_optimizers(optimizer)
else:
return optimizer
def setup_dataloaders(
self, *dataloaders: torch.utils.data.DataLoader
) -> Tuple[torch.utils.data.DataLoader, ...]:
"""Setup dataloaders"""
if self.use_fabric:
return self._fabric.setup_dataloaders(*dataloaders)
else:
return dataloaders
def save(self, path: str, state_dict: Dict[str, Any]) -> None:
"""Save model state"""
if self.use_fabric:
self._fabric.save(path, state_dict)
else:
torch.save(state_dict, path)
def load(self, path: str, model: Optional[torch.nn.Module] = None) -> Any:
"""Load model state"""
if self.use_fabric:
return self._fabric.load(path, model)
else:
return torch.load(path, map_location=self._device)
def load_raw(self, path: str, model: torch.nn.Module) -> None:
"""Load raw model weights"""
if self.use_fabric:
checkpoint = self._fabric.load(path)
model.load_state_dict(checkpoint)
else:
checkpoint = torch.load(path, map_location=self._device)
model.load_state_dict(checkpoint)
def launch(self, fn: Optional[Any] = None, *args: Any, **kwargs: Any) -> Any:
"""Launch training function"""
if self.use_fabric:
if fn is None:
# This is called without a function, just launch fabric
return self._fabric.launch()
else:
return self._fabric.launch(fn, *args, **kwargs)
else:
# Normal PyTorch, just call the function directly
if fn is not None:
return fn(*args, **kwargs)
else:
return None
def backward(self, loss: torch.Tensor) -> None:
"""Backward pass"""
if self.use_fabric:
self._fabric.backward(loss)
else:
loss.backward()
def clip_gradients(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
max_norm: float = 1.0,
) -> None:
"""Clip gradients"""
if self.use_fabric:
self._fabric.clip_gradients(model, optimizer, max_norm=max_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
@property
def device(self) -> torch.device:
"""Get current device"""
if self.use_fabric:
return self._fabric.device
else:
return self._device
@property
def local_rank(self) -> int:
"""Get local rank"""
if self.use_fabric:
return self._fabric.local_rank
else:
return 0
@property
def global_rank(self) -> int:
"""Get global rank"""
if self.use_fabric:
return self._fabric.global_rank
else:
return 0
@property
def world_size(self) -> int:
"""Get world size"""
if self.use_fabric:
return self._fabric.world_size
else:
return 1
def barrier(self) -> None:
"""Synchronization barrier"""
if self.use_fabric:
self._fabric.barrier()
else:
pass # No barrier needed for single process
def print(self, *args: Any, **kwargs: Any) -> None:
"""Print only on rank 0"""
if self.use_fabric:
self._fabric.print(*args, **kwargs)
else:
print(*args, **kwargs)
device: device
property
readonly
¶
Get current device
global_rank: int
property
readonly
¶
Get global rank
local_rank: int
property
readonly
¶
Get local rank
world_size: int
property
readonly
¶
Get world size
__init__(self, use_fabric=True, fabric_config=None)
special
¶
Initialize the Fabric wrapper.
Parameters¶
use_fabric : bool Whether to use Lightning Fabric or normal PyTorch operations fabric_config : Optional[Dict] Configuration for Fabric (devices, strategy, etc.)
Source code in torchhydro/trainers/fabric_wrapper.py
def __init__(self, use_fabric: bool = True, fabric_config: Optional[Dict] = None):
"""
Initialize the Fabric wrapper.
Parameters
----------
use_fabric : bool
Whether to use Lightning Fabric or normal PyTorch operations
fabric_config : Optional[Dict]
Configuration for Fabric (devices, strategy, etc.)
"""
self.use_fabric = use_fabric
self.fabric_config = fabric_config or {}
self._fabric: Optional[Any] = None
self._device: Optional[torch.device] = None
if self.use_fabric:
self._init_fabric()
else:
self._init_pytorch()
backward(self, loss)
¶
Backward pass
Source code in torchhydro/trainers/fabric_wrapper.py
def backward(self, loss: torch.Tensor) -> None:
"""Backward pass"""
if self.use_fabric:
self._fabric.backward(loss)
else:
loss.backward()
barrier(self)
¶
Synchronization barrier
Source code in torchhydro/trainers/fabric_wrapper.py
def barrier(self) -> None:
"""Synchronization barrier"""
if self.use_fabric:
self._fabric.barrier()
else:
pass # No barrier needed for single process
clip_gradients(self, model, optimizer, max_norm=1.0)
¶
Clip gradients
Source code in torchhydro/trainers/fabric_wrapper.py
def clip_gradients(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
max_norm: float = 1.0,
) -> None:
"""Clip gradients"""
if self.use_fabric:
self._fabric.clip_gradients(model, optimizer, max_norm=max_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
launch(self, fn=None, *args, **kwargs)
¶
Launch training function
Source code in torchhydro/trainers/fabric_wrapper.py
def launch(self, fn: Optional[Any] = None, *args: Any, **kwargs: Any) -> Any:
"""Launch training function"""
if self.use_fabric:
if fn is None:
# This is called without a function, just launch fabric
return self._fabric.launch()
else:
return self._fabric.launch(fn, *args, **kwargs)
else:
# Normal PyTorch, just call the function directly
if fn is not None:
return fn(*args, **kwargs)
else:
return None
load(self, path, model=None)
¶
Load model state
Source code in torchhydro/trainers/fabric_wrapper.py
def load(self, path: str, model: Optional[torch.nn.Module] = None) -> Any:
"""Load model state"""
if self.use_fabric:
return self._fabric.load(path, model)
else:
return torch.load(path, map_location=self._device)
load_raw(self, path, model)
¶
Load raw model weights
Source code in torchhydro/trainers/fabric_wrapper.py
def load_raw(self, path: str, model: torch.nn.Module) -> None:
"""Load raw model weights"""
if self.use_fabric:
checkpoint = self._fabric.load(path)
model.load_state_dict(checkpoint)
else:
checkpoint = torch.load(path, map_location=self._device)
model.load_state_dict(checkpoint)
print(self, *args, **kwargs)
¶
Print only on rank 0
Source code in torchhydro/trainers/fabric_wrapper.py
def print(self, *args: Any, **kwargs: Any) -> None:
"""Print only on rank 0"""
if self.use_fabric:
self._fabric.print(*args, **kwargs)
else:
print(*args, **kwargs)
save(self, path, state_dict)
¶
Save model state
Source code in torchhydro/trainers/fabric_wrapper.py
def save(self, path: str, state_dict: Dict[str, Any]) -> None:
"""Save model state"""
if self.use_fabric:
self._fabric.save(path, state_dict)
else:
torch.save(state_dict, path)
setup_dataloaders(self, *dataloaders)
¶
Setup dataloaders
Source code in torchhydro/trainers/fabric_wrapper.py
def setup_dataloaders(
self, *dataloaders: torch.utils.data.DataLoader
) -> Tuple[torch.utils.data.DataLoader, ...]:
"""Setup dataloaders"""
if self.use_fabric:
return self._fabric.setup_dataloaders(*dataloaders)
else:
return dataloaders
setup_module(self, model)
¶
Setup model for training
Source code in torchhydro/trainers/fabric_wrapper.py
def setup_module(self, model: torch.nn.Module) -> torch.nn.Module:
"""Setup model for training"""
if self.use_fabric:
return self._fabric.setup_module(model)
else:
return model.to(self._device)
setup_optimizers(self, optimizer)
¶
Setup optimizer
Source code in torchhydro/trainers/fabric_wrapper.py
def setup_optimizers(
self, optimizer: torch.optim.Optimizer
) -> torch.optim.Optimizer:
"""Setup optimizer"""
if self.use_fabric:
return self._fabric.setup_optimizers(optimizer)
else:
return optimizer
create_fabric_wrapper(training_cfgs)
¶
Create a fabric wrapper based on training configuration.
Parameters¶
training_cfgs : Dict Training configuration dictionary
Returns¶
FabricWrapper Initialized fabric wrapper
Source code in torchhydro/trainers/fabric_wrapper.py
def create_fabric_wrapper(training_cfgs: Dict) -> FabricWrapper:
"""
Create a fabric wrapper based on training configuration.
Parameters
----------
training_cfgs : Dict
Training configuration dictionary
Returns
-------
FabricWrapper
Initialized fabric wrapper
"""
# Check if we should use fabric
fabric_strategy = training_cfgs.get("fabric_strategy")
use_fabric = fabric_strategy is not None
# Check if we have multiple devices
devices = training_cfgs.get("device", [0])
if isinstance(devices, list) and len(devices) == 1 and use_fabric:
print("📱 Single device detected - we can disable Fabric")
use_fabric = False
# Fabric configuration
fabric_config = {
"devices": devices if isinstance(devices, list) else [devices],
"strategy": fabric_strategy,
"precision": training_cfgs.get("precision", "32-true"),
"accelerator": training_cfgs.get("accelerator", "auto"),
}
return FabricWrapper(use_fabric=use_fabric, fabric_config=fabric_config)
resulter
¶
Resulter
¶
Source code in torchhydro/trainers/resulter.py
class Resulter:
def __init__(self, cfgs) -> None:
self.cfgs = cfgs
self.result_dir = cfgs["data_cfgs"]["case_dir"]
if not os.path.exists(self.result_dir):
os.makedirs(self.result_dir)
@property
def pred_name(self):
return f"epoch{str(self.chosen_trained_epoch)}flow_pred"
@property
def obs_name(self):
return f"epoch{str(self.chosen_trained_epoch)}flow_obs"
@property
def chosen_trained_epoch(self):
model_loader = self.cfgs["evaluation_cfgs"]["model_loader"]
if model_loader["load_way"] == "specified":
epoch_name = str(model_loader["test_epoch"])
elif model_loader["load_way"] == "best":
# NOTE: TO make it consistent with the name in case of model_loader["load_way"] == "pth", the name have to be "best_model.pth"
epoch_name = "best_model.pth"
elif model_loader["load_way"] == "latest":
epoch_name = str(self.cfgs["training_cfgs"]["epochs"])
elif model_loader["load_way"] == "pth":
epoch_name = model_loader["pth_path"].split(os.sep)[-1]
else:
raise ValueError("Invalid load_way")
return epoch_name
def save_cfg(self, cfgs):
# save the cfgs after training
# update the cfgs with the latest one
self.cfgs = cfgs
param_file_exist = any(
(
fnmatch.fnmatch(file, "*.json")
and "_stat" not in file # statistics json file
and "_dict" not in file # data cache json file
)
for file in os.listdir(self.result_dir)
)
if not param_file_exist:
# although we save params log during training, but sometimes we directly evaluate a model
# so here we still save params log if param file does not exist
# no param file was saved yet, here we save data and params setting
save_model_params_log(cfgs, self.result_dir)
def save_result(self, pred, obs):
"""
save the pred value of testing period and obs value
Parameters
----------
pred
predictions
obs
observations
pred_name
the file name of predictions
obs_name
the file name of observations
Returns
-------
None
"""
save_dir = self.result_dir
flow_pred_file = os.path.join(save_dir, self.pred_name)
flow_obs_file = os.path.join(save_dir, self.obs_name)
max_len = max(len(basin) for basin in pred.basin.values)
encoding = {"basin": {"dtype": f"U{max_len}"}}
pred.to_netcdf(flow_pred_file + ".nc", encoding=encoding)
obs.to_netcdf(flow_obs_file + ".nc", encoding=encoding)
def eval_result(self, preds_xr, obss_xr):
# types of observations
target_col = self.cfgs["data_cfgs"]["target_cols"]
evaluation_metrics = self.cfgs["evaluation_cfgs"]["metrics"]
basin_ids = self.cfgs["data_cfgs"]["object_ids"]
test_path = self.cfgs["data_cfgs"]["case_dir"]
# Assume object_ids like ['changdian_61561']
# fill_nan: "no" means ignoring the NaN value;
# "sum" means calculate the sum of the following values in the NaN locations.
# For example, observations are [1, nan, nan, 2], and predictions are [0.3, 0.3, 0.3, 1.5].
# Then, "no" means [1, 2] v.s. [0.3, 1.5] while "sum" means [1, 2] v.s. [0.3 + 0.3 + 0.3, 1.5].
# If it is a str, then all target vars use same fill_nan method;
# elif it is a list, each for a var
fill_nan = self.cfgs["evaluation_cfgs"]["fill_nan"]
# Then evaluate the model metrics
if type(fill_nan) is list and len(fill_nan) != len(target_col):
raise ValueError("length of fill_nan must be equal to target_col's")
for i, col in enumerate(target_col):
eval_log = {}
obs = obss_xr[col].to_numpy()
pred = preds_xr[col].to_numpy()
eval_log = calculate_and_record_metrics(
obs,
pred,
evaluation_metrics,
col,
fill_nan[i] if isinstance(fill_nan, list) else fill_nan,
eval_log,
)
# Create pandas DataFrames from eval_log for each target variable (e.g., streamflow)
# Create a dictionary to hold the data for the DataFrame
data = {}
# Iterate over metrics in eval_log
for metric, values in eval_log.items():
# Remove 'of streamflow' (or similar) from the metric name
clean_metric = metric.replace(f"of {col}", "").strip()
# Add the cleaned metric to the data dictionary
data[clean_metric] = values
# Create a DataFrame using object_ids as the index and metrics as columns
df = pd.DataFrame(data, index=basin_ids)
# Define the output file name based on the target variable
output_file = os.path.join(test_path, f"metric_{col}.csv")
# Save the DataFrame to a CSV file
df.to_csv(output_file, index_label="basin_id")
# Finally, try to explain model behaviour using shap
is_shap = self.cfgs["evaluation_cfgs"]["explainer"] == "shap"
if is_shap:
shap_summary_plot(self.model, self.traindataset, self.testdataset)
# deep_explain_model_summary_plot(self.model, test_data)
# deep_explain_model_heatmap(self.model, test_data)
def _convert_streamflow_units(self, ds):
"""convert the streamflow units to m^3/s
Parameters
----------
pred : np.array
predictions
Returns
-------
"""
data_cfgs = self.cfgs["data_cfgs"]
source_name = data_cfgs["source_cfgs"]["source_name"]
source_path = data_cfgs["source_cfgs"]["source_path"]
other_settings = data_cfgs["source_cfgs"].get("other_settings", {})
data_source = data_sources_dict[source_name](source_path, **other_settings)
basin_id = data_cfgs["object_ids"]
# NOTE: all datasource should have read_area method
basin_area = data_source.read_area(basin_id)
target_unit = "m^3/s"
# Get the flow variable name dynamically from config instead of hardcoding "streamflow"
# NOTE: the first target variable must be the flow variable
var_flow = self.cfgs["data_cfgs"]["target_cols"][0]
streamflow_ds = ds[[var_flow]]
ds_ = streamflow_unit_conv(
streamflow_ds, basin_area, target_unit=target_unit, inverse=True
)
new_ds = ds.copy(deep=True)
new_ds[var_flow] = ds_[var_flow]
return new_ds
def load_result(self, convert_flow_unit=False) -> Tuple[np.array, np.array]:
"""load the pred value of testing period and obs value"""
save_dir = self.result_dir
pred_file = os.path.join(save_dir, self.pred_name + ".nc")
obs_file = os.path.join(save_dir, self.obs_name + ".nc")
pred = xr.open_dataset(pred_file)
obs = xr.open_dataset(obs_file)
if convert_flow_unit:
pred = self._convert_streamflow_units(pred)
obs = self._convert_streamflow_units(obs)
return pred, obs
def save_intermediate_results(self, **kwargs):
"""Load model weights and deal with some intermediate results"""
is_cell_states = kwargs.get("is_cell_states", False)
is_pbm_params = kwargs.get("is_pbm_params", False)
cfgs = self.cfgs
cfgs["training_cfgs"]["train_mode"] = False
training_cfgs = cfgs["training_cfgs"]
seq_first = training_cfgs["which_first_tensor"] == "sequence"
if is_cell_states:
# TODO: not support return_cell_states yet
return cellstates_when_inference(seq_first, data_cfgs, pred)
if is_pbm_params:
self._save_pbm_params(cfgs, seq_first)
def _save_pbm_params(self, cfgs, seq_first):
training_cfgs = cfgs["training_cfgs"]
model_loader = cfgs["evaluation_cfgs"]["model_loader"]
model_pth_dir = cfgs["data_cfgs"]["case_dir"]
weight_path = read_pth_from_model_loader(model_loader, model_pth_dir)
cfgs["model_cfgs"]["weight_path"] = weight_path
cfgs["training_cfgs"]["device"] = [0] if torch.cuda.is_available() else [-1]
deephydro = DeepHydro(cfgs)
device = deephydro.device
dl_model = deephydro.model.dl_model
pb_model = deephydro.model.pb_model
param_func = deephydro.model.param_func
# TODO: check for dplnnmodule model
param_test_way = deephydro.model.param_test_way
test_dataloader = DataLoader(
deephydro.testdataset,
batch_size=training_cfgs["batch_size"],
shuffle=False,
sampler=None,
batch_sampler=None,
drop_last=False,
timeout=0,
worker_init_fn=None,
)
deephydro.model.eval()
# here the batch is just an index of lookup table, so any batch size could be chosen
params_lst = []
with torch.no_grad():
for batch in test_dataloader:
ys, gen = model_infer(seq_first, device, dl_model, batch)
# we set all params' values in [0, 1] and will scale them when forwarding
if param_func == "clamp":
params_ = torch.clamp(gen, min=0.0, max=1.0)
elif param_func == "sigmoid":
params_ = F.sigmoid(gen)
else:
raise NotImplementedError(
"We don't provide this way to limit parameters' range!! Please choose sigmoid or clamp"
)
# just get one-period values, here we use the final period's values
params = params_[:, -1, :]
params_lst.append(params)
pb_params = reduce(lambda a, b: torch.cat((a, b), dim=0), params_lst)
# trans tensor to pandas dataframe
sites = deephydro.cfgs["data_cfgs"]["object_ids"]
params_names = pb_model.params_names
params_df = pd.DataFrame(
pb_params.cpu().numpy(), columns=params_names, index=sites
)
save_param_file = os.path.join(
model_pth_dir, f"pb_params_{int(time.time())}.csv"
)
params_df.to_csv(save_param_file, index_label="GAGE_ID")
def read_tensorboard_log(self, **kwargs):
"""read tensorboard log files"""
is_scalar = kwargs.get("is_scalar", False)
is_histogram = kwargs.get("is_histogram", False)
log_dir = self.cfgs["data_cfgs"]["case_dir"]
if is_scalar:
scalar_file = os.path.join(log_dir, "tb_scalars.csv")
if not os.path.exists(scalar_file):
reader = SummaryReader(log_dir)
df_scalar = reader.scalars
df_scalar.to_csv(scalar_file, index=False)
else:
df_scalar = pd.read_csv(scalar_file)
if is_histogram:
histogram_file = os.path.join(log_dir, "tb_histograms.csv")
if not os.path.exists(histogram_file):
reader = SummaryReader(log_dir)
df_histogram = reader.histograms
df_histogram.to_csv(histogram_file, index=False)
else:
df_histogram = pd.read_csv(histogram_file)
if is_scalar and is_histogram:
return df_scalar, df_histogram
elif is_scalar:
return df_scalar
elif is_histogram:
return df_histogram
# TODO: the following code is not finished yet
def load_ensemble_result(
self, save_dirs, test_epoch, flow_unit="m3/s", basin_areas=None
) -> Tuple[np.array, np.array]:
"""
load ensemble mean value
Parameters
----------
save_dirs
test_epoch
flow_unit
default is m3/s, if it is not m3/s, transform the results
basin_areas
if unit is mm/day it will be used, default is None
Returns
-------
"""
preds = []
obss = []
for save_dir in save_dirs:
pred_i, obs_i = self.load_result(save_dir, test_epoch)
if pred_i.ndim == 3 and pred_i.shape[-1] == 1:
pred_i = pred_i.reshape(pred_i.shape[0], pred_i.shape[1])
obs_i = obs_i.reshape(obs_i.shape[0], obs_i.shape[1])
preds.append(pred_i)
obss.append(obs_i)
preds_np = np.array(preds)
obss_np = np.array(obss)
pred_mean = np.mean(preds_np, axis=0)
obs_mean = np.mean(obss_np, axis=0)
if flow_unit == "mm/day":
if basin_areas is None:
raise ArithmeticError("No basin areas we cannot calculate")
basin_areas = np.repeat(basin_areas, obs_mean.shape[1], axis=0).reshape(
obs_mean.shape
)
obs_mean = obs_mean * basin_areas * 1e-3 * 1e6 / 86400
pred_mean = pred_mean * basin_areas * 1e-3 * 1e6 / 86400
elif flow_unit == "m3/s":
pass
elif flow_unit == "ft3/s":
obs_mean = obs_mean / 35.314666721489
pred_mean = pred_mean / 35.314666721489
return pred_mean, obs_mean
def eval_ensemble_result(
self,
save_dirs,
test_epoch,
return_value=False,
flow_unit="m3/s",
basin_areas=None,
) -> Tuple[np.array, np.array]:
"""calculate statistics for ensemble results
Parameters
----------
save_dirs : _type_
where the results save
test_epoch : _type_
we name the results files with the test_epoch
return_value : bool, optional
if True, return (inds_df, pred_mean, obs_mean), by default False
flow_unit : str, optional
arg for load_ensemble_result, by default "m3/s"
basin_areas : _type_, optional
arg for load_ensemble_result, by default None
Returns
-------
Tuple[np.array, np.array]
inds_df or (inds_df, pred_mean, obs_mean)
"""
pred_mean, obs_mean = self.load_ensemble_result(
save_dirs, test_epoch, flow_unit=flow_unit, basin_areas=basin_areas
)
inds = stat_error(obs_mean, pred_mean)
inds_df = pd.DataFrame(inds)
return (inds_df, pred_mean, obs_mean) if return_value else inds_df
eval_ensemble_result(self, save_dirs, test_epoch, return_value=False, flow_unit='m3/s', basin_areas=None)
¶
calculate statistics for ensemble results
Parameters¶
save_dirs : type where the results save test_epoch : type we name the results files with the test_epoch return_value : bool, optional if True, return (inds_df, pred_mean, obs_mean), by default False flow_unit : str, optional arg for load_ensemble_result, by default "m3/s" basin_areas : type, optional arg for load_ensemble_result, by default None
Returns¶
Tuple[np.array, np.array] inds_df or (inds_df, pred_mean, obs_mean)
Source code in torchhydro/trainers/resulter.py
def eval_ensemble_result(
self,
save_dirs,
test_epoch,
return_value=False,
flow_unit="m3/s",
basin_areas=None,
) -> Tuple[np.array, np.array]:
"""calculate statistics for ensemble results
Parameters
----------
save_dirs : _type_
where the results save
test_epoch : _type_
we name the results files with the test_epoch
return_value : bool, optional
if True, return (inds_df, pred_mean, obs_mean), by default False
flow_unit : str, optional
arg for load_ensemble_result, by default "m3/s"
basin_areas : _type_, optional
arg for load_ensemble_result, by default None
Returns
-------
Tuple[np.array, np.array]
inds_df or (inds_df, pred_mean, obs_mean)
"""
pred_mean, obs_mean = self.load_ensemble_result(
save_dirs, test_epoch, flow_unit=flow_unit, basin_areas=basin_areas
)
inds = stat_error(obs_mean, pred_mean)
inds_df = pd.DataFrame(inds)
return (inds_df, pred_mean, obs_mean) if return_value else inds_df
load_ensemble_result(self, save_dirs, test_epoch, flow_unit='m3/s', basin_areas=None)
¶
load ensemble mean value
Parameters¶
save_dirs test_epoch flow_unit default is m3/s, if it is not m3/s, transform the results basin_areas if unit is mm/day it will be used, default is None
Returns¶
Source code in torchhydro/trainers/resulter.py
def load_ensemble_result(
self, save_dirs, test_epoch, flow_unit="m3/s", basin_areas=None
) -> Tuple[np.array, np.array]:
"""
load ensemble mean value
Parameters
----------
save_dirs
test_epoch
flow_unit
default is m3/s, if it is not m3/s, transform the results
basin_areas
if unit is mm/day it will be used, default is None
Returns
-------
"""
preds = []
obss = []
for save_dir in save_dirs:
pred_i, obs_i = self.load_result(save_dir, test_epoch)
if pred_i.ndim == 3 and pred_i.shape[-1] == 1:
pred_i = pred_i.reshape(pred_i.shape[0], pred_i.shape[1])
obs_i = obs_i.reshape(obs_i.shape[0], obs_i.shape[1])
preds.append(pred_i)
obss.append(obs_i)
preds_np = np.array(preds)
obss_np = np.array(obss)
pred_mean = np.mean(preds_np, axis=0)
obs_mean = np.mean(obss_np, axis=0)
if flow_unit == "mm/day":
if basin_areas is None:
raise ArithmeticError("No basin areas we cannot calculate")
basin_areas = np.repeat(basin_areas, obs_mean.shape[1], axis=0).reshape(
obs_mean.shape
)
obs_mean = obs_mean * basin_areas * 1e-3 * 1e6 / 86400
pred_mean = pred_mean * basin_areas * 1e-3 * 1e6 / 86400
elif flow_unit == "m3/s":
pass
elif flow_unit == "ft3/s":
obs_mean = obs_mean / 35.314666721489
pred_mean = pred_mean / 35.314666721489
return pred_mean, obs_mean
load_result(self, convert_flow_unit=False)
¶
load the pred value of testing period and obs value
Source code in torchhydro/trainers/resulter.py
def load_result(self, convert_flow_unit=False) -> Tuple[np.array, np.array]:
"""load the pred value of testing period and obs value"""
save_dir = self.result_dir
pred_file = os.path.join(save_dir, self.pred_name + ".nc")
obs_file = os.path.join(save_dir, self.obs_name + ".nc")
pred = xr.open_dataset(pred_file)
obs = xr.open_dataset(obs_file)
if convert_flow_unit:
pred = self._convert_streamflow_units(pred)
obs = self._convert_streamflow_units(obs)
return pred, obs
read_tensorboard_log(self, **kwargs)
¶
read tensorboard log files
Source code in torchhydro/trainers/resulter.py
def read_tensorboard_log(self, **kwargs):
"""read tensorboard log files"""
is_scalar = kwargs.get("is_scalar", False)
is_histogram = kwargs.get("is_histogram", False)
log_dir = self.cfgs["data_cfgs"]["case_dir"]
if is_scalar:
scalar_file = os.path.join(log_dir, "tb_scalars.csv")
if not os.path.exists(scalar_file):
reader = SummaryReader(log_dir)
df_scalar = reader.scalars
df_scalar.to_csv(scalar_file, index=False)
else:
df_scalar = pd.read_csv(scalar_file)
if is_histogram:
histogram_file = os.path.join(log_dir, "tb_histograms.csv")
if not os.path.exists(histogram_file):
reader = SummaryReader(log_dir)
df_histogram = reader.histograms
df_histogram.to_csv(histogram_file, index=False)
else:
df_histogram = pd.read_csv(histogram_file)
if is_scalar and is_histogram:
return df_scalar, df_histogram
elif is_scalar:
return df_scalar
elif is_histogram:
return df_histogram
save_intermediate_results(self, **kwargs)
¶
Load model weights and deal with some intermediate results
Source code in torchhydro/trainers/resulter.py
def save_intermediate_results(self, **kwargs):
"""Load model weights and deal with some intermediate results"""
is_cell_states = kwargs.get("is_cell_states", False)
is_pbm_params = kwargs.get("is_pbm_params", False)
cfgs = self.cfgs
cfgs["training_cfgs"]["train_mode"] = False
training_cfgs = cfgs["training_cfgs"]
seq_first = training_cfgs["which_first_tensor"] == "sequence"
if is_cell_states:
# TODO: not support return_cell_states yet
return cellstates_when_inference(seq_first, data_cfgs, pred)
if is_pbm_params:
self._save_pbm_params(cfgs, seq_first)
save_result(self, pred, obs)
¶
save the pred value of testing period and obs value
Parameters¶
pred predictions obs observations pred_name the file name of predictions obs_name the file name of observations
Returns¶
None
Source code in torchhydro/trainers/resulter.py
def save_result(self, pred, obs):
"""
save the pred value of testing period and obs value
Parameters
----------
pred
predictions
obs
observations
pred_name
the file name of predictions
obs_name
the file name of observations
Returns
-------
None
"""
save_dir = self.result_dir
flow_pred_file = os.path.join(save_dir, self.pred_name)
flow_obs_file = os.path.join(save_dir, self.obs_name)
max_len = max(len(basin) for basin in pred.basin.values)
encoding = {"basin": {"dtype": f"U{max_len}"}}
pred.to_netcdf(flow_pred_file + ".nc", encoding=encoding)
obs.to_netcdf(flow_obs_file + ".nc", encoding=encoding)
train_logger
¶
Author: Wenyu Ouyang Date: 2021-12-31 11:08:29 LastEditTime: 2025-11-07 08:36:50 LastEditors: Wenyu Ouyang Description: Training function for DL models FilePath: orchhydro orchhydro rainers rain_logger.py Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.
TrainLogger
¶
Source code in torchhydro/trainers/train_logger.py
class TrainLogger:
def __init__(self, model_filepath, params, opt):
self.training_cfgs = params["training_cfgs"]
self.data_cfgs = params["data_cfgs"]
self.evaluation_cfgs = params["evaluation_cfgs"]
self.model_cfgs = params["model_cfgs"]
self.opt = opt
self.training_save_dir = model_filepath
self.tb = SummaryWriter(self.training_save_dir)
self.session_params = []
self.train_time = []
# log loss for each epoch
self.epoch_loss = []
# reload previous logs if continue_train is True and weight_path is not None
if (
self.model_cfgs["continue_train"]
and self.model_cfgs["weight_path"] is not None
):
the_logger_file = get_lastest_logger_file_in_a_dir(self.training_save_dir)
if the_logger_file is not None:
with open(the_logger_file, "r") as f:
logs = json.load(f)
start_epoch = self.training_cfgs["start_epoch"]
# read the logs before start_epoch and load them to session_params, train_time, epoch_loss
for log in logs["run"]:
if log["epoch"] < start_epoch:
self.session_params.append(log)
self.train_time.append(log["train_time"])
self.epoch_loss.append(float(log["train_loss"]))
def save_session_param(
self, epoch, total_loss, n_iter_ep, valid_loss=None, valid_metrics=None
):
if valid_metrics is None:
if valid_loss is None:
epoch_params = {
"epoch": epoch,
"train_loss": str(total_loss),
"iter_num": n_iter_ep,
}
else:
epoch_params = {
"epoch": epoch,
"train_loss": str(total_loss),
"validation_loss": str(valid_loss),
"iter_num": n_iter_ep,
}
else:
epoch_params = {
"epoch": epoch,
"train_loss": str(total_loss),
"validation_loss": str(valid_loss),
"validation_metric": valid_metrics,
"iter_num": n_iter_ep,
}
epoch_params["train_time"] = self.train_time[epoch - 1]
self.session_params.append(epoch_params)
@contextmanager
def log_epoch_train(self, epoch):
start_time = time.time()
logs = {}
# here content in the 'with' block will be performed after yeild
yield logs
total_loss = logs["train_loss"]
elapsed_time = time.time() - start_time
lr = self.opt.param_groups[0]["lr"]
log_str = "Epoch {} Loss {:.4f} time {:.2f} lr {}".format(
epoch, total_loss, elapsed_time, lr
)
print(log_str)
model = logs["model"]
print(model)
self.tb.add_scalar("Loss", total_loss, epoch)
# self.plot_hist_img(model, epoch)
self.train_time.append(log_str)
self.epoch_loss.append(total_loss)
@contextmanager
def log_epoch_valid(self, epoch):
logs = {}
yield logs
valid_loss = logs["valid_loss"]
if (
self.training_cfgs["valid_batch_mode"] == "test"
and self.training_cfgs["calc_metrics"]
):
# NOTE: Now we only evaluate the metrics for test-mode validation
valid_metrics = logs["valid_metrics"]
val_log = "Epoch {} Valid Loss {:.4f} Valid Metric {}".format(
epoch, valid_loss, valid_metrics
)
print(val_log)
self.tb.add_scalar("ValidLoss", valid_loss, epoch)
target_col = self.data_cfgs["target_cols"]
evaluation_metrics = self.evaluation_cfgs["metrics"]
for i in range(len(target_col)):
for evaluation_metric in evaluation_metrics:
self.tb.add_scalar(
f"Valid{target_col[i]}{evaluation_metric}mean",
np.nanmean(
valid_metrics[f"{evaluation_metric} of {target_col[i]}"]
),
epoch,
)
self.tb.add_scalar(
f"Valid{target_col[i]}{evaluation_metric}median",
np.nanmedian(
valid_metrics[f"{evaluation_metric} of {target_col[i]}"]
),
epoch,
)
else:
val_log = "Epoch {} Valid Loss {:.4f} ".format(epoch, valid_loss)
print(val_log)
self.tb.add_scalar("ValidLoss", valid_loss, epoch)
def save_model_and_params(self, model, epoch, params):
final_epoch = params["training_cfgs"]["epochs"]
save_epoch = params["training_cfgs"]["save_epoch"]
if save_epoch is None or save_epoch == 0 and epoch != final_epoch:
return
if (save_epoch > 0 and epoch % save_epoch == 0) or epoch == final_epoch:
# save for save_epoch
model_file = os.path.join(
self.training_save_dir, f"model_Ep{str(epoch)}.pth"
)
save_model(model, model_file)
if epoch == final_epoch:
self._save_final_epoch(params, model)
def _save_final_epoch(self, params, model):
# In final epoch, we save the model and params in case_dir
final_path = params["data_cfgs"]["case_dir"]
params["run"] = self.session_params
time_stamp = datetime.now().strftime("%d_%B_%Y%I_%M%p")
model_save_path = os.path.join(final_path, f"{time_stamp}_model.pth")
save_model(model, model_save_path)
save_model_params_log(params, final_path)
# also save one for a training directory for one hyperparameter setting
save_model_params_log(params, self.training_save_dir)
def plot_hist_img(self, model, global_step):
for tag, parm in model.named_parameters():
self.tb.add_histogram(
f"{tag}_hist", parm.detach().cpu().numpy(), global_step
)
if len(parm.shape) == 2:
img_format = "HW"
if parm.shape[0] > parm.shape[1]:
img_format = "WH"
self.tb.add_image(
f"{tag}_img",
parm.detach().cpu().numpy(),
global_step,
dataformats=img_format,
)
def plot_model_structure(self, model):
"""plot model structure in tensorboard
TODO: This function is not working as expected. It should be rewritten.
Parameters
----------
model :
torch model
"""
# input4modelplot = torch.randn(
# self.training_cfgs["batch_size"],
# self.training_cfgs["hindcast_length"],
# # self.model_cfgs["model_hyperparam"]["n_input_features"],
# self.model_cfgs["model_hyperparam"]["input_size"],
# )
if self.data_cfgs["model_mode"] == "single":
input4modelplot = [
torch.randn(
self.training_cfgs["batch_size"],
self.training_cfgs["hindcast_length"],
self.data_cfgs["input_features"] - 1,
),
torch.randn(
self.training_cfgs["batch_size"],
self.training_cfgs["hindcast_length"],
self.data_cfgs["cnn_size"],
),
torch.rand(
self.training_cfgs["batch_size"],
1,
self.data_cfgs["output_features"],
),
]
else:
input4modelplot = [
torch.randn(
self.training_cfgs["batch_size"],
self.training_cfgs["hindcast_length"],
self.data_cfgs["input_features"],
),
torch.randn(
self.training_cfgs["batch_size"],
self.training_cfgs["hindcast_length"],
self.data_cfgs["input_size_encoder2"],
),
torch.rand(
self.training_cfgs["batch_size"],
1,
self.data_cfgs["output_features"],
),
]
self.tb.add_graph(model, input4modelplot)
plot_model_structure(self, model)
¶
plot model structure in tensorboard
TODO: This function is not working as expected. It should be rewritten.
Parameters¶
model : torch model
Source code in torchhydro/trainers/train_logger.py
def plot_model_structure(self, model):
"""plot model structure in tensorboard
TODO: This function is not working as expected. It should be rewritten.
Parameters
----------
model :
torch model
"""
# input4modelplot = torch.randn(
# self.training_cfgs["batch_size"],
# self.training_cfgs["hindcast_length"],
# # self.model_cfgs["model_hyperparam"]["n_input_features"],
# self.model_cfgs["model_hyperparam"]["input_size"],
# )
if self.data_cfgs["model_mode"] == "single":
input4modelplot = [
torch.randn(
self.training_cfgs["batch_size"],
self.training_cfgs["hindcast_length"],
self.data_cfgs["input_features"] - 1,
),
torch.randn(
self.training_cfgs["batch_size"],
self.training_cfgs["hindcast_length"],
self.data_cfgs["cnn_size"],
),
torch.rand(
self.training_cfgs["batch_size"],
1,
self.data_cfgs["output_features"],
),
]
else:
input4modelplot = [
torch.randn(
self.training_cfgs["batch_size"],
self.training_cfgs["hindcast_length"],
self.data_cfgs["input_features"],
),
torch.randn(
self.training_cfgs["batch_size"],
self.training_cfgs["hindcast_length"],
self.data_cfgs["input_size_encoder2"],
),
torch.rand(
self.training_cfgs["batch_size"],
1,
self.data_cfgs["output_features"],
),
]
self.tb.add_graph(model, input4modelplot)
train_utils
¶
Author: Wenyu Ouyang Date: 2024-04-08 18:16:26 LastEditTime: 2025-11-06 13:48:48 LastEditors: Wenyu Ouyang Description: Some basic functions for training FilePath: orchhydro orchhydro rainers rain_utils.py Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved.
EarlyStopper
¶
Source code in torchhydro/trainers/train_utils.py
class EarlyStopper(object):
def __init__(
self,
patience: int,
min_delta: float = 0.0,
cumulative_delta: bool = False,
):
"""
EarlyStopping handler can be used to stop the training if no improvement after a given number of events.
Parameters
----------
patience
Number of events to wait if no improvement and then stop the training.
min_delta
A minimum increase in the score to qualify as an improvement,
i.e. an increase of less than or equal to `min_delta`, will count as no improvement.
cumulative_delta
It True, `min_delta` defines an increase since the last `patience` reset, otherwise,
it defines an increase after the last event. Default value is False.
"""
if patience < 1:
raise ValueError("Argument patience should be positive integer.")
if min_delta < 0.0:
raise ValueError("Argument min_delta should not be a negative number.")
self.patience = patience
self.min_delta = min_delta
self.cumulative_delta = cumulative_delta
self.counter = 0
self.best_score = None
def check_loss(self, model, validation_loss, save_dir) -> bool:
score = validation_loss
if self.best_score is None:
self.save_model_checkpoint(model, save_dir)
self.best_score = score
elif score + self.min_delta >= self.best_score:
self.counter += 1
print("Epochs without Model Update:", self.counter)
if self.counter >= self.patience:
return False
else:
self.save_model_checkpoint(model, save_dir)
print("Model Update")
self.best_score = score
self.counter = 0
return True
def save_model_checkpoint(self, model, save_dir):
torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))
__init__(self, patience, min_delta=0.0, cumulative_delta=False)
special
¶
EarlyStopping handler can be used to stop the training if no improvement after a given number of events.
Parameters¶
patience
Number of events to wait if no improvement and then stop the training.
min_delta
A minimum increase in the score to qualify as an improvement,
i.e. an increase of less than or equal to min_delta, will count as no improvement.
cumulative_delta
It True, min_delta defines an increase since the last patience reset, otherwise,
it defines an increase after the last event. Default value is False.
Source code in torchhydro/trainers/train_utils.py
def __init__(
self,
patience: int,
min_delta: float = 0.0,
cumulative_delta: bool = False,
):
"""
EarlyStopping handler can be used to stop the training if no improvement after a given number of events.
Parameters
----------
patience
Number of events to wait if no improvement and then stop the training.
min_delta
A minimum increase in the score to qualify as an improvement,
i.e. an increase of less than or equal to `min_delta`, will count as no improvement.
cumulative_delta
It True, `min_delta` defines an increase since the last `patience` reset, otherwise,
it defines an increase after the last event. Default value is False.
"""
if patience < 1:
raise ValueError("Argument patience should be positive integer.")
if min_delta < 0.0:
raise ValueError("Argument min_delta should not be a negative number.")
self.patience = patience
self.min_delta = min_delta
self.cumulative_delta = cumulative_delta
self.counter = 0
self.best_score = None
average_weights(w)
¶
Returns the average of the weights.
Source code in torchhydro/trainers/train_utils.py
def average_weights(w):
"""
Returns the average of the weights.
"""
w_avg = copy.deepcopy(w[0])
for key in w_avg.keys():
for i in range(1, len(w)):
w_avg[key] += w[i][key]
w_avg[key] = torch.div(w_avg[key], len(w))
return w_avg
cellstates_when_inference(seq_first, data_cfgs, pred)
¶
get cell states when inference
Source code in torchhydro/trainers/train_utils.py
def cellstates_when_inference(seq_first, data_cfgs, pred):
"""get cell states when inference"""
cs_out = (
cs_cat_lst.detach().cpu().numpy().swapaxes(0, 1)
if seq_first
else cs_cat_lst.detach().cpu().numpy()
)
cs_out_lst = [cs_out]
cell_state = reduce(lambda a, b: np.vstack((a, b)), cs_out_lst)
np.save(os.path.join(data_cfgs["case_dir"], "cell_states.npy"), cell_state)
# model.zero_grad()
torch.cuda.empty_cache()
return pred, cell_state
compute_loss(labels, output, criterion, **kwargs)
¶
Function for computing the loss
Parameters¶
labels The real values for the target. Shape can be variable but should follow (batch_size, time) output The output of the model criterion loss function validation_dataset Only passed when unscaling of data is needed. m defaults to 1
Returns¶
torch.Tensor the computed loss
Source code in torchhydro/trainers/train_utils.py
def compute_loss(
labels: torch.Tensor, output: torch.Tensor, criterion, **kwargs
) -> torch.Tensor:
"""
Function for computing the loss
Parameters
----------
labels
The real values for the target. Shape can be variable but should follow (batch_size, time)
output
The output of the model
criterion
loss function
validation_dataset
Only passed when unscaling of data is needed.
m
defaults to 1
Returns
-------
torch.Tensor
the computed loss
"""
# a = np.sum(output.cpu().detach().numpy(),axis=1)/len(output)
# b=[]
# for i in a:
# b.append([i.tolist()])
# output = torch.tensor(b, requires_grad=True).to(torch.device("cuda"))
if isinstance(criterion, GaussianLoss):
if len(output[0].shape) > 2:
g_loss = GaussianLoss(output[0][:, :, 0], output[1][:, :, 0])
else:
g_loss = GaussianLoss(output[0][:, 0], output[1][:, 0])
return g_loss(labels)
if isinstance(criterion, FloodBaseLoss):
# labels has one more column than output, which is the flood mask
# so we need to remove the last column of labels to get targets
flood_mask = labels[:, :, -1:] # Extract flood mask from last column
targets = labels[:, :, :-1] # Extract targets (remove last column)
return criterion(output, targets, flood_mask)
if (
isinstance(output, torch.Tensor)
and len(labels.shape) != len(output.shape)
and len(labels.shape) > 1
):
if labels.shape[1] == output.shape[1]:
labels = labels.unsqueeze(2)
else:
labels = labels.unsqueeze(0)
assert labels.shape == output.shape
return criterion(output, labels.float())
compute_validation(model, criterion, data_loader, device=None, **kwargs)
¶
Function to compute the validation loss metrics
Parameters¶
model the trained model criterion torch.nn.modules.loss dataloader The data-loader of either validation or test-data device torch.device
Returns¶
tuple validation observations (numpy array), predictions (numpy array) and the loss of validation
Source code in torchhydro/trainers/train_utils.py
def compute_validation(
model,
criterion,
data_loader: DataLoader,
device: torch.device = None,
**kwargs,
):
"""
Function to compute the validation loss metrics
Parameters
----------
model
the trained model
criterion
torch.nn.modules.loss
dataloader
The data-loader of either validation or test-data
device
torch.device
Returns
-------
tuple
validation observations (numpy array), predictions (numpy array) and the loss of validation
"""
model.eval()
seq_first = kwargs["which_first_tensor"] != "batch"
obs = []
preds = []
valid_loss = 0.0
obs_final = None
pred_final = None
with torch.no_grad():
iter_num = 0
for batch in tqdm(data_loader, desc="Evaluating", total=len(data_loader)):
trg, output = model_infer(seq_first, device, model, batch)
obs.append(trg)
preds.append(output)
valid_loss_ = compute_loss(trg, output, criterion)
if torch.isnan(valid_loss_):
# for not-train mode, we may get all nan data for trg
# so we skip this batch
continue
print("NAN loss detected, skipping this batch")
valid_loss = valid_loss + valid_loss_.item()
iter_num = iter_num + 1
# For flood datasets, remove the flood_mask column from observations
# to match the prediction dimensions for evaluation
trg_for_eval = (
trg[:, :, :-1] if isinstance(criterion, FloodBaseLoss) else trg
)
# clear memory to save GPU memory
if obs_final is None:
obs_final = trg_for_eval.detach().cpu()
pred_final = output.detach().cpu()
else:
obs_final = torch.cat([obs_final, trg_for_eval.detach().cpu()], dim=0)
pred_final = torch.cat([pred_final, output.detach().cpu()], dim=0)
del trg, output
torch.cuda.empty_cache()
valid_loss = valid_loss / iter_num
y_obs = obs_final.numpy()
y_pred = pred_final.numpy()
return y_obs, y_pred, valid_loss
evaluate_validation(validation_data_loader, output, labels, evaluation_cfgs, target_col)
¶
calculate metrics for validation
Parameters¶
output model output labels model target evaluation_cfgs evaluation configs target_col target columns
Returns¶
tuple metrics
Source code in torchhydro/trainers/train_utils.py
def evaluate_validation(
validation_data_loader,
output,
labels,
evaluation_cfgs,
target_col,
):
"""
calculate metrics for validation
Parameters
----------
output
model output
labels
model target
evaluation_cfgs
evaluation configs
target_col
target columns
Returns
-------
tuple
metrics
"""
fill_nan = evaluation_cfgs["fill_nan"]
if isinstance(fill_nan, list) and len(fill_nan) != len(target_col):
raise ValueError("Length of fill_nan must be equal to length of target_col.")
eval_log = {}
evaluation_metrics = evaluation_cfgs["metrics"]
obss_xr, preds_xr = get_preds_to_be_eval(
validation_data_loader,
evaluation_cfgs,
output,
labels,
)
# obss_xr_list
# preds_xr_list
# if type()
# for i in range(obs.shape[0]): # 第几个预见期
## obs_ = obs[i]
if isinstance(obss_xr, list):
obss_xr_list = obss_xr
preds_xr_list = preds_xr
for horizon_idx in range(len(obss_xr_list)):
obss_xr = obss_xr_list[horizon_idx]
preds_xr = preds_xr_list[horizon_idx]
for i, col in enumerate(target_col):
obs = obss_xr[col].to_numpy()
pred = preds_xr[col].to_numpy()
# eval_log will be updated rather than completely replaced, no need to use eval_log["key"]
eval_log = calculate_and_record_metrics(
obs,
pred,
evaluation_metrics,
col,
fill_nan[i] if isinstance(fill_nan, list) else fill_nan,
eval_log,
horizon_idx + 1,
)
return eval_log
for i, col in enumerate(target_col):
obs = obss_xr[col].to_numpy()
pred = preds_xr[col].to_numpy()
# eval_log will be updated rather than completely replaced, no need to use eval_log["key"]
eval_log = calculate_and_record_metrics(
obs,
pred,
evaluation_metrics,
col,
fill_nan[i] if isinstance(fill_nan, list) else fill_nan,
eval_log,
)
return eval_log
get_lastest_logger_file_in_a_dir(dir_path)
¶
Get the last logger file in a directory
Parameters¶
dir_path : str the directory
Returns¶
str the path of the logger file
Source code in torchhydro/trainers/train_utils.py
def get_lastest_logger_file_in_a_dir(dir_path):
"""Get the last logger file in a directory
Parameters
----------
dir_path : str
the directory
Returns
-------
str
the path of the logger file
"""
pattern = r"^\d{1,2}_[A-Za-z]+_\d{6}_\d{2}(AM|PM)\.json$"
pth_files_lst = [
os.path.join(dir_path, file)
for file in os.listdir(dir_path)
if re.match(pattern, file)
]
return get_latest_file_in_a_lst(pth_files_lst)
get_latest_pbm_param_file(param_dir)
¶
Get the latest parameter file of physics-based models in the current directory.
Parameters¶
param_dir : str The directory of parameter files.
Returns¶
str The latest parameter file.
Source code in torchhydro/trainers/train_utils.py
def get_latest_pbm_param_file(param_dir):
"""Get the latest parameter file of physics-based models in the current directory.
Parameters
----------
param_dir : str
The directory of parameter files.
Returns
-------
str
The latest parameter file.
"""
param_file_lst = [
os.path.join(param_dir, f)
for f in os.listdir(param_dir)
if f.startswith("pb_params") and f.endswith(".csv")
]
param_files = [Path(f) for f in param_file_lst]
param_file_names_lst = [param_file.stem.split("_") for param_file in param_files]
ctimes = [
int(param_file_names[param_file_names.index("params") + 1])
for param_file_names in param_file_names_lst
]
return param_files[ctimes.index(max(ctimes))] if ctimes else None
get_latest_tensorboard_event_file(log_dir)
¶
Get the latest event file in the log_dir directory.
Parameters¶
log_dir : str The directory where the event files are stored.
Returns¶
str The latest event file.
Source code in torchhydro/trainers/train_utils.py
def get_latest_tensorboard_event_file(log_dir):
"""Get the latest event file in the log_dir directory.
Parameters
----------
log_dir : str
The directory where the event files are stored.
Returns
-------
str
The latest event file.
"""
event_file_lst = [
os.path.join(log_dir, f) for f in os.listdir(log_dir) if f.startswith("events")
]
event_files = [Path(f) for f in event_file_lst]
event_file_names_lst = [event_file.stem.split(".") for event_file in event_files]
ctimes = [
int(event_file_names[event_file_names.index("tfevents") + 1])
for event_file_names in event_file_names_lst
]
return event_files[ctimes.index(max(ctimes))]
get_masked_tensors(variable_length_cfgs, batch, seq_first)
¶
Get the mask for the data
Parameters¶
variable_length_cfgs : dict The variable length configuration batch : tuple or list The batch data from collate_fn or dataset seq_first : bool Whether the data is in sequence first format
Returns¶
tuple For standard datasets: (xs, ys, xs_mask, ys_mask, xs_lens, ys_lens) For GNN datasets: (xs, ys, edge_index, edge_weight, xs_mask, ys_mask, xs_lens, ys_lens) For GNN with batch vector: (xs, ys, edge_index, edge_weight, batch_vector, xs_mask, ys_mask, xs_lens, ys_lens)
Source code in torchhydro/trainers/train_utils.py
def get_masked_tensors(variable_length_cfgs, batch, seq_first):
"""Get the mask for the data
Parameters
----------
variable_length_cfgs : dict
The variable length configuration
batch : tuple or list
The batch data from collate_fn or dataset
seq_first : bool
Whether the data is in sequence first format
Returns
-------
tuple
For standard datasets: (xs, ys, xs_mask, ys_mask, xs_lens, ys_lens)
For GNN datasets: (xs, ys, edge_index, edge_weight, xs_mask, ys_mask, xs_lens, ys_lens)
For GNN with batch vector: (xs, ys, edge_index, edge_weight, batch_vector, xs_mask, ys_mask, xs_lens, ys_lens)
"""
xs_mask = None
ys_mask = None
xs_lens = None
ys_lens = None
edge_index = None
edge_weight = None
if variable_length_cfgs is None:
# Check batch length to determine format
if len(batch) == 5:
# GNN batch with batch_vector: [sxc, y, edge_index, edge_weight, batch_vector]
xs, ys, edge_index, edge_weight, batch_vector = batch
return (
xs,
ys,
edge_index,
edge_weight,
batch_vector,
xs_mask,
ys_mask,
xs_lens,
ys_lens,
)
elif len(batch) == 4:
# GNN batch: [sxc, y, edge_index, edge_weight]
xs, ys, edge_index, edge_weight = batch
return xs, ys, edge_index, edge_weight, xs_mask, ys_mask, xs_lens, ys_lens
else:
# Standard batch: [xs, ys]
xs, ys = batch[0], batch[1]
return xs, ys, xs_mask, ys_mask, xs_lens, ys_lens
if variable_length_cfgs.get("use_variable_length", False):
# When using variable length training, batch comes from varied_length_collate_fn
# which returns [xs_pad, ys_pad, xs_lens, ys_lens, xs_mask, ys_mask]
if len(batch) >= 6:
xs, ys, xs_lens, ys_lens, xs_mask_bool, ys_mask_bool = batch[:6]
else:
# Fallback: treat as regular batch with first two elements
xs, ys = batch[0], batch[1]
xs_lens = ys_lens = xs_mask_bool = ys_mask_bool = None
if xs_mask_bool is None and ys_mask_bool is None:
# sometime even you choose to use variable length training, the batch data may still be fixed length
# so we need to return the batch data directly
return xs, ys, xs_mask_bool, ys_mask_bool, xs_lens, ys_lens
# Convert masks to the format expected by model (float tensor with shape [..., 1])
xs_mask = xs_mask_bool.unsqueeze(-1).float() # [batch, seq, 1]
ys_mask = ys_mask_bool.unsqueeze(-1).float() # [batch, seq, 1]
# Convert to appropriate format for model if needed
if seq_first:
xs_mask = xs_mask.transpose(0, 1) # [seq, batch, 1]
ys_mask = ys_mask.transpose(0, 1) # [seq, batch, 1]
else:
# Check batch length to determine format
if len(batch) == 5:
# GNN batch with batch_vector: [sxc, y, edge_index, edge_weight, batch_vector]
xs, ys, edge_index, edge_weight, batch_vector = batch
elif len(batch) == 4:
# GNN batch: [sxc, y, edge_index, edge_weight]
xs, ys, edge_index, edge_weight = batch
else:
# Standard batch: [xs, ys]
xs, ys = batch[0], batch[1]
# Return appropriate format based on what we have
if edge_index is not None and edge_weight is not None:
return (
xs,
ys,
edge_index,
edge_weight,
batch_vector,
xs_mask,
ys_mask,
xs_lens,
ys_lens,
)
else:
return xs, ys, xs_mask, ys_mask, xs_lens, ys_lens
get_preds_to_be_eval(valorte_data_loader, evaluation_cfgs, output, labels)
¶
Get prediction results prepared for evaluation: the denormalized data without metrics by different eval ways
Parameters¶
valorte_data_loader : DataLoader validation or test data loader evaluation_cfgs : dict evaluation configs output : np.ndarray model output labels : np.ndarray model target
Returns¶
tuple description
Source code in torchhydro/trainers/train_utils.py
def get_preds_to_be_eval(
valorte_data_loader,
evaluation_cfgs,
output,
labels,
):
"""
Get prediction results prepared for evaluation:
the denormalized data without metrics by different eval ways
Parameters
----------
valorte_data_loader : DataLoader
validation or test data loader
evaluation_cfgs : dict
evaluation configs
output : np.ndarray
model output
labels : np.ndarray
model target
Returns
-------
tuple
_description_
"""
evaluator = evaluation_cfgs["evaluator"]
# this test_rolling means how we perform prediction during testing
test_rolling = evaluation_cfgs["rolling"]
batch_size = valorte_data_loader.batch_size
target_scaler = valorte_data_loader.dataset.target_scaler
target_data = target_scaler.data_target
rho = valorte_data_loader.dataset.rho
horizon = valorte_data_loader.dataset.horizon
warmup_length = valorte_data_loader.dataset.warmup_length
hindcast_output_window = target_scaler.data_cfgs["hindcast_output_window"]
nf = valorte_data_loader.dataset.noutputvar # number of features
# number of time steps after warmup as outputs typically don't include warmup period
nt = valorte_data_loader.dataset.nt - warmup_length
basin_num = len(target_data.basin)
data_shape = (basin_num, nt, nf)
if evaluator["eval_way"] == "once":
stride = evaluator["stride"]
if stride > 0:
if horizon != stride:
raise NotImplementedError(
"horizon should be equal to stride in evaluator if you chose eval_way to be once, or else you need to change the eval_way to be 1pace or rolling"
)
obs = _rolling_preds_for_once_eval(
(basin_num, horizon, nf),
rho,
evaluation_cfgs["forecast_length"],
stride,
hindcast_output_window,
target_data.reshape(basin_num, horizon, nf),
)
pred = _rolling_preds_for_once_eval(
(basin_num, horizon, nf),
rho,
evaluation_cfgs["forecast_length"],
stride,
hindcast_output_window,
output.reshape(batch_size, horizon, nf),
)
else:
if test_rolling > 0:
raise RuntimeError(
"please set rolling to 0 when you chose eval way as once and stride=0"
)
obs = labels.reshape(basin_num, -1, nf)
pred = output.reshape(basin_num, -1, nf)
elif evaluator["eval_way"] == "1pace":
if test_rolling < 1:
raise NotImplementedError(
"rolling should be larger than 0 if you chose eval_way to be 1pace"
)
pace_idx = evaluator["pace_idx"]
# stride = evaluator.get("stride", 1)
# for 1pace with pace_idx meaning which value of output was chosen to show
# 1st, we need to transpose data to 4-dim to show the whole data
# TODO:check should we select which def
pred = _recover_samples_to_basin(output, valorte_data_loader, pace_idx)
obs = _recover_samples_to_basin(labels, valorte_data_loader, pace_idx)
elif evaluator["eval_way"] == "rolling":
# 获取滚动预测所需的参数
stride = evaluator.get("stride", 1)
if stride != 1:
raise NotImplementedError(
"if stride is not equal to 1, we think it is meaningless"
)
# 重组预测结果和观测值
basin_num = len(target_data.basin)
# 新增:根据配置选择不同的数据组织方式
recover_mode = evaluator.get("recover_mode", "bybasins")
stride = evaluator.get("stride", 1)
data_shape = (basin_num, nt, nf)
if recover_mode == "bybasins":
pred = _recover_samples_to_4d_by_basins(
data_shape,
valorte_data_loader,
stride,
hindcast_output_window,
output,
)
obs = _recover_samples_to_4d_by_basins(
data_shape,
valorte_data_loader,
stride,
hindcast_output_window,
labels,
)
elif recover_mode == "byforecast":
pred = _recover_samples_to_4d_by_forecast(
data_shape,
valorte_data_loader,
stride,
hindcast_output_window,
output, # samples, seq_length, nf
)
obs = _recover_samples_to_4d_by_forecast(
data_shape,
valorte_data_loader,
stride,
hindcast_output_window,
labels,
)
elif recover_mode == "byensembles":
pred = _recover_samples_to_3d_by_4d_ensembles(
data_shape,
valorte_data_loader,
stride,
hindcast_output_window,
output,
)
obs = _recover_samples_to_3d_by_4d_ensembles(
data_shape,
valorte_data_loader,
stride,
hindcast_output_window,
labels,
)
else:
raise ValueError(
f"Unsupported recover_mode: {recover_mode}, must be 'bybasins' or 'byforecast' or 'byensembles'"
)
elif evaluator["eval_way"] == "floodevent":
# For flood event evaluation, stride is not typically used, but we set it to 1 for consistency
stride = evaluator.get("stride", 1)
pred = _recover_samples_to_continuous_by_floodevent(
data_shape,
valorte_data_loader,
stride,
hindcast_output_window,
output,
)
obs = _recover_samples_to_continuous_by_floodevent(
data_shape,
valorte_data_loader,
stride,
hindcast_output_window,
labels,
)
else:
raise ValueError("eval_way should be rolling or 1pace")
# pace_idx = np.nan
recover_mode = evaluator.get("recover_mode")
valte_dataset = valorte_data_loader.dataset
# 检查数据维度并进行适当处理
if pred.ndim == 4:
# 如果是四维数据,需要根据评估方式选择合适的处理方法
if evaluator["eval_way"] == "1pace" and "pace_idx" in evaluator:
# 对于1pace模式,选择特定的预测步长
pace_idx = evaluator["pace_idx"]
# 选择特定预测步长的数据
pred_3d = pred[:, :, pace_idx, :]
obs_3d = obs[:, :, pace_idx, :]
preds_xr = valte_dataset.denormalize(pred_3d, pace_idx)
obss_xr = valte_dataset.denormalize(obs_3d, pace_idx)
elif evaluator["eval_way"] == "rolling" and recover_mode == "byforecast":
# 对于byforecast模式,需要特殊处理
# 创建一个列表存储每个预测步长的结果
preds_xr_list = []
obss_xr_list = []
for i in range(pred.shape[2]):
pred_3d = pred[:, :, i, :]
obs_3d = obs[:, :, i, :]
the_array_pred_ = np.full(target_data.shape, np.nan)
the_array_obs_ = np.full(target_data.shape, np.nan)
start = rho + i # TODO:need check
end = start + pred_3d.shape[1]
assert end <= the_array_pred_.shape[1]
the_array_pred_[:, start:end, :] = pred_3d
the_array_obs_[:, start:end, :] = obs_3d
preds_xr_list.append(valte_dataset.denormalize(the_array_pred_, i))
obss_xr_list.append(valte_dataset.denormalize(the_array_obs_, i))
# 合并结果
# preds_xr = xr.concat(preds_xr_list, dim="horizon")
# obss_xr = xr.concat(obss_xr_list, dim="horizon")
return obss_xr_list, preds_xr_list
elif evaluator["eval_way"] == "rolling" and recover_mode == "bybasins":
# 对于其他情况,可以考虑将四维数据转换为三维
# 例如,取最后一个预测步长
preds_xr_list = []
obss_xr_list = []
for i in range(pred.shape[0]):
pred_3d = pred[i, :, :, :]
obs_3d = obs[i, :, :, :]
selected_data = target_scaler.data_target
the_array_pred_ = np.full(selected_data.shape, np.nan)
the_array_obs_ = np.full(selected_data.shape, np.nan)
start = rho # TODO:need check
end = start + pred_3d.shape[1] # 自动计算填充的结束位置
# 检查是否越界(可选)
assert end <= the_array_pred_.shape[1] # "填充范围超出目标数组的边界"
# 执行填充
the_array_pred_[:, start:end, :] = pred_3d
the_array_obs_[:, start:end, :] = obs_3d
preds_xr = valte_dataset.denormalize(the_array_pred_, -1)
obss_xr = valte_dataset.denormalize(the_array_obs_, -1)
else:
# for 3d data, directly process
# TODO: maybe need more test for the pace_idx case
preds_xr = valte_dataset.denormalize(pred)
obss_xr = valte_dataset.denormalize(obs)
def _align_and_order(_obs, _pred):
# 对齐到公共 (basin,time,variable) 的交集,避免 outer 引入 NaN
_obs, _pred = xr.align(_obs, _pred, join="inner")
# time 维为空(无交集)时直接抛错,避免进入 nanmean
if _obs.sizes.get("time", 0) == 0:
raise ValueError(
"No overlapping timestamps between observations and predictions "
f"(obs.time len={_obs.sizes.get('time',0)}, pred.time len={_pred.sizes.get('time',0)})."
)
# 按时间排序(保险)
if "time" in _obs.dims:
_obs = _obs.sortby("time")
if "time" in _pred.dims:
_pred = _pred.sortby("time")
# 规范维度顺序(若存在)
wanted = [d for d in ("basin", "time", "variable") if d in _obs.dims]
_obs = _obs.transpose(*wanted, missing_dims="ignore")
_pred = _pred.transpose(*wanted, missing_dims="ignore")
return _obs, _pred
# 单对象 vs 列表分别处理
if preds_xr is not None and obss_xr is not None:
obss_xr, preds_xr = _align_and_order(obss_xr, preds_xr)
return obss_xr, preds_xr
elif preds_xr_list is not None and obss_xr_list is not None:
obss_aligned, preds_aligned = [], []
for _o, _p in zip(obss_xr_list, preds_xr_list):
_o2, _p2 = _align_and_order(_o, _p)
obss_aligned.append(_o2)
preds_aligned.append(_p2)
return obss_aligned, preds_aligned
else:
# 理论不应走到这
raise RuntimeError("Failed to build preds_xr / obss_xr for evaluation.")
gnn_collate_fn(batch)
¶
Custom collate function for GNN datasets that handles variable-sized graphs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch |
A list of samples, where each sample is a tuple of (sxc, y, edge_index, edge_weight). |
required |
Returns:
| Type | Description |
|---|---|
A list containing batched tensors |
[batched_sxc, batched_y, batched_edge_index, batched_edge_weight, batch_vector] |
Source code in torchhydro/trainers/train_utils.py
def gnn_collate_fn(batch):
"""Custom collate function for GNN datasets that handles variable-sized graphs.
Args:
batch: A list of samples, where each sample is a tuple of
(sxc, y, edge_index, edge_weight).
Returns:
A list containing batched tensors:
[batched_sxc, batched_y, batched_edge_index, batched_edge_weight, batch_vector]
"""
import torch
if len(batch) == 0:
return []
# Unpack the batch
sxc_list, y_list, edge_index_list, edge_weight_list = zip(*batch)
# Batch the target values (y) - these should have the same shape
batched_y = torch.stack(y_list, dim=0) # [batch_size, forecast_length, output_dim]
# Find the maximum number of nodes in this batch
max_num_nodes = max(sxc.shape[0] for sxc in sxc_list)
# Get dimensions
batch_size = len(sxc_list)
seq_length = sxc_list[0].shape[1]
feature_dim = sxc_list[0].shape[2]
# Create padded tensor for node features
batched_sxc = torch.zeros(batch_size, max_num_nodes, seq_length, feature_dim)
# Create batched edge indices and weights
# For each graph in the batch, we need to offset node indices
batched_edge_index = []
batched_edge_weight = []
batch_vector = []
node_offset = 0
for i, (sxc, edge_index, edge_weight) in enumerate(
zip(sxc_list, edge_index_list, edge_weight_list)
):
num_nodes = sxc.shape[0]
# Fill the padded tensor with actual node features
batched_sxc[i, :num_nodes] = sxc
# For edge indices, we need to offset by node_offset to make them unique across batch
if edge_index.numel() > 0:
if edge_index.max() >= num_nodes:
print(
f"Warning: Graph {i} has edge indices {edge_index.max().item()} >= num_nodes {num_nodes}"
)
valid_mask = (edge_index[0] < num_nodes) & (edge_index[1] < num_nodes)
edge_index = edge_index[:, valid_mask]
edge_weight = edge_weight[valid_mask]
if edge_index.numel() > 0:
offset_edge_index = edge_index + node_offset
batched_edge_index.append(offset_edge_index)
batched_edge_weight.append(edge_weight)
# batch_vector: for each node in this graph, assign batch index i
batch_vector.append(torch.full((num_nodes,), i, dtype=torch.long))
node_offset += num_nodes
# Concatenate edge indices and weights if they exist
if batched_edge_index:
batched_edge_index = torch.cat(batched_edge_index, dim=1) # [2, total_edges]
batched_edge_weight = torch.cat(batched_edge_weight, dim=0) # [total_edges]
else:
batched_edge_index = torch.empty((2, 0), dtype=torch.long)
batched_edge_weight = torch.empty(0)
batch_vector = torch.cat(batch_vector, dim=0) # [total_nodes]
return [
batched_sxc,
batched_y,
batched_edge_index,
batched_edge_weight,
batch_vector,
]
model_infer(seq_first, device, model, batch, variable_length_cfgs=None, return_key=None)
¶
Unified model inference function with variable length support
Parameters¶
seq_first : bool if True, the input data is sequence first device : torch.device cpu or gpu model : torch.nn.Module the model batch : tuple or list batch data from collate_fn or dataset variable_length_cfgs : dict, optional variable length configuration containing mask settings return_key : str, optional when model returns a dict, choose which key (e.g., "f2") to return. if None, defaults to the last frequency (max key).
Source code in torchhydro/trainers/train_utils.py
def model_infer(
seq_first, device, model, batch, variable_length_cfgs=None, return_key=None
):
"""
Unified model inference function with variable length support
Parameters
----------
seq_first : bool
if True, the input data is sequence first
device : torch.device
cpu or gpu
model : torch.nn.Module
the model
batch : tuple or list
batch data from collate_fn or dataset
variable_length_cfgs : dict, optional
variable length configuration containing mask settings
return_key : str, optional
when model returns a dict, choose which key (e.g., "f2") to return.
if None, defaults to the last frequency (max key).
"""
result = get_masked_tensors(variable_length_cfgs, batch, seq_first)
# --- unpack inputs ---
if len(result) == 9:
(
xs,
ys,
edge_index,
edge_weight,
batch_vector,
xs_mask,
ys_mask,
xs_lens,
ys_lens,
) = result
elif len(result) == 8:
xs, ys, edge_index, edge_weight, xs_mask, ys_mask, xs_lens, ys_lens = result
batch_vector = None
else:
xs, ys, xs_mask, ys_mask, xs_lens, ys_lens = result
edge_index = edge_weight = batch_vector = None
# --- move xs to device ---
if isinstance(xs, list):
xs = [
(
x.permute(1, 0, 2).to(device)
if seq_first and x.ndim == 3
else x.to(device)
)
for x in xs
]
else:
xs = [
(
xs.permute(1, 0, 2).to(device)
if seq_first and xs.ndim == 3
else xs.to(device)
)
]
# --- move ys to device ---
if ys is not None:
ys = (
ys.permute(1, 0, 2).to(device)
if seq_first and ys.ndim == 3
else ys.to(device)
)
# --- move graph data ---
if edge_index is not None:
edge_index = edge_index.to(device)
if edge_weight is not None:
edge_weight = edge_weight.to(device)
if batch_vector is not None:
batch_vector = batch_vector.to(device)
# --- forward ---
if xs_mask is not None and ys_mask is not None:
if edge_index is not None and edge_weight is not None:
output = model(
*xs,
edge_index=edge_index,
edge_weight=edge_weight,
batch_vector=batch_vector,
mask=xs_mask,
seq_lengths=xs_lens,
)
else:
output = model(*xs, mask=xs_mask, seq_lengths=xs_lens)
else:
if edge_index is not None and edge_weight is not None:
output = model(
*xs,
edge_index=edge_index,
edge_weight=edge_weight,
batch_vector=batch_vector,
)
else:
output = model(*xs)
# --- handle model outputs ---
if isinstance(output, tuple):
output = output[0]
if isinstance(output, dict):
# 默认取最高频的输出
if return_key is None:
return_key = sorted(output.keys())[-1] # e.g., "f2"
if return_key not in output:
raise KeyError(
f"Model returned keys {list(output.keys())}, but return_key='{return_key}' not found"
)
output = output[return_key]
if ys_mask is not None:
ys = ys.masked_fill(ys_mask == 0, torch.nan)
# --- seq_first transpose back ---
if seq_first:
output = output.transpose(0, 1)
if ys is not None:
ys = ys.transpose(0, 1)
return ys, output
torch_single_train(model, opt, criterion, data_loader, device=None, **kwargs)
¶
Training function for one epoch
Parameters¶
model a PyTorch model inherit from nn.Module opt optimizer function from PyTorch optim.Optimizer criterion loss function data_loader object for loading data to the model device where we put the tensors and models
Returns¶
tuple(torch.Tensor, int) loss of this epoch and number of all iterations
Raises¶
ValueError if nan exits, raise a ValueError
Source code in torchhydro/trainers/train_utils.py
def torch_single_train(
model,
opt: optim.Optimizer,
criterion,
data_loader: DataLoader,
device=None,
**kwargs,
):
"""
Training function for one epoch
Parameters
----------
model
a PyTorch model inherit from nn.Module
opt
optimizer function from PyTorch optim.Optimizer
criterion
loss function
data_loader
object for loading data to the model
device
where we put the tensors and models
Returns
-------
tuple(torch.Tensor, int)
loss of this epoch and number of all iterations
Raises
--------
ValueError
if nan exits, raise a ValueError
"""
# we will set model.eval() in the validation function so here we should set model.train()
model.train()
n_iter_ep = 0
running_loss = 0.0
which_first_tensor = kwargs["which_first_tensor"]
seq_first = which_first_tensor != "batch"
variable_length_cfgs = data_loader.dataset.training_cfgs.get(
"variable_length_cfgs", None
)
pbar = tqdm(data_loader)
for _, batch in enumerate(pbar):
# mask handling is already done inside model_infer function
trg, output = model_infer(seq_first, device, model, batch, variable_length_cfgs)
loss = compute_loss(trg, output, criterion, **kwargs)
if loss > 100:
print("Warning: high loss detected")
if torch.isnan(loss):
raise ValueError("nan loss detected")
# continue
loss.backward() # Backpropagate to compute the current gradient
opt.step() # Update network parameters based on gradients
model.zero_grad() # clear gradient
if loss == float("inf"):
raise ValueError(
"Error infinite loss detected. Try normalizing data or performing interpolation"
)
running_loss += loss.item()
n_iter_ep += 1
if n_iter_ep == 0:
raise ValueError(
"All batch computations of loss result in NAN. Please check the data."
)
total_loss = running_loss / float(n_iter_ep)
return total_loss, n_iter_ep
varied_length_collate_fn(batch)
¶
Collate function for variable length training
This function is automatically used by DataLoader when variable_length_cfgs["use_variable_length"] is True. It pads sequences to the same length and generates corresponding masks.
Parameters¶
batch : list of tuples The batch data after the dataset getitem method
Returns¶
list [xs_pad, ys_pad, xs_lens, ys_lens, xs_mask, ys_mask] - xs_pad: padded input sequences [batch, max_seq_len, input_dim] - ys_pad: padded output sequences [batch, max_seq_len, output_dim] - xs_lens: original sequence lengths for input - ys_lens: original sequence lengths for output - xs_mask: valid position mask for input [batch, max_seq_len] - ys_mask: valid position mask for output [batch, max_seq_len]
Source code in torchhydro/trainers/train_utils.py
def varied_length_collate_fn(batch):
"""Collate function for variable length training
This function is automatically used by DataLoader when variable_length_cfgs["use_variable_length"] is True.
It pads sequences to the same length and generates corresponding masks.
Parameters
----------
batch : list of tuples
The batch data after the dataset __getitem__ method
Returns
-------
list
[xs_pad, ys_pad, xs_lens, ys_lens, xs_mask, ys_mask]
- xs_pad: padded input sequences [batch, max_seq_len, input_dim]
- ys_pad: padded output sequences [batch, max_seq_len, output_dim]
- xs_lens: original sequence lengths for input
- ys_lens: original sequence lengths for output
- xs_mask: valid position mask for input [batch, max_seq_len]
- ys_mask: valid position mask for output [batch, max_seq_len]
"""
xs, ys = zip(*batch)
# sometimes x is a tuple like in dpl dataset, then we can get the shape of the first element as the length
xs_lens = [x[0].shape[0] if type(x) in [tuple, list] else x.shape[0] for x in xs]
ys_lens = [y[0].shape[0] if type(y) in [tuple, list] else y.shape[0] for y in ys]
# if all ys_lens are the same, use default collate_fn to create tensors
if len(set(ys_lens)) == 1 and len(set(xs_lens)) == 1:
xs_tensor = default_collate(xs)
ys_tensor = default_collate(ys)
return [xs_tensor, ys_tensor, None, None, None, None]
# pad the batch data with padding value 0
xs_pad = pad_sequence(xs, batch_first=True, padding_value=0)
ys_pad = pad_sequence(ys, batch_first=True, padding_value=0)
# generate the mask for the batch data
# xs_mask: [batch_size, max_seq_len] or [batch_size, max_seq_len, 1]
batch_size = len(xs_lens)
max_xs_len = max(xs_lens)
max_ys_len = max(ys_lens)
# create the mask for the input sequence (True for valid positions, False for padding positions)
xs_mask = torch.zeros(batch_size, max_xs_len, dtype=torch.bool)
for i, length in enumerate(xs_lens):
xs_mask[i, :length] = True
# create the mask for the output sequence
ys_mask = torch.zeros(batch_size, max_ys_len, dtype=torch.bool)
for i, length in enumerate(ys_lens):
ys_mask[i, :length] = True
return [
xs_pad,
ys_pad,
xs_lens,
ys_lens,
xs_mask,
ys_mask,
]
trainer
¶
Author: Wenyu Ouyang Date: 2021-12-05 11:21:58 LastEditTime: 2025-11-08 16:02:09 LastEditors: Wenyu Ouyang Description: Main function for training and testing FilePath: orchhydro orchhydro rainers rainer.py Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.
ensemble_train_and_evaluate(cfgs)
¶
Function to train and test for ensemble models
Parameters¶
cfgs Dictionary containing all configs needed to run the model
Returns¶
None
Source code in torchhydro/trainers/trainer.py
def ensemble_train_and_evaluate(cfgs: Dict):
"""
Function to train and test for ensemble models
Parameters
----------
cfgs
Dictionary containing all configs needed to run the model
Returns
-------
None
"""
# for basins and models
ensemble = cfgs["training_cfgs"]["ensemble"]
if not ensemble:
raise ValueError(
"ensemble should be True, otherwise should use train_and_evaluate rather than ensemble_train_and_evaluate"
)
ensemble_items = cfgs["training_cfgs"]["ensemble_items"]
number_of_items = len(ensemble_items)
if number_of_items == 0:
raise ValueError("ensemble_items should not be empty")
keys_list = list(ensemble_items.keys())
if "kfold" in keys_list:
_trans_kfold_to_periods(cfgs, ensemble_items, "kfold")
_nested_loop_train_and_evaluate(keys_list, 0, ensemble_items, cfgs)
set_random_seed(seed)
¶
Set a random seed to guarantee the reproducibility
Parameters¶
seed a number
Returns¶
None
Source code in torchhydro/trainers/trainer.py
def set_random_seed(seed):
"""
Set a random seed to guarantee the reproducibility
Parameters
----------
seed
a number
Returns
-------
None
"""
# print("Random seed:", seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
train_and_evaluate(cfgs)
¶
Function to train and test a Model
Parameters¶
cfgs Dictionary containing all configs needed to run the model
Returns¶
None
Source code in torchhydro/trainers/trainer.py
def train_and_evaluate(cfgs: Dict):
"""
Function to train and test a Model
Parameters
----------
cfgs
Dictionary containing all configs needed to run the model
Returns
-------
None
"""
random_seed = cfgs["training_cfgs"]["random_seed"]
set_random_seed(random_seed)
resulter = Resulter(cfgs)
deephydro = _get_deep_hydro(cfgs)
# if train_mode is False, we only evaluate the model
train_mode = deephydro.cfgs["training_cfgs"]["train_mode"]
# but if train_mode is True, we still need some conditions to train the model
continue_train = deephydro.cfgs["model_cfgs"]["continue_train"]
is_transfer_learning = deephydro.cfgs["model_cfgs"]["model_type"] == "TransLearn"
is_train = train_mode and (
(deephydro.weight_path is not None and (continue_train or is_transfer_learning))
or (deephydro.weight_path is None)
)
if is_train:
deephydro.model_train()
preds, obss = deephydro.model_evaluate()
resulter.save_cfg(deephydro.cfgs)
resulter.save_result(preds, obss)
resulter.eval_result(preds, obss)