12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374 | class Estreams(HydroDataset):
"""EStreams dataset class extending HydroDataset.
This class uses a custom data reading implementation to support a newer
dataset version than the one supported by the underlying aquafetch library.
It overrides the download URLs and provides updated methods.
"""
def __init__(
self, data_path: str, region: Optional[str] = None, download: bool = False
) -> None:
"""Initialize EStreams dataset.
Args:
data_path: Path to the EStreams data directory
region: Geographic region identifier (optional)
download: Whether to download data automatically (default: False)
"""
super().__init__(data_path)
self.region = region
self.download = download
# Instantiate EStreams from aqua_fetch
# The _read_stn_dyn method and path2 fix have been added directly to aqua_fetch
self.aqua_fetch = EStreams(data_path)
@property
def _attributes_cache_filename(self):
return "estreams_attributes.nc"
@property
def _timeseries_cache_filename(self):
return "estreams_timeseries.nc"
@property
def default_t_range(self):
return ["1950-01-01", "2023-06-30"]
# get the information of features from "https://www.nature.com/articles/s41597-024-03706-1/tables/6"
# Define standardized static variable mappings
_subclass_static_definitions = {
"p_mean": {"specific_name": "p_mean", "unit": "mm/day"},
"area": {"specific_name": "area_km2", "unit": "km^2"},
}
# Define standardized dynamic variable mappings
_dynamic_variable_mapping = {
StandardVariable.PRECIPITATION: {
"default_source": "estreams",
"sources": {"estreams": {"specific_name": "pcp_mm", "unit": "mm/day"}},
},
StandardVariable.TEMPERATURE_MEAN: {
"default_source": "estreams",
"sources": {"estreams": {"specific_name": "airtemp_c_mean", "unit": "°C"}},
},
StandardVariable.TEMPERATURE_MIN: {
"default_source": "estreams",
"sources": {"estreams": {"specific_name": "airtemp_c_min", "unit": "°C"}},
},
StandardVariable.TEMPERATURE_MAX: {
"default_source": "estreams",
"sources": {"estreams": {"specific_name": "airtemp_c_max", "unit": "°C"}},
},
StandardVariable.SURFACE_PRESSURE: {
"default_source": "estreams",
"sources": {"estreams": {"specific_name": "sp_mean", "unit": "hPa"}},
},
StandardVariable.RELATIVE_HUMIDITY: {
"default_source": "estreams",
"sources": {"estreams": {"specific_name": "rh_", "unit": "%"}},
},
StandardVariable.WIND_SPEED: {
"default_source": "estreams",
"sources": {"estreams": {"specific_name": "windspeed_mps", "unit": "m/s"}},
},
StandardVariable.SOLAR_RADIATION: {
"default_source": "estreams",
"sources": {"estreams": {"specific_name": "solrad_wm2", "unit": "W/m^2"}},
},
StandardVariable.POTENTIAL_EVAPOTRANSPIRATION: {
"default_source": "estreams",
"sources": {"estreams": {"specific_name": "pet_mm", "unit": "mm/day"}},
},
}
def cache_timeseries_xrdataset(self, batch_size=100):
"""
Cache timeseries data to NetCDF files in batches
Args:
batch_size: Number of stations to process per batch, default is 100 stations
"""
if not hasattr(self, "aqua_fetch"):
raise NotImplementedError("aqua_fetch attribute is required")
# Build mapping from variable names to units
unit_lookup = {}
if hasattr(self, "_dynamic_variable_mapping"):
for std_name, mapping_info in self._dynamic_variable_mapping.items():
for source, source_info in mapping_info["sources"].items():
unit_lookup[source_info["specific_name"]] = source_info["unit"]
# Get all station IDs
gage_id_lst = self.read_object_ids().tolist()
total_stations = len(gage_id_lst)
# Get original variable list and clean
original_var_lst = self.aqua_fetch.dynamic_features
cleaned_var_lst = self._clean_feature_names(original_var_lst)
var_name_mapping = dict(zip(original_var_lst, cleaned_var_lst))
print(
f"Start batch processing {total_stations} stations, {batch_size} stations per batch"
)
print(
f"Total number of batches: {(total_stations + batch_size - 1)//batch_size}"
)
# Ensure cache directory exists
self.cache_dir.mkdir(parents=True, exist_ok=True)
# Process stations in batches and save independently
batch_num = 1
for batch_idx in range(0, total_stations, batch_size):
batch_end = min(batch_idx + batch_size, total_stations)
batch_stations = gage_id_lst[batch_idx:batch_end]
print(
f"\nProcessing batch {batch_num}/{(total_stations + batch_size - 1)//batch_size}"
)
print(
f"Station range: {batch_idx} - {batch_end-1} (total {len(batch_stations)} stations)"
)
try:
# Get data for this batch
batch_data = self.aqua_fetch.fetch_stations_features(
stations=batch_stations,
dynamic_features=original_var_lst,
static_features=None,
st=self.default_t_range[0],
en=self.default_t_range[1],
as_dataframe=False,
)
dynamic_data = (
batch_data[1] if isinstance(batch_data, tuple) else batch_data
)
# Process variables
new_data_vars = {}
time_coord = dynamic_data.coords["time"]
for original_var in tqdm(
original_var_lst,
desc=f"Processing variables (batch {batch_num})",
total=len(original_var_lst),
):
cleaned_var = var_name_mapping[original_var]
var_data = []
for station in batch_stations:
if station in dynamic_data.data_vars:
station_data = dynamic_data[station].sel(
dynamic_features=original_var
)
if "dynamic_features" in station_data.coords:
station_data = station_data.drop("dynamic_features")
var_data.append(station_data)
if var_data:
combined = xr.concat(var_data, dim="basin")
combined["basin"] = batch_stations
combined.attrs["units"] = unit_lookup.get(
cleaned_var, "unknown"
)
new_data_vars[cleaned_var] = combined
# Create Dataset for this batch
batch_ds = xr.Dataset(
data_vars=new_data_vars,
coords={
"basin": batch_stations,
"time": time_coord,
},
)
# Save this batch to independent file
batch_filename = f"batch{batch_num:03d}_estreams_timeseries.nc"
batch_filepath = self.cache_dir.joinpath(batch_filename)
print(f"Saving batch {batch_num} to: {batch_filepath}")
batch_ds.to_netcdf(batch_filepath)
print(f"Batch {batch_num} saved successfully")
except Exception as e:
print(f"Batch {batch_num} processing failed: {e}")
import traceback
traceback.print_exc()
continue
batch_num += 1
print(f"\nAll batches processed! Total {batch_num - 1} batch files saved")
def read_ts_xrdataset(
self,
gage_id_lst: list = None,
t_range: list = None,
var_lst: list = None,
sources: dict = None,
**kwargs,
) -> xr.Dataset:
"""
Read timeseries data from batch-saved cache files
Args:
gage_id_lst: List of station IDs
t_range: Time range [start, end]
var_lst: List of standard variable names
sources: Data source dictionary
Returns:
xr.Dataset: xarray dataset containing requested data
"""
if (
not hasattr(self, "_dynamic_variable_mapping")
or not self._dynamic_variable_mapping
):
raise NotImplementedError(
"This dataset does not support the standardized variable mapping."
)
if var_lst is None:
var_lst = list(self._dynamic_variable_mapping.keys())
if t_range is None:
t_range = self.default_t_range
target_vars_to_fetch = []
rename_map = {}
# Process variable name mapping and data source selection
for std_name in var_lst:
if std_name not in self._dynamic_variable_mapping:
raise ValueError(
f"'{std_name}' is not a recognized standard variable for this dataset."
)
mapping_info = self._dynamic_variable_mapping[std_name]
# Determine which data source(s) to use
is_explicit_source = sources and std_name in sources
sources_to_use = []
if is_explicit_source:
provided_sources = sources[std_name]
if isinstance(provided_sources, list):
sources_to_use.extend(provided_sources)
else:
sources_to_use.append(provided_sources)
else:
sources_to_use.append(mapping_info["default_source"])
# Only need suffix when user explicitly requests multiple data sources
needs_suffix = is_explicit_source and len(sources_to_use) > 1
for source in sources_to_use:
if source not in mapping_info["sources"]:
raise ValueError(
f"Source '{source}' is not available for variable '{std_name}'."
)
actual_var_name = mapping_info["sources"][source]["specific_name"]
target_vars_to_fetch.append(actual_var_name)
output_name = f"{std_name}_{source}" if needs_suffix else std_name
rename_map[actual_var_name] = output_name
# Find all batch files
batch_pattern = str(self.cache_dir / "batch*_estreams_timeseries.nc")
batch_files = sorted(glob.glob(batch_pattern))
if not batch_files:
print("No batch cache files found, starting cache creation...")
self.cache_timeseries_xrdataset()
batch_files = sorted(glob.glob(batch_pattern))
if not batch_files:
raise FileNotFoundError("Cache creation failed, no batch files found")
print(f"Found {len(batch_files)} batch files")
# If no stations specified, read all stations
if gage_id_lst is None:
print("No station list specified, will read all stations...")
gage_id_lst = self.read_object_ids().tolist()
# Convert station IDs to strings (ensure consistency)
gage_id_lst = [str(gid) for gid in gage_id_lst]
# Iterate through batch files to find batches containing required stations
relevant_datasets = []
for batch_file in batch_files:
try:
# First open only coordinates, don't load data
ds_batch = xr.open_dataset(batch_file)
batch_basins = [str(b) for b in ds_batch.basin.values]
# Check if this batch contains required stations
common_basins = list(set(gage_id_lst) & set(batch_basins))
if common_basins:
print(
f"Batch {os.path.basename(batch_file)}: contains {len(common_basins)} required stations"
)
# Check if variables exist
missing_vars = [
v for v in target_vars_to_fetch if v not in ds_batch.data_vars
]
if missing_vars:
ds_batch.close()
raise ValueError(
f"Batch {os.path.basename(batch_file)} missing variables: {missing_vars}"
)
# Select variables and stations
ds_subset = ds_batch[target_vars_to_fetch]
ds_selected = ds_subset.sel(
basin=common_basins, time=slice(t_range[0], t_range[1])
)
relevant_datasets.append(ds_selected)
ds_batch.close()
else:
ds_batch.close()
except Exception as e:
print(f"Failed to read batch file {batch_file}: {e}")
continue
if not relevant_datasets:
raise ValueError(
f"Specified stations not found in any batch files: {gage_id_lst}"
)
print(f"Reading data from {len(relevant_datasets)} batches...")
# Merge data from all relevant batches
if len(relevant_datasets) == 1:
final_ds = relevant_datasets[0]
else:
final_ds = xr.concat(relevant_datasets, dim="basin")
# Rename to standard variable names
final_ds = final_ds.rename(rename_map)
# Ensure stations are arranged in input order
if len(gage_id_lst) > 0:
# Only select actually existing stations
existing_basins = [b for b in gage_id_lst if b in final_ds.basin.values]
if existing_basins:
final_ds = final_ds.sel(basin=existing_basins)
return final_ds
|