# -*- coding: utf-8 -*-
"""
Created on Thu 16 Dec 2021
@author: Alexandre Kenshilik Coche
@contact: alexandre.co@hotmail.fr
This module is a collection of tools for manipulating hydrological space-time
data, especially netCDF data. It has been originally developped to provide
preprocessing tools for CWatM (https://cwatm.iiasa.ac.at/) and HydroModPy
(https://gitlab.com/Alex-Gauvain/HydroModPy), but most functions have been
designed to be of general use.
"""
#%% IMPORTS
import logging
logging.basicConfig(level=logging.ERROR) # DEBUG < INFO < WARNING < ERROR < CRITICAL
logger = logging.getLogger(__name__)
import xarray as xr
xr.set_options(keep_attrs = True)
# import rioxarray as rio # Not necessary, the rio module from xarray is enough
import json
import pandas as pd
from pandas.errors import (ParserError as pd_ParserError)
import geopandas as gpd
import numpy as np
import rasterio
import rasterio.features
from affine import Affine
from shapely.geometry import (mapping, Point, Polygon, MultiPolygon)
import shapely.geometry as sg
import os
import re
import sys
from functools import partial
import gc # garbage collector
from pathlib import Path
from typing import Union, Optional, List, Dict, Tuple, Any
import datetime
import importlib.metadata
# import matplotlib.pyplot as plt
from pysheds.grid import Grid
from pysheds.view import Raster, ViewFinder
# ========= change since version 0.5 ==========================================
# from pysheds.pgrid import Grid as pGrid
# =============================================================================
from geop4th import internal
# ========== see reproject() §Rasterize ======================================
# import geocube
# from geocube.api.core import make_geocube
# from geocube.rasterize import rasterize_points_griddata, rasterize_points_radial
# import functools
# =============================================================================
# import whitebox
# wbt = whitebox.WhiteboxTools()
# wbt.verbose = False
#%% LEGENDE:
# ---- ° = à garder mais mettre à jour
# ---- * = à inclure dans une autre fonction ou supprimer
#%% LIST ALL FUNCTIONALITIES
def available():
# unfinished functions
excluded_funcs = {
'cell_area',
'xr_to_pd',
'tss_to_dataframe',
'compute_Erefs_from_Epan',
'compute_wind_speed',
'compute_relative_humidity',
'convert_downwards_radiation',
'transform_tif',
'transform_nc',
'agg_func',
'dummy_input',
'hourly_to_daily_old',
}
# aliases and partials
add_funcs = {
'reproject',
'convert',
'clip',
'rasterize'}
internal.available(__name__,
ignore = excluded_funcs,
add = add_funcs)
#%% NETCDF METADATA AND DIRECTORY UTILITIES
###############################################################################
#%% LOADING & INITIALIZING DATASETS
###############################################################################
[docs]
def load_any(data,
*, name = None,
decode_coords = 'all',
decode_times = True,
rebuild_time_val = True,
**kwargs):
r"""
This function loads any common spatio-temporal file or variable into a
standard python variable, without the need to think about the file or variable type.
Parameters
----------
data : path (str or pathlib.Path), or variable (xarray.Dataset, xarray.DataArray, geopandas.GeoDataFrame, pandas.DataFrame or numpy.array)
``data`` will be loaded into a standard *GEOP4TH* variable:
- all vector data (GeoPackage, shapefile, GeoJSON) will be loaded as a geopandas.GeoDataFrame
- all raster data (ASCII, GeoTIFF) and netCDF will be loaded as a xarray.Dataset
- other data will be loaded either as a pandas.DataFrame (CSV and JSON) or as a numpy.array (TIFF)
If ``data`` is already a variable, no operation will be executed.
name : str, optional, default None
Name of the main variable for TIFF, GeoTIFF or ASCII files.
decode_coords : bool or {"coordinates", "all"}, default "all"
Controls which variables are set as coordinate variables:
- "coordinates" or True: Set variables referred to in the
``'coordinates'`` attribute of the datasets or individual variables
as coordinate variables.
- "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and
other attributes as coordinate variables.
Only existing variables can be set as coordinates. Missing variables
will be silently ignored.
Despite it is an argument for `xarray.open_dataset`, this argument is explicitely passed outside the
following ``**kwargs`` arguments because the default value in *geop4th* is different from the default
value in `xarray.open_dataset`.
decode_times : bool, default True
If True, decode times encoded in the standard NetCDF datetime format
into datetime objects.
Despite it is an argument for `xarray.open_dataset`, this argument is explicitely passed outside the
following ``**kwargs`` arguments because the default value in *GEOP4TH* is different from the default
value in `xarray.open_dataset`.
rebuild_time_val : bool, default True
If True, infer the time coordinate as a datetime object, from available information.
**kwargs
Optional other arguments passed to ``xarray.open_dataset``, ``pandas.DataFrame.read_csv``,
``pandas.DataFrame.to_csv``, ``pandas.DataFrame.read_json`` or
``pandas.DataFrame.to_json`` function calls.
May contain:
- decode_cf
- sep
- encoding
- force_ascii
- ...
>>> help(xarray.open_dataset)
>>> help(pandas.read_csv)
>>> help(pandas.to_csv)
>>> ...
Returns
-------
data_ds : geopandas.GeoDataFrame, xarray.Dataset, pandas.DataFrame or numpy.array
Data is loaded into a GEOP4TH variable. The type of this variable is
accordingly to the type of data:
- all vector data will be loaded as a geopandas.GeoDataFrame
- all raster data and netCDF will be loaded as a xarray.Dataset
- other data will be loaded either as a pandas.DataFrame (CSV and JSON) or as a numpy.array (TIFF)
Examples
--------
>>> data_ds = geo.load_any(r'D:\data.nc')
>>> data_ds.head()
...
"""
# initialization
kwargs['decode_coords'] = decode_coords
kwargs['decode_times'] = decode_times
# If data is already a variable, the variable will be copied
if isinstance(data, (xr.Dataset, xr.DataArray,
gpd.GeoDataFrame, pd.DataFrame,
gpd.GeoSeries, pd.Series,
np.ndarray)):
data_ds = data.copy()
# If data is a string/path, this file will be loaded into a variable
elif isinstance(data, (str, Path)):
print("\nLoading data...")
if not os.path.isfile(data):
print(f" Err: the path provided is not a file: {data}")
return
else:
extension_src = os.path.splitext(data)[-1]
# Adapt load kwargs:
xarray_args = ['chunks', 'cache', 'decode_cf',
'mask_and_scale', 'decode_times', 'decode_timedelta',
'use_cftime', 'concat_characters', 'decode_coords',
'drop_variables', 'create_default_indexes',
'inline_array', 'chunked_array_type',
'from_array_kwargs', 'backend_kwargs',
# 'engine',
]
pandas_args = ['sep', 'delimiter', 'header', 'names', 'index_col',
'usecols', 'dtype', 'converters',
'true_values', 'false_values', 'skipinitialspace',
'skiprows', 'skipfooter', 'nrows', 'na_values',
'keep_default_na', 'na_filter', 'verbose',
'skip_blank_lines', 'parse_dates',
'infer_datetime_format', 'keep_date_col',
'date_parser', 'date_format', 'dayfirst',
'cache_dates', 'iterator', 'chunksize', 'compression',
'thousands', 'decimal', 'lineterminator', 'quotechar',
'quoting', 'doublequote', 'escapechar', 'comment',
'encoding', 'encoding_errors', 'dialect',
'on_bad_lines', 'delim_whitespace', 'low_memory',
'memory_map', 'float_precision', 'storage_options',
'dtype_backend',
# 'engine',
]
geopandas_args = ['bbox', 'mask', 'columns', 'rows',
# 'engine',
]
# These arguments are only used in pandas.DataFrame.to_csv():
if extension_src != '.csv':
for arg in pandas_args:
if arg in kwargs: kwargs.pop(arg)
# These arguments are only used in pandas.DataFrame.to_json():
if extension_src != '.json':
for arg in ['force_ascii']:
if arg in kwargs: kwargs.pop(arg)
# These arguments are only used in xarray.open_dataset():
if extension_src != '.nc':
for arg in xarray_args:
if arg in kwargs: kwargs.pop(arg)
# These arguments are only used in xarray.open_dataset():
if extension_src not in ['.nc', '.tif', '.asc']:
for arg in ['decode_times']:
if arg in kwargs: kwargs.pop(arg)
# These arguments are only used in geopandas.read_file():
if extension_src not in ['.shp', '.json', '.gpkg']:
for arg in geopandas_args:
if arg in kwargs: kwargs.pop(arg)
if extension_src in ['.shp', '.json', '.gpkg']:
try:
data_ds = gpd.read_file(data, **kwargs)
except: # DataSourceError
try:
data_ds = pd.read_json(data, **kwargs)
except:
data_ds = json.load(open(data, "r"))
print(" Warning: The JSON file could not be loaded as a pandas.DataFrame and was loaded as a dict")
elif os.path.splitext(data)[-1] in ['.csv']:
# Auto-detect separator if not specified
if 'sep' not in kwargs:
print(" _ Warning: loading is quicker when passing a `sep` argument")
for sep_char in ['\t', ',', ';', ' ', '|']:
try:
data_ds = pd.read_csv(data, sep=sep_char, **kwargs)
print(f" _ Info: data was read with sep = {repr(sep_char)}")
break
except pd_ParserError:
continue
else:
print(" _ Error: the `sep` argument to read the CSV file could not be inferred")
logger.exception("Failed to auto-detect CSV separator")
print("\nTry to pass the `sep` argument explicitly to `geobricks.load_any()` (see `help(pandas.read_csv)`)\n")
return
else:
try:
data_ds = pd.read_csv(data, **kwargs)
except pd_ParserError:
logger.exception("")
print("\nTry to pass additional arguments to `geobricks.load_any()` such as column separator `sep` (see `help(pandas.read_csv)`)\n")
return
elif extension_src == '.nc':
try:
with xr.open_dataset(data, **kwargs) as data_ds:
data_ds.load() # to unlock the resource
except:
kwargs['decode_times'] = False
print(" _ decode_times = False")
try:
with xr.open_dataset(data, **kwargs) as data_ds:
data_ds.load() # to unlock the resource
except:
kwargs['decode_coords'] = False
print(" _ decode_coords = False")
with xr.open_dataset(data, **kwargs) as data_ds:
data_ds.load() # to unlock the resource
if rebuild_time_val:
time_coord = main_time_dims(data_ds, all_coords = True, all_vars = True)[0]
if data_ds[time_coord].dtype == float:
print(" _ inferring time axis...")
print(f" . inferred time coordinate is {time_coord}")
units, reference_date = data_ds[time_coord].attrs['units'].split('since')
units = units.replace(' ', '').casefold()
reference_date = reference_date.replace(' ', '').casefold()
if units in ['month', 'months', 'M']:
freq = 'M'
elif units in ['day', 'days', 'D']:
freq = 'D'
start_date = pd.date_range(start = reference_date,
periods = int(data_ds[time_coord][0].values)+1,
freq = freq)[-1]
try:
data_ds[time_coord] = pd.date_range(start = start_date,
periods = data_ds.sizes[time_coord],
freq = freq)
except: # SPECIAL CASE to handle truncated output files (from failed CWatM simulations)
print(' . info: truncated time on data')
data_ds = data_ds.where(data_ds[time_coord]<1e5, drop = True)
data_ds[time_coord] = pd.date_range(start = start_date,
periods = data_ds.sizes[time_coord],
freq = freq)
print(f" . initial time = {pd.to_datetime(data_ds[time_coord])[0].strftime('%Y-%m-%d')} | final time = {pd.to_datetime(data_ds[time_coord])[-1].strftime('%Y-%m-%d')} | units = {units}")
elif extension_src in ['.tif', '.asc']:
with xr.open_dataset(data, **kwargs) as data_ds:
data_ds.load() # to unlock the resource
# ======== autre option =======================================================
# data_ds = rioxarray.open_rasterio(data)
# =============================================================================
if 'band' in data_ds.dims:
if data_ds.sizes['band'] == 1:
data_ds = data_ds.squeeze('band')
data_ds = data_ds.drop('band')
if name is not None:
data_ds = data_ds.rename(band_data = name)
else:
print("Err: `data` input does not exist")
return
# Safeguard
if isinstance(data, (xr.Dataset, xr.DataArray,
gpd.GeoDataFrame, pd.DataFrame,
gpd.GeoSeries, pd.Series,
np.ndarray)):
if len(data_ds) == 0:
print("Err: `data` is empty")
return
# Return
return data_ds
###############################################################################
[docs]
def main_vars(data):
"""
Infer the main data variables in a dataset, or ask the user (in the case of vector datasets).
Parameters
----------
data : path (str or pathlib.Path), or variable (xarray.Dataset, xarray.DataArray, geopandas.GeoDataFrame or pandas.DataFrame)
Data whose main variable names will be retrieved.
Returns
-------
list of str
List of the inferred main data variables.
"""
data_ds = load_any(data)
if isinstance(data_ds, xr.Dataset): # raster
var = list(set(list(data_ds.data_vars)) - set(['x', 'y', 'X','Y', 'i', 'j',
'lat', 'lon',
'spatial_ref',
'LambertParisII',
'bnds', 'time_bnds',
'valid_time', 't', 'time',
'date',
'forecast_reference_time',
'forecast_period']))
# =============================================================================
# if len(var) == 1:
# var = var[0]
# =============================================================================
elif isinstance(data_ds, (xr.DataArray, gpd.GeoSeries, pd.Series)):
var = data_ds.name
if (var is None) | (var == ''):
var = input("Name of the main variable: ")
elif isinstance(data_ds, (gpd.GeoDataFrame, pd.DataFrame)): # vector
if len(data_ds.columns) == 1:
var = [data_ds.columns[0]]
else:
# =============================================================================
# var = data_ds.loc[:, data_ds.columns != 'geometry']
# =============================================================================
print("Name or id of the main data variable: ")
i = 1
for c in data_ds.columns:
print(f" {i}. {c}")
i += 1
col = input("")
if col in data_ds.columns: var = col # selection by name
else: var = data_ds.columns[int(col)-1] # selection by id
elif isinstance(data_ds, pd.Series):
var = data_ds.name
if (var is None) | (var == ''):
var = input("Name of the main variable: ")
# in case var is a single variable, it is still encapsulated into a list,
# for coherence
if not isinstance(var, list):
var = [var]
return var
###############################################################################
[docs]
def main_space_dims(data):
"""
Infer the spatial dimension names in a dataset.
Parameters
----------
data : path (str or pathlib.Path), or variable (xarray.Dataset, xarray.DataArray, geopandas.GeoDataFrame or pandas.DataFrame)
Data whose spatial dimensions will be detected.
Returns
-------
x_var : list of str
Name of the X-axis dimension.
y_var : list of str
Name of the Y-axis dimension.
"""
data_ds = load_any(data)
if isinstance(data_ds, (xr.Dataset, xr.DataArray)):
# This first line is intended to get around casefold differences
data_dims_dict = {d.casefold() : d for d in data_ds.dims if isinstance(d, str)}
x_var = list(set(data_dims_dict).intersection(set(['x', 'lon', 'longitude'])))
y_var = list(set(data_dims_dict).intersection(set(['y', 'lat', 'latitude'])))
data_dict = data_dims_dict
elif isinstance(data_ds, (gpd.GeoDataFrame, pd.DataFrame)):
data_col_dict = {d.casefold() : d for d in data_ds.columns if isinstance(d, str)}
x_var = list(set(data_col_dict).intersection(set(['x', 'lon', 'longitude'])))
y_var = list(set(data_col_dict).intersection(set(['y', 'lat', 'latitude'])))
data_dict = data_col_dict
elif isinstance(data_ds, (gpd.GeoSeries, pd.Series)):
# Look through index(es)
try:
data_col_dict = {d.casefold() : d for d in data_ds.index.names if isinstance(d, str)} # MultiIndex
x_var = list(set(data_col_dict).intersection(set(['x', 'lon', 'longitude'])))
y_var = list(set(data_col_dict).intersection(set(['y', 'lat', 'latitude'])))
except:
data_col_dict = {d.casefold() : d for d in data_ds.index.name if isinstance(d, str)} # Index
x_var = list(set(data_col_dict).intersection(set(['x', 'lon', 'longitude'])))
y_var = list(set(data_col_dict).intersection(set(['y', 'lat', 'latitude'])))
data_dict = data_col_dict
if (len(x_var) == 1):
print("Info: x dimension is encoded in data index")
if (len(y_var) == 1):
print("Info: y dimension is encoded in data index")
if len(x_var) == 0:
x_var = [None]
### Try to infer other candidates
# . based on 2 similar strings with a switch between a 'x' and a 'y'
x_pattern = re.compile("(.*)(x)(.*)")
for d in data_dict:
if len(x_pattern.findall(d)) > 0:
prefix, _, suffix = x_pattern.findall(d)[0]
y_pattern = re.compile(f"({prefix})(y)({suffix})")
for dy in data_dict:
if len(y_pattern.findall(dy)) > 0:
x_var = [prefix + 'x' + suffix]
y_var = [prefix + 'y' + suffix]
if (x_var[0] is None) & (not isinstance(data_ds, gpd.GeoDataFrame)):
print("Warning: no x variable has been detected")
if isinstance(data, (gpd.GeoDataFrame, pd.DataFrame)):
# Look through index(es)
try:
data_col_dict = {d.casefold() : d for d in data_ds.index.names if isinstance(d, str)} # MultiIndex
x_var = list(set(data_col_dict).intersection(set(['x', 'lon', 'longitude'])))
y_var = list(set(data_col_dict).intersection(set(['y', 'lat', 'latitude'])))
except:
data_col_dict = {d.casefold() : d for d in data_ds.index.name if isinstance(d, str)} # Index
x_var = list(set(data_col_dict).intersection(set(['x', 'lon', 'longitude'])))
y_var = list(set(data_col_dict).intersection(set(['y', 'lat', 'latitude'])))
data_dict = data_col_dict
if (len(x_var) == 1):
print("Info: x dimension is encoded in data index")
if (len(y_var) == 1):
print("Info: y dimension is encoded in data index")
elif (len(x_var) == 1) & (isinstance(data_ds, gpd.GeoDataFrame)):
print("Info: a x column has been detected, besides 'geometry'")
elif len(x_var) > 1:
print("Warning: several x variables have been detected")
# safeguard
if len(x_var) == 0:
x_var = [None]
if len(y_var) == 0:
if not isinstance(data_ds, gpd.GeoDataFrame):
print("Warning: no y variable has been detected")
y_var = [None]
elif (len(y_var) == 1) & (isinstance(data_ds, gpd.GeoDataFrame)):
print("Info: a y column has been detected, besides 'geometry'")
elif len(y_var) > 1:
print("Warning: several y variables have been detected")
data_dict.update({None: None})
x_var = [data_dict[v] for v in x_var]
y_var = [data_dict[v] for v in y_var]
return [(x_var[i], y_var[i]) for i in range(0, len(x_var))]
###############################################################################
[docs]
def main_time_dims(data_ds,
all_coords = False,
all_vars = False):
"""
Infer the time dimension and the other main time variables from a dataset.
Parameters
----------
data_ds : xarray.Dataset or geopandas.GeoDataFrame
Data whose time variable(s) will be retrieved.
all_coords : bool, default False
Only used if ``data_ds`` is a xarray variable.
If False, only dimensions are considered as potential time coordinates.
If True, even coordinates not associated to any dimension will be
considered as well as potential time coordinates (along ``dims``).
all_vars : bool, default False
Only used if ``data_ds`` is a xarray variable.
If True, data variables (``data_vars``) will be considered as well
as potential time coordinates (along ``dims``).
Returns
-------
var : list of str
List of potential time coordinate names, the first one being the most relevant.
"""
time_coord_avatars = ['time', 't', 'valid_time',
'forecast_period', 'date',
'time0',
# 'time_bnds',
# 'forecast_reference_time',
]
if isinstance(data_ds, (xr.Dataset, xr.DataArray)):
# Convert xr.DataArray to xr.Dataset because
# next operations are made consistent for xr.Datasets
if isinstance(data_ds, xr.DataArray):
data_ds = data_ds.to_dataset(name = 'data')
data_dims_dict = {d.casefold() : d for d in data_ds.dims}
data_coords_dict = {d.casefold() : d for d in data_ds.coords}
data_vars_dict = {d.casefold() : d for d in data_ds.data_vars}
data_dict = data_dims_dict
data_dict.update(data_coords_dict)
data_dict.update(data_vars_dict)
var = list(set(data_dims_dict).intersection(set(time_coord_avatars)))
if all_coords: # in this case, even non-dim coordinates will be considered as potential time coordinates
var = list(set(var).union(set(data_coords_dict).intersection(set(time_coord_avatars))))
if all_vars: # in this case, even data variables will be considered as potential time coordinates
if isinstance(data_ds, xr.Dataset):
var = list(set(var).union(set(data_vars_dict).intersection(set(time_coord_avatars))))
elif isinstance(data_ds, xr.DataArray):
print("Note: `all_vars` argument is unnecessary with xarray.DataArrays")
var = [data_dict[v] for v in var]
elif isinstance(data_ds, (pd.DataFrame, gpd.GeoDataFrame)):
data_col_dict = {d.casefold() : d for d in data_ds.columns if isinstance(d, str)}
var = list(set(data_col_dict).intersection(set(time_coord_avatars)))
var = [data_col_dict[v] for v in var]
elif isinstance(data_ds, (pd.Series, gpd.GeoSeries)):
try:
data_col_dict = {d.casefold() : d for d in data_ds.index.names} # MultiIndex
except:
data_col_dict = {d.casefold() : d for d in data_ds.index.name} # Index
var = list(set(data_col_dict).intersection(set(time_coord_avatars)))
var = [data_col_dict[v] for v in var]
# =============================================================================
# if len(var) == 1:
# var = var[0]
# =============================================================================
if len(var) > 1:
# If there are several time coordinate candidates, the best option will
# be put in first position. The best option is determined via a series
# of rules:
candidates = []
if isinstance(data_ds, (xr.Dataset, xr.DataArray)):
# Only 1D datetime variables will be considered
for v in var:
if np.issubdtype(data_ds[v], np.datetime64):
if len(data_ds[v].dims) == 1:
candidates.append(v)
# The first remaining candidate with the largest number of values will
# be selected
coords_length = {data_ds[v].size:v for v in candidates}
first_var = coords_length[max(coords_length.keys())]
elif isinstance(data_ds, (pd.DataFrame, gpd.GeoDataFrame)):
# Only datetime variables will be considered
for v in var:
if np.issubdtype(data_ds[v], np.datetime64):
candidates.append(v)
# The first remaining candidate will be selected
first_var = candidates[0]
var.pop(var.index(first_var))
var.insert(0, first_var)
if len(var) == 0:
var = [None]
return var
###############################################################################
[docs]
def get_filelist(data,
*, extension = None,
tag = ''):
"""
This function extract from a folder (or a file) a list of relevant files.
Parameters
----------
data: path (str or pathlib.Path) or list of paths (str or pathlib.Path)
Folder, filepath or iterable of filepaths
extension: str, optional
Only the files with this extension will be retrieved.
tag: str, optional
Only the files containing this tag in their names will be retrieved.
Returns
-------
data_folder : str
Root of the files.
filelist : list of str
List of selected file names.
"""
# if extension[0] == '.': extension = extension[1:]
if isinstance(extension, str):
if extension[0] != '.': extension = '.' + extension
# ---- Data is a single element
# if data is a single string/path
if isinstance(data, (str, Path)):
# if this string points to a folder
if os.path.isdir(data):
data_folder = data
if extension is not None:
filelist = [f for f in os.listdir(data_folder)
if ( (os.path.isfile(os.path.join(data_folder, f))) \
& (os.path.splitext(os.path.join(data_folder, f))[-1] == extension) \
& (len(re.compile(f".*({tag}).*").findall(f)) > 0) )]
else:
filelist = [f for f in os.listdir(data_folder)
if ( (os.path.isfile(os.path.join(data_folder, f))) \
& (len(re.compile(f".*({tag}).*").findall(f)) > 0) )]
# if this string points to a file
elif os.path.isfile(data):
data_folder = os.path.split(data)[0] # root of the file
filelist = [data]
# ---- Data is an iterable
elif isinstance(data, (list, tuple)):
# [Safeguard] It is assumed that data contains an iterable of files
if not os.path.isfile(data[0]):
print("Err: Argument should be a folder, a filepath or a list of filepath")
return
data_folder = os.path.split(data[0])[0] # root of the first element of the list
filelist = list(data)
return data_folder, filelist
###############################################################################
#%%% ° pick_dates_fields
def pick_dates_fields(*, input_file, output_format = 'NetCDF', **kwargs):
"""
% DESCRIPTION:
This function extracts the specified dates or fields from NetCDF files that
contain multiple dates or fields, and exports it as a single file.
% EXAMPLE:
import geoconvert as gc
gc.pick_dates_fields(input_file = r"D:/path/test.nc",
dates = ['2020-10-15', '2021-10-15'])
% OPTIONAL ARGUMENTS:
> output_format = 'NetCDF' (default) | 'GeoTIFF'
> kwargs:
> dates = ['2021-10-15', '2021-10-19']
> fields = ['T2M', 'PRECIP', ...]
"""
with xr.open_dataset(input_file) as _dataset:
_dataset.load() # to unlock the resource
#% Get arguments (and build output_name):
# ---------------------------------------
_basename = os.path.splitext(input_file)[0]
# Get fields:
if 'fields' in kwargs:
fields = kwargs['fields']
if isinstance(fields, str): fields = [fields]
else: fields = list(fields) # in case fields are string or tuple
else:
fields = list(_dataset.data_vars) # if not input_arg, fields = all
# Get dates:
if 'dates' in kwargs:
dates = kwargs['dates']
if isinstance(dates, str):
output_file = '_'.join([_basename, dates, '_'.join(fields)])
dates = [dates]
else:
dates = list(dates) # in case dates are tuple
output_file = '_'.join([_basename, dates[0], 'to',
dates[-1], '_'.join(fields)])
else:
dates = ['alldates'] # if not input_arg, dates = all
output_file = '_'.join([_basename, '_'.join(fields)])
#% Standardize terms:
# -------------------
if 't' in list(_dataset.dims):
print('Renaming time coordinate')
_dataset = _dataset.rename(t = 'time')
if 'lon' in list(_dataset.dims) or 'lat' in list(_dataset.dims):
print('Renaming lat/lon coordinates')
_dataset = _dataset.rename(lat = 'latitude', lon = 'longitude')
# Change the order of coordinates to match QGIS standards:
_dataset = _dataset.transpose('time', 'latitude', 'longitude')
# Insert georeferencing metadata to match QGIS standards:
_dataset.rio.write_crs("epsg:4326", inplace = True)
# Insert metadata to match Panoply standards:
_dataset.longitude.attrs = {'units': 'degrees_east',
'long_name': 'longitude'}
_dataset.latitude.attrs = {'units': 'degrees_north',
'long_name': 'latitude'}
if 'X' in list(_dataset.dims) or 'Y' in list(_dataset.dims):
print('Renaming X/Y coordinates')
_dataset = _dataset.rename(X = 'x', Y = 'y')
# Change the order of coordinates to match QGIS standards:
_dataset = _dataset.transpose('time', 'y', 'x')
# Insert metadata to match Panoply standards:
_dataset.x.attrs = {'standard_name': 'projection_x_coordinate',
'long_name': 'x coordinate of projection',
'units': 'Meter'}
_dataset.y.attrs = {'standard_name': 'projection_y_coordinate',
'long_name': 'y coordinate of projection',
'units': 'Meter'}
# =============================================================================
# # Rename coordinates (ancienne version):
# try:
# _dataset.longitude
# except AttributeError:
# _dataset = _dataset.rename({'lon':'longitude'})
# try:
# _dataset.latitude
# except AttributeError:
# _dataset = _dataset.rename({'lat':'latitude'})
# try:
# _dataset.time
# except AttributeError:
# _dataset = _dataset.rename({'t':'time'})
# =============================================================================
#% Extraction and export:
# -----------------------
# Extraction of fields:
_datasubset = _dataset[fields]
# Extraction of dates:
if dates != 'alldates':
_datasubset = _datasubset.sel(time = dates)
if output_format == 'NetCDF':
_datasubset.attrs = {'Conventions': 'CF-1.6'} # I am not sure...
# Export:
_datasubset.to_netcdf(output_file + '.nc')
elif output_format == 'GeoTIFF':
_datasubset.rio.to_raster(output_file + '.tiff')
#%% EXPORT
###############################################################################
[docs]
def export(data,
output_filepath,
**kwargs):
r"""
Export any geospatial dataset (file or GEOP4TH variable) to a file. Note that
if the export implies a rasterization or a vectorization, it will not be handled
by this function. It is necessary instead to use the :func:`rasterize` function
(or its related super-function :func:`transform`). Vectorization is not yet
implemented in GEOP4TH.
Parameters
----------
data : path (str or pathlib.Path), or variable (xarray.Dataset, xarray.DataArray, geopandas.GeoDataFrame, pandas.DataFrame or numpy.array)
Dataset that will be exported to ``output_filepath``.
Note that ``data`` will be loaded into a standard *GEOP4TH* variable:
- all vector data (GeoPackage, shapefile, GeoJSON) will be loaded as a geopandas.GeoDataFrame
- all raster data (ASCII, GeoTIFF) and netCDF will be loaded as a xarray.Dataset
- other data will be loaded either as a pandas.DataFrame (CSV and JSON) or as a numpy.array (TIFF)
output_filepath : str or Path
Full filepath (must contains location folder, name and extension) of
the file to be exported. For instance: r"D:\results\exportedData.tif"
**kwargs :
Additional arguments that can be passed to geopandas.GeoDataFrame.to_file(),
xarray.Dataset.to_netcdf(), xarray.Dataset.rio.to_raster(),
pandas.DataFrame.to_csv() or pandas.DataFrame.to_json(), depending of
the specified file extension.
Returns
-------
None. The data is exported to the specified file.
"""
extension_dst = os.path.splitext(output_filepath)[-1]
data_ds = load_any(data, decode_times = True, decode_coords = 'all')
# Safeguards
# These arguments are only used in pandas.DataFrame.to_csv():
if extension_dst != '.csv':
for arg in ['sep', 'encoding']:
if arg in kwargs: kwargs.pop(arg)
# These arguments are only used in pandas.DataFrame.to_json():
if (extension_dst != '.json') & isinstance(data_ds, pd.DataFrame):
for arg in ['force_ascii']:
if arg in kwargs: kwargs.pop(arg)
if isinstance(data_ds, xr.DataArray):
if 'name' in kwargs:
name = kwargs['name']
else:
name = main_vars(data_ds)[0]
data_ds = data_ds.to_dataset(name = name)
print("\nExporting...")
if isinstance(data_ds, (gpd.GeoDataFrame, gpd.GeoSeries)):
if extension_dst in ['.shp', '.json', '.geojson', '.gpkg']:
if data_ds.empty:
print(f" _ Warning: Cannot export empty GeoDataFrame to '{output_filepath}'")
return
data_ds.to_file(output_filepath, **kwargs)
print(f" _ Success: The data has been exported to the file '{output_filepath}'")
elif extension_dst in ['.nc', '.tif']:
print("Err: To convert vector to raster, use geobricks.rasterize() instead")
return
elif extension_dst in ['.csv']:
if data_ds.empty:
print(f" _ Warning: Cannot export empty GeoDataFrame to '{output_filepath}'")
return
data_ds.drop(columns = 'geometry').to_csv(output_filepath, **kwargs)
print(f" _ Success: The data has been exported to the file '{output_filepath}'")
else:
print(f"Err: the extension {extension_dst} is not supported for geopandas data types")
return
elif isinstance(data_ds, xr.Dataset):
# Avoid dtypes incompatibilities
var_list = main_vars(data_ds)
for var in var_list:
if data_ds[var].dtype == int:
if any(pd.isna(val) for val in data_ds[var].encoding.values()) | any(pd.isna(val) for val in data_ds[var].attrs.values()):
data_ds[var] = data_ds[var].astype(float)
print(f"Info: convert '{var}' from `int` to `float` to avoid issues with NaN")
if extension_dst == '.nc':
data_ds.to_netcdf(output_filepath, **kwargs)
elif extension_dst in ['.tif', '.asc']:
if (len(data_ds.dims) + len(data_ds.data_vars)) > 4:
print("Err: the dataset contains 3 dimensions as well as several variables, which is too many to be supported by a GeoTIFF file. Consider exporting to a netCDF file instead.")
return
# If there is a time dimension, time will be expanded as data variables (.tif bands)
if len(data_ds.dims) == 3:
time = main_time_dims(data_ds)[0]
var = var_list[0]
for t in data_ds[time]:
data_ds[t.item().strftime("%Y-%m-%d %H:%M:%S.%f")] = data_ds[var].loc[{'time': t.values}]
data_ds = data_ds.drop_dims(time)
data_ds.rio.to_raster(output_filepath, **kwargs) # recalc_transform = False
else:
print(f"Err: the extension {extension_dst} is not supported for xarray data types")
return
print(f" _ Success: The data has been exported to the file '{output_filepath}'")
elif isinstance(data_ds, pd.DataFrame): # Note: it is important to test this
# condition after gpd.GeoDataFrame because GeoDataFrames are also DataFrames
if extension_dst in ['.json']:
data_ds.to_json(output_filepath, **kwargs)
elif extension_dst in ['.csv']:
data_ds.to_csv(output_filepath, **kwargs)
else:
print(f"Err: the extension {extension_dst} is not supported for pandas data types")
return
print(f" _ Success: The data has been exported to the file '{output_filepath}'")
#%% GEOREFERENCING
###############################################################################
# Georef (ex-decorate_NetCDF_for_QGIS)
[docs]
def georef(data,
*, crs = None,
to_file = False,
var_list = None,
**time_kwargs):
r"""
Description
-----------
Standardize the metadata required for georeferencing the data:
- standardize spatial dimension names (and attributes for netCDF/rasters)
- standardize the time dimension name and format (and attributes for netCDF/rasters)
- standardize the nodata encoding (for netCDF/rasters): under the key '_FillValue'
in the encodings of the relevant data
- standardize (and include if absent) the CRS: 'grid_mapping' key in the encoding
of the relevant data, and 'spatial_ref' dimensionless coordinate containing CRS info
- transpose the dimensions into the right order
This function corrects the minor format defaults, according to Climate and
Forecast Convention (https://cfconventions.org/conventions.html), thus facilitating
further processing and visualization operations. For most data, these
corrections are enough to solve the issues encountered in visualization softwares
(such as QGIS). If some data require deeper corrections, this should be
done with ``standardize`` scripts (in *geop4th/workflows/standardize* folder).
Parameters
----------
data : path (str or pathlib.Path), or variable (xarray.Dataset, xarray.DataArray, geopandas.GeoDataFrame, pandas.DataFrame or numpy.array)
Data to georeference.
Note that if ``data`` is not a variable, it will be loaded into a standard *GEOP4TH* variable:
- all vector data (GeoPackage, shapefile, GeoJSON) will be loaded as a geopandas.GeoDataFrame
- all raster data (ASCII, GeoTIFF) and netCDF will be loaded as a xarray.Dataset
- other data will be loaded either as a pandas.DataFrame (CSV and JSON) or as a numpy.array (TIFF)
crs : int or str or rasterio.crs.CRS, optional
Coordinate reference system of the source (``data``), that will be embedded in the ``data``.
When passed as an *integer*, ``src_crs`` refers to the EPSG code.
When passed as a *string*, ``src_crs`` can be OGC WKT string or Proj.4 string.
to_file : bool or path (str or pathlib.Path), default False
If True and if ``data`` is a path (str or pathlib.Path), the resulting
dataset will be exported to the same location as ``data``, while appending '_georef' to its name.
If ``to_file`` is a path, the resulting dataset will be exported to this specified filepath.
var_list : (list of) str, optional
Main variables, in case data variables are too excentric to be automatically inferred.
**time_kwargs :
Arguments for ``standardize_time_coord`` function:
- var : time variable name (str), optional, default None
- infer_from : {'dims', 'coords', 'all'}, optional, default 'dims'
Returns
-------
xarray.Dataset or geopandas.GeoDataFrame with a standard georeferencement.
If ``to_file`` argument is used, the resulting dataset can also be exported to a file.
Example
-------
>>> geo.georef(r"<path/to/my/file>", to_file = True)
"""
# ---- Load & initialize
data_ds = load_any(data, decode_times = True, decode_coords = 'all')
if isinstance(data_ds, (xr.Dataset, xr.DataArray)):
if var_list is None:
var_list = main_vars(data_ds)
elif isinstance(var_list, str):
var_list = [var_list]
else:
var_list = list(var_list)
x_var, y_var = main_space_dims(data_ds)[0]
time_coord = main_time_dims(data_ds)[0]
print("\nGeoreferencing data...")
# ---- Standardize spatial coords, time coords, grid mapping and _FillValue
# =============================================================================
# if 'X' in data_ds.coords:
# data_ds = data_ds.rename({'X': 'x'})
# if 'Y' in data_ds.coords:
# data_ds = data_ds.rename({'Y': 'y'})
# if 'latitude' in data_ds.coords:
# data_ds = data_ds.rename({'latitude': 'lat'})
# if 'longitude' in data_ds.coords:
# data_ds = data_ds.rename({'longitude': 'lon'})
# =============================================================================
if x_var == 'X':
data_ds = data_ds.rename({'X': 'x'})
if y_var == 'Y':
data_ds = data_ds.rename({'Y': 'y'})
if y_var == 'latitude':
data_ds = data_ds.rename({'latitude': 'lat'})
if x_var == 'longitude':
data_ds = data_ds.rename({'longitude': 'lon'})
# ====== old standard time handling ===========================================
# if len(time_coord) == 1:
# data_ds = data_ds.rename({time_coord: 'time'})
# =============================================================================
data_ds = standardize_time_coord(data_ds, **time_kwargs)
if isinstance(data_ds, (xr.Dataset, xr.DataArray)):
data_ds = standardize_grid_mapping(data_ds, var_list = var_list)
data_ds, _ = standardize_fill_value(data_ds, var_list = var_list)
# Reorder dimension organization
if time_coord is None:
data_ds = data_ds.transpose(y_var, x_var)
else:
data_ds = data_ds.transpose(time_coord, y_var, x_var)
data_ds.attrs['Conventions'] = 'CF-1.12 (under test)'
### Operations specific to the data type:
# ---------------------------------------
if isinstance(data_ds, (gpd.GeoDataFrame, gpd.GeoSeries)):
# ---- Add CRS to gpd.GeoDataFrames
if crs is not None:
data_ds.set_crs(crs = crs,
inplace = True,
allow_override = True)
# data_ds = standardize_grid_mapping(data_ds, crs)
print(f' _ Coordinates Reference System (epsg:{data_ds.crs.to_epsg()}) included.')
else:
if data_ds.crs is None:
# Auto-detect CRS for lat/lon coordinates
x_dim, y_dim = main_space_dims(data_ds)
if x_dim == 'lon' and y_dim == 'lat':
# Geographic coordinates detected, assume WGS84
data_ds.set_crs(4326, inplace=True, allow_override=True)
print(" _ Automatically detected geographic coordinates, set CRS to WGS84 (EPSG:4326)")
else:
print(" _ Warning: Data contains no CRS. Consider passing the `crs` argument")
elif isinstance(data_ds, xr.Dataset):
# ---- Add CRS to xr.Datasets
if crs is not None:
data_ds.rio.write_crs(crs, inplace = True)
print(f' _ Coordinates Reference System (epsg:{data_ds.rio.crs.to_epsg()}) included.')
else:
if data_ds.rio.crs is None:
# Auto-detect CRS for lat/lon coordinates
x_dim, y_dim = main_space_dims(data_ds)
if x_dim == 'lon' and y_dim == 'lat':
# Geographic coordinates detected, assume WGS84
data_ds.rio.write_crs(4326, inplace=True)
print(" _ Automatically detected geographic coordinates, set CRS to WGS84 (EPSG:4326)")
else:
print(" _ Warning: Data contains no CRS. Consider passing the `crs` argument")
# ---- Add spatial dims attributes to xr.Datasets
if (x_var == 'x') & (y_var == 'y'):
data_ds.x.attrs = {'standard_name': 'projection_x_coordinate',
'long_name': 'x coordinate of projection',
'units': 'm'}
data_ds.y.attrs = {'standard_name': 'projection_y_coordinate',
'long_name': 'y coordinate of projection',
'units': 'm'}
print(" _ Standard attributes added for coordinates x and y")
elif (x_var == 'lon') & (y_var == 'lat'):
data_ds.lon.attrs = {'standard_name': 'longitude',
'long_name': 'longitude',
'units': 'degree_east'}
data_ds.lat.attrs = {'standard_name': 'latitude',
'long_name': 'latitude',
'units': 'degree_north'}
print(" _ Standard attributes added for coordinates lat and lon")
# ---- Remove statistics
for var in var_list:
for optional_attrs in ['AREA_OR_POINT', 'STATISTICS_MAXIMUM',
'STATISTICS_MEAN', 'STATISTICS_MINIMUM',
'STATISTICS_STDDEV', 'STATISTICS_VALID_PERCENT']:
if isinstance(data_ds, xr.Dataset):
if optional_attrs in data_ds[var].attrs:
data_ds[var].attrs.pop(optional_attrs)
elif isinstance(data_ds, xr.DataArray):
if optional_attrs in data_ds.attrs:
data_ds.attrs.pop(optional_attrs)
# ---- Add CRS to xr.Datasets
if crs is not None:
data_ds.rio.write_crs(crs, inplace = True)
print(f' _ Coordinates Reference System (epsg:{data_ds.rio.crs.to_epsg()}) included.')
else:
if data_ds.rio.crs is None:
print(" _ Warning: Data contains no CRS. Consider passing the `crs` argument")
# =============================================================================
# else:
# data_ds.rio.write_crs(data_ds.rio.crs, inplace = True) # rewrite the crs
# =============================================================================
# ---- Export
if to_file == True:
if isinstance(data, (str, Path)):
print('\nExporting...')
export_filepath = '_'.join([os.path.splitext(data)[0], "_georef"]) + os.path.splitext(data)[-1]
export(data_ds, export_filepath)
else:
# =============================================================================
# print(" _ As data input is not a file, the result is exported to a standard directory")
# output_file = os.path.join(os.getcwd(), f"{'_'.join(['data', 'georef', crs_suffix])}.nc")
# =============================================================================
print("Warning; `data` should be a path (str or pathlib.Path) for using `to_file=True`.")
elif isinstance(to_file, (str, Path)):
print('\nExporting...')
export(data_ds, to_file)
# ---- Return variable
return data_ds
# =========================================================================
#% Mémos / Corrections bugs
# =========================================================================
# Si jamais il y a un problème de variable qui se retrouve en 'data_var'
# au lieu d'etre en 'coords' :
# data_ds = data_ds.set_coords('i')
# S'il y a un problème d'incompatibilité 'missing_value' / '_FillValue' :
# data_ds['lon'] = data_ds.lon.fillna(np.nan)
# data_ds['lat'] = data_ds.lat.fillna(np.nan)
# Si jamais une variable non essentielle pose problème à l'export :
# data_ds = data_ds.drop('lon')
# data_ds = data_ds.drop('lat')
# Pour trouver les positions des valeurs nan :
# np.argwhere(np.isnan(data_ds.lon.values))
# Pour reconvertir la date
# units, reference_date = ds.time.attrs['units'].split('since')
# ds['time'] = pd.date_range(start = reference_date,
# periods = ds.sizes['time'], freq = 'MS')
# =========================================================================
# Créer les coordonnées 'x' et 'y'...
# =============================================================================
# # ... à partir des lon.lat :
# # LAISSÉ TOMBÉ, PARCE QUE LEURS VALEURS DE LATITUDE SONT FAUSSES [!]
# coords_xy = rasterio.warp.transform(rasterio.crs.CRS.from_epsg(4326),
# rasterio.crs.CRS.from_epsg(27572),
# np.array(data_ds.lon).reshape((data_ds.lon.size), order = 'C'),
# np.array(data_ds.lat).reshape((data_ds.lat.size), order = 'C'))
#
# # data_ds['x'] = np.round(coords_xy[0][0:data_ds.lon.shape[1]], -1)
#
# # Arrondi à la dizaine à cause de l'approx. initiale sur les lat/lon :
# x = np.round(coords_xy[0], -1).reshape(data_ds.lon.shape[0],
# data_ds.lon.shape[1],
# order = 'C')
# # donne un motif qui se répète autant de fois qu'il y a de latitudes
# y = np.round(coords_xy[1], -1).reshape(data_ds.lon.shape[1],
# data_ds.lon.shape[0],
# order = 'F')
# # donne un motif qui se répète autant de fois qu'il y a de longitudes
#
# # data_ds['x'] = x[0,:]
# # data_ds['y'] = y[0,:]
# data_ds = data_ds.assign_coords(x = ('i', x[0,:]))
# data_ds = data_ds.assign_coords(y = ('j', y[0,:]))
# =============================================================================
###############################################################################
[docs]
def standardize_time_coord(data,
*, var = None,
infer_from = 'dims',
to_file = False):
"""
Use a standard time variable as the temporal coordinate.
Standardize its names into 'time'. If not the main time coordinate, swap
it with the main time coordinate.
Parameters
----------
data : path (str or pathlib.Path), or variable (xarray.Dataset, xarray.DataArray, geopandas.GeoDataFrame or pandas.DataFrame)
Data whose temporal coordinate should be renamed.
var : str, optional
Variable to rename into 'time'. If not specified, the variable that will
be renamed into 'time' will be inferred from the detected time coordinate(s).
infer_from : {'dims', 'coords', 'all'}, default 'dims'
Only used for xarray variables.
To specify if the time coordinate should be infered from dimensions,
coordinates or all variables (coordinates and data variables).
to_file : bool or path (str or pathlib.Path), default False
If True and if ``data`` is a path (str or pathlib.Path), the resulting
dataset will be exported to a file with the same pathname and the
suffix '_std_time'. If ``to_file`` is a path, the resulting dataset
will be exported to this specified filepath.
Returns
-------
geopandas.GeoDataFrame, xarray.Dataset or pandas.DataFrame
Data with the modified name for the temporal coordinate.
The variable type will be accordingly to the variable type of input ``data``.
If ``to_file`` argument is used, the resulting dataset can also be exported to a file.
"""
# =============================================================================
# # Rename 'valid_time' into 'time' (if necessary)
# for time_avatar in ['valid_time', 'date']:
# if isinstance(data_ds, (xr.Dataset, xr.DataArray)):
# if ((time_avatar in data_ds.coords) | (time_avatar in data_ds.data_vars)) \
# & ('time' not in data_ds.coords) & ('time' not in data_ds.data_vars):
# data_ds = data_ds.rename({time_avatar: 'time'})
#
# elif isinstance(data_ds, (pd.DataFrame, gpd.GeoDataFrame)):
# if (time_avatar in data_ds.columns) & ('time' not in data_ds.columns):
# data_ds = data_ds.rename(columns = {time_avatar: 'time'})
#
#
# =============================================================================
print("Standardizing time dimension...")
data_ds = load_any(data)
if isinstance(data_ds, xr.Dataset):
if ('time' in data_ds.data_vars) | ('time' in data_ds.coords):
data_ds = data_ds.rename(time = 'time0')
print(" _ A variable 'time' was already present and was renammed to 'time0'")
elif isinstance(data_ds, (pd.DataFrame, gpd.GeoDataFrame)): # Note: gpd.GeoDataFrame are also pd.DataFrames
if 'time' in data_ds.columns:
data_ds = data_ds.rename(columns = {'time': 'time0'})
print(" _ A variable 'time' was already present and was renammed to 'time0'")
elif isinstance(data_ds, xr.DataArray):
if 'time' in data_ds.coords:
data_ds = data_ds.rename(time = 'time0')
print(" _ A variable 'time' was already present and was renammed to 'time0'")
if infer_from == 'dims':
time_dims = main_time_dims(data_ds)
time_coords = main_time_dims(data_ds)
elif infer_from == 'coords':
time_dims = main_time_dims(data_ds)
time_coords = main_time_dims(data_ds, all_coords = True)
elif infer_from == 'all':
time_dims = main_time_dims(data_ds)
time_coords = main_time_dims(data_ds, all_coords = True, all_vars = True)
# ========= Obsolete ==========================================================
# if isinstance(time_dims, str): time_dims = [time_dims]
# if isinstance(time_coords, str): time_coords = [time_coords]
# =============================================================================
if var is not None:
# Rename the var specified by user into 'time'
new_tvar = var
else:
# Rename the time coord into 'time'
if len(time_coords) > 1:
print(" _ Warning: Several time dimension candidates has been identified: {', '.join(time_coords)'}. {time_coords[0]} is taken as the most likely time dimension. Consider using `var` argument.")
new_tvar = time_coords[0]
else:
if time_coords != [None]:
new_tvar = time_coords[0]
else: # safeguard
print(" _ Warning: No time dimension has been identified. Consider using `infer_from = 'coords'` or `infer_from = 'all'` arguments.")
return data_ds
if isinstance(data_ds, (xr.Dataset, xr.DataArray)):
data_ds = data_ds.rename({new_tvar: 'time'})
print(f" _ The variable '{new_tvar}' has been renamed into 'time'")
# In the case of xarray variables, if the user-specified var is
# not a dim, the function will try to swap it with the time dim
if new_tvar not in time_dims:
for d in time_dims:
# Swap dims with the first dimension that has the same
# length as 'time'
if data_ds['time'].size == data_ds.sizes[d]:
data_ds = data_ds.swap_dims({d: 'time'})
print(f" _ The new variable 'time' (prev. '{new_tvar}') has been swaped with the dimension '{d}'")
break
else:
print(r" _ Warning: The new variable 'time' (prev. '{new_tvar}') is not a dimension, and no current dimension has been found to match. Consider trying `infer_from = 'coords'` or `infer_from = 'all'` arguments")
elif isinstance(data_ds, (pd.DataFrame, gpd.GeoDataFrame)):
data_ds = data_ds.rename(columns = {new_tvar: 'time'})
print(f" _ The variable '{new_tvar}' has been renamed into 'time'")
# =============================================================================
# if not infer:
# if isinstance(time_coord, str):
# if isinstance(data_ds, (xr.Dataset, xr.DataArray)):
# data_ds = data_ds.rename({time_coord: 'time'})
# elif isinstance(data_ds, (pd.DataFrame, gpd.GeoDataFrame)):
# data_ds = data_ds.rename(columns = {time_coord: 'time'})
# elif isinstance(time_coord, list):
# print("Warning: Time could not be standardized because there are several time coordinate candidates. Consider passing the argument 'infer = True' in ghc.standardize_time_coord()")
#
# else:
# if isinstance(time_coord, list):
# time_coord = time_coord[0]
# if isinstance(data_ds, (xr.Dataset, xr.DataArray)):
# time_coord_avatars = ['t', 'time', 'valid_time',
# 'forecast_period', 'date',
# ]
# time_vars = list(set(list(data_ds.data_vars)).intersection(set(time_coord_avatars)))
#
# elif isinstance(data_ds, (pd.DataFrame, gpd.GeoDataFrame)):
# data_ds = data_ds.rename(columns = {time_coord: 'time'})
#
# # Swap the coordinate (if necessary)
# if time_coord != []:
# if time_coord != 'time':
# data_ds = data_ds.swap_dims({time_coord: 'time'})
# =============================================================================
# Make sure the time variable is a datetime
if not np.issubdtype(data_ds.time, (np.datetime64)):
try: data_ds['time'] = pd.to_datetime(data_ds['time'])
except: print(f" _ Warning: New 'time' variable (prev. '{new_tvar}') could not be converted into datetime dtype. Consider using `infer_from = 'coords'` or `infer_from = 'all'` arguments.")
# Standardize attrs
if isinstance(data_ds, (xr.Dataset, xr.DataArray)):
data_ds['time'].attrs['standard_name'] = 'time'
# ======== Necessary or not? How to define the reference datetime? ============
# data_ds['time'].attrs['units'] = 'days since 1970-01-01'
# data_ds['time'].attrs['calendar'] = 'gregorian'
# =============================================================================
# Export
if to_file == True:
if isinstance(data, (str, Path)):
print('\nExporting...')
export_filepath = os.path.splitext(data)[0] + "_std_time" + os.path.splitext(data)[-1]
export(data_ds, export_filepath)
else:
print("Warning; `data` should be a path (str or pathlib.Path) for using `to_file=True`.")
elif isinstance(to_file, (str, Path)):
print('\nExporting...')
export(data_ds, to_file)
return data_ds
###############################################################################
[docs]
def standardize_grid_mapping(data,
crs = None,
to_file = False,
var_list = None):
"""
Some visualization softwares (such as GIS) need a standard structure
for `grid_mapping` information in netCDF datasets:
- `grid_mapping` info should be in `encodings` and not in `attrs`
- ...
This function standardizes `grid_mapping` information, so that it is
compatible with such softwares as QGIS.
Parameters
----------
data : path (str or pathlib.Path), or variable (xarray.Dataset, xarray.DataArray)
Dataset (netCDF or xarray variable) whose nodata information will be standardized.
Note that ``data`` will be loaded into a xarray.Dataset or xarray.DataArray.
crs : int or str or rasterio.crs.CRS, optional
Coordinate reference system of the source (``data``).
When passed as an *integer*, ``src_crs`` refers to the EPSG code.
When passed as a *string*, ``src_crs`` can be OGC WKT string or Proj.4 string.
to_file : bool or path (str or pathlib.Path), default False
If True and if ``data`` is a path (str or pathlib.Path), the resulting
dataset will be exported to a file with the same pathname and the
suffix '_std_grid_map'. If ``to_file`` is a path, the resulting dataset
will be exported to this specified filepath.
var_list : (list of) str, optional
List of main variables can be specified by user, to avoid any prompt.
Returns
-------
data_ds : xarray.Dataset
Standard *GEOP4TH* variable (xarray.Dataset) with corrected nodata information.
If ``to_file`` argument is used, the resulting dataset can also be exported to a file.
"""
# ---- Load
data_ds = load_any(data)
if not isinstance(data_ds, (xr.Dataset, xr.DataArray)):
print("Warning: the `standardize_grid_mapping` function is only intended for netCDF datasets")
return data_ds
# Get main variable
if var_list is None:
var_list = main_vars(data_ds)
elif isinstance(var_list, str):
var_list = [var_list]
else:
var_list = list(var_list)
# Handle xr.DataArrays
if isinstance(data_ds, xr.DataArray):
if data_ds.name is not None:
data_ds = data_ds.to_dataset()
else:
data_ds = data_ds.to_dataset(name = var_list[0])
is_da = True
elif isinstance(data_ds, xr.Dataset):
is_da = False
# ---- Get the potential names of grid_mapping variable and clean all
# grid_mapping information
# Remove all the metadata about grid_mapping, and save grid_mapping names
names = set()
for var in var_list[::-1]:
if 'spatial_ref' in data_ds.data_vars:
names.add('spatial_ref')
if 'spatial_ref' in data_ds.coords: # standard case
names.add('spatial_ref')
if 'grid_mapping' in data_ds.attrs:
names.add(data_ds.attrs['grid_mapping'])
data_ds.attrs.pop('grid_mapping')
if "grid_mapping" in data_ds.encoding:
names.add(data_ds.encoding['grid_mapping'])
data_ds.encoding.pop('grid_mapping')
if 'grid_mapping' in data_ds[var].attrs:
names.add(data_ds[var].attrs['grid_mapping'])
data_ds[var].attrs.pop('grid_mapping')
if "grid_mapping" in data_ds[var].encoding:
names.add(data_ds[var].encoding['grid_mapping'])
data_ds[var].encoding.pop('grid_mapping')
# Drop all variable or coord corresponding to the previously founded
# grid_mapping names
temp = None
for n in list(names):
if n in data_ds.data_vars:
temp = data_ds[n]
data_ds = data_ds.drop(n)
if n in data_ds.coords:
temp = data_ds[n]
data_ds = data_ds.reset_coords(n, drop = True)
if crs is None:
# Use the last grid_mapping value as the standard spatial_ref
dummy_crs = 2154
data_ds.rio.write_crs(dummy_crs, inplace = True) # creates the spatial_ref structure and mapping
data_ds['spatial_ref'] = temp
else:
data_ds.rio.write_crs(crs, inplace = True)
# ---- Export
if to_file == True:
if isinstance(data, (str, Path)):
print('\nExporting...')
export_filepath = os.path.splitext(data)[0] + "_std_grid_map" + os.path.splitext(data)[-1]
export(data_ds, export_filepath)
else:
print("Warning; `data` should be a path (str or pathlib.Path) for using `to_file=True`.")
elif isinstance(to_file, (str, Path)):
print('\nExporting...')
export(data_ds, to_file)
if is_da:
data_ds = data_ds[var_list[0]]
return data_ds
###############################################################################
[docs]
def standardize_fill_value(data, *,
var_list = None,
attrs = None,
encod = None,
to_file = False):
"""
Standardize the way the nodata value (fill value) is encoded in a netCDF dataset.
In netCDF, several ways of embedding nodata information can be used ('_Fillvalue'
or 'missing_value', in attributes or in encodings). Sometimes multiple
embeddings are stored in the same dataset. When several nodata information
are present in the same dataset, this function
infers the most relevant one and removes the others. In the end, the relevant nodata
information will be stored as a '_FillValue' encoding only.
Parameters
----------
data : path (str or pathlib.Path), or variable (xarray.Dataset, xarray.DataArray)
Dataset (netCDF or xarray variable) whose nodata information will be standardized.
Note that ``data`` will be loaded into a xarray.Dataset or xarray.DataArray.
var_list : (list of) str, optional
Used to specify if only one data variable has to be standardized. Otherwise,
the nodata information will be standardized for all data variables.
attrs : dict, optional
If the nodata information is present in an `attrs` dict dissociated from
the dataset, it can be passed here.
encod : dict, optional
If the nodata information is present in an `encoding` dict dissociated from
the dataset, it can be passed here.
to_file : bool or path (str or pathlib.Path), default False
If True and if ``data`` is a path (str or pathlib.Path), the resulting
dataset will be exported to a file with the same pathname and the
suffix '_std_fill_val'. If ``to_file`` is a path, the resulting dataset
will be exported to this specified filepath.
Returns
-------
data_ds : xarray.Dataset
Standard *GEOP4TH* variable (xarray.Dataset) with corrected nodata information.
nodata : numeric
No-data value.
If ``to_file`` argument is used, the resulting dataset can also be exported to a file.
"""
data_ds = load_any(data)
if not isinstance(data_ds, (xr.Dataset, xr.DataArray)):
print("Warning: the `standardize_fill_value` function is only intended for netCDF datasets")
return data_ds, None
# Initializations
if var_list is None:
var_list = main_vars(data_ds)
elif isinstance(var_list, str):
var_list = [var_list]
elif isinstance(var_list, list):
var_list = var_list
if isinstance(data_ds, xr.Dataset):
for var in var_list:
if attrs is None:
attrs = data_ds[var].attrs
if encod is None:
encod = data_ds[var].encoding
# Clean all fill_value info
if '_FillValue' in data_ds[var].attrs:
data_ds[var].attrs.pop('_FillValue')
if 'missing_value' in data_ds[var].attrs:
data_ds[var].attrs.pop('missing_value')
# Set the fill_value, according to a hierarchical rule
if '_FillValue' in encod:
nodata = encod['_FillValue']
data_ds[var].encoding['_FillValue'] = nodata
elif '_FillValue' in attrs:
nodata = attrs['_FillValue']
data_ds[var].encoding['_FillValue'] = nodata
elif 'missing_value' in encod:
nodata = encod['missing_value']
data_ds[var].encoding['_FillValue'] = nodata
elif 'missing_value' in attrs:
nodata = attrs['missing_value']
data_ds[var].encoding['_FillValue'] = nodata
else:
if data_ds[var].dtype == bool:
nodata = False
else:
nodata = np.nan
data_ds[var].encoding['_FillValue'] = nodata
elif isinstance(data_ds, xr.DataArray):
if attrs is None:
attrs = data_ds.attrs
if encod is None:
encod = data_ds.encoding
# Clean all fill_value info
if '_FillValue' in data_ds.attrs:
data_ds.attrs.pop('_FillValue')
if 'missing_value' in data_ds.attrs:
data_ds.attrs.pop('missing_value')
if 'missing_value' in data_ds.attrs:
data_ds.attrs.pop('missing_value')
# Set the fill_value, according to a hierarchical rule
if '_FillValue' in encod:
nodata = encod['_FillValue']
data_ds.encoding['_FillValue'] = nodata
elif '_FillValue' in attrs:
nodata = attrs['_FillValue']
data_ds.encoding['_FillValue'] = nodata
elif 'missing_value' in encod:
nodata = encod['missing_value']
data_ds.encoding['_FillValue'] = nodata
elif 'missing_value' in attrs:
nodata = attrs['missing_value']
data_ds.encoding['_FillValue'] = nodata
else:
nodata = np.nan
data_ds.encoding['_FillValue'] = nodata
# Export
if to_file == True:
if isinstance(data, (str, Path)):
print('\nExporting...')
export_filepath = os.path.splitext(data)[0] + "_std_fill_val" + os.path.splitext(data)[-1]
export(data_ds, export_filepath)
else:
print("Warning; `data` should be a path (str or pathlib.Path) for using `to_file=True`.")
elif isinstance(to_file, (str, Path)):
print('\nExporting...')
export(data_ds, to_file)
return data_ds, nodata
#%% FILE MANAGEMENT
###############################################################################
[docs]
def merge_data(data,
*, extension = None,
tag = '',
flatten = False,
spatial_intersect = False,
**kwargs):
"""
This function merge all files inside a folder.
Parameters
----------
data : (list of) str or Path, or variable (xarray.Dataset, xarray.DataArray, geopandas.GeoDataFrame or pandas.DataFrame)
Iterable of data to merge, or a path to a folder containing data to merge.
In that case, the arguments ``extension`` and ``tag`` can be passed.
When datasets overlap, the last dataset has the highest priority.
extension: str, optional
Only the files with this extension will be retrieved (when ``data`` is a folder path).
tag: str, optional
Only the files containing this tag in their names will be retrieved (when ``data`` is a folder path).
flatten : bool, default False
If True, data will be flattent over the time axis.
spatial_intersect : bool, default False
If True: if datasets have different spatial extent, they will be merged
according to the intersection of their spatial coordinates. If False, the
datasets will be merged according to the union of their spatial coordinates.
Note that for all other dimensions (time, ...), the datasets will always
be merged according to the union of their indexes.
**kwargs
Optional other arguments passed to :func:`geo.load_any`
(arguments for ``xarray.open_dataset``, ``pandas.DataFrame.read_csv``,
``pandas.DataFrame.to_csv``, ``pandas.DataFrame.read_json`` or
``pandas.DataFrame.to_json`` function calls).
May contain:
- decode_cf
- sep
- encoding
- force_ascii
- ...
>>> help(xarray.open_dataset)
>>> help(pandas.read_csv)
>>> help(pandas.to_csv)
>>> ...
Returns
-------
geopandas.GeoDataFrame, xarray.Dataset, pandas.DataFrame or numpy.array
Merged data is stored in a variable whose type is accordingly to the type of data:
- all vector data will be loaded as a geopandas.GeoDataFrame
- all raster data and netCDF will be loaded as a xarray.Dataset
- other data will be loaded either as a pandas.DataFrame (CSV and JSON) or as a numpy.array (TIFF)
"""
# ---- Load file list
# If data is a list of files:
if isinstance(data, (list, tuple)):
# Check if the list is empty
if len(data) == 0:
print("Warning: Empty list provided to merge_data")
return None
# If the list contains paths
if all([isinstance(data[i], (str, Path))
for i in range(0, len(data))]):
data_folder = os.path.split(data[0])[0]
filelist = data
#filelist = [os.path.split(d)[-1] for d in data]
extension = os.path.splitext(filelist[0])[-1]
# If the list contains xarray or geopandas variables
elif all([isinstance(data[i], (xr.Dataset, xr.DataArray,
gpd.GeoDataFrame, pd.DataFrame))
for i in range(0, len(data))]):
data_folder = None
filelist = data
else:
print("Err: Mixed data types")
return
# If data is a folder:
elif isinstance(data, (str, Path)):
if os.path.isdir(data):
data_folder, filename_list = get_filelist(data, extension = extension, tag = tag)
filelist = [os.path.join(data_folder, f) for f in filename_list]
# If data is a single file:
elif os.path.isfile(data):
filelist = [data]
# If data is a xarray or a geopandas variable
elif isinstance(data, (xr.Dataset, xr.DataArray, gpd.GeoDataFrame, pd.DataFrame)):
data_folder = None
filelist = [data]
# if extension[0] == '.': extension = extension[1:]
if isinstance(extension, str):
if extension[0] != '.': extension = '.' + extension
if len(filelist) > 1:
# ---- Load all datasets
print("Loading files...")
ds_list = []
for f in filelist:
ds_list.append(load_any(f, **kwargs))
# Dimension names (based on the first dataset taken as a reference)
x_ref, y_ref = main_space_dims(ds_list[0])[0]
time_ref = main_time_dims(ds_list[0])[0]
# ---- spatial_intersect option
if spatial_intersect:
if time_ref is None:
print("Info: datasets do not contain any time dimension. If you intend to merge datasets according to the union of their spatial dimension, you should pass `spatial_intersect = False`.")
if (x_ref is not None) & (y_ref is not None):
union_x = set(ds_list[0][x_ref].values)
union_y = set(ds_list[0][y_ref].values)
for ds in ds_list[1:]:
x, y = main_space_dims(ds)[0]
union_x = union_x.union(set(ds[x].values))
union_y = union_y.union(set(ds[y].values))
union_x = sorted(union_x)
union_y = sorted(union_y)[::-1]
for i in range(0, len(ds_list)):
x, y = main_space_dims(ds_list[i])[0]
ds_list[i] = ds_list[i].reindex({x: union_x, y: union_y})
print("Merging files...")
# ---- [case] Rasters
if (extension in ['.nc', '.tif', '.asc']) | all([isinstance(ds_list[i], (xr.Dataset, xr.DataArray))
for i in range(0, len(ds_list))]):
# Backup of attributes and encodings
var_list = main_vars(ds_list[0])
# ========== useless ==========================================================
# attrs = ds_list[0][var].attrs.copy()
# =============================================================================
if isinstance(ds_list[0], xr.DataArray):
encod = ds_list[0].encoding
elif isinstance(ds_list[0], xr.Dataset):
encod = {}
for var in var_list:
encod[var] = ds_list[0][var].encoding.copy()
# Merge
# =============================================================================
# if not update_val:
# merged_ds = xr.merge(ds_list) # Note: works only when doublons are identical
# else:
# =============================================================================
# TODO: explore the possibility of using combine_first
merged_ds = ds_list[0]
for ds in ds_list[1:]:
# =============================================================================
# # First, merged_ds is expanded with ds (merged_ds has priority over ds here)
# merged_ds = merged_ds.merge(ds, compat = 'override')
# # Second, non-null values of ds overwrites merge_ds
# merged_ds.loc[{dim: ds[dim].values for dim in merged_ds.dims}] = merged_ds.loc[{dim: ds[dim].values for dim in merged_ds.dims}].where(ds.isnull(), ds)
# =============================================================================
# =============================================================================
# # Instead of this 2-step process, use xarray combine functions:
# # (https://docs.xarray.dev/en/stable/user-guide/combining.html)
# # combine_first, combine_nested, combine_by_coords
# =============================================================================
merged_ds = ds.combine_first(merged_ds)
# =============================================================================
# merged_ds_aligned, _ = xr.align(merged_ds, ds)
# merged_ds_aligned = merged_ds_aligned.where(ds.isnull(), ds)
# =============================================================================
# ========== wrong ============================================================
# merged_ds = xr.concat(ds_list, dim = 'time')
# =============================================================================
# Order y-axis from max to min (because order is altered with merge)
_, y_var = main_space_dims(merged_ds)[0]
merged_ds = merged_ds.sortby(y_var, ascending = False)
if main_time_dims(merged_ds)[0] is not None:
merged_ds = merged_ds.sortby(main_time_dims(merged_ds)[0]) # In case the files are not loaded in time order
# Transferring encodings (_FillValue, compression...)
if isinstance(merged_ds, xr.DataArray):
merged_ds.encoding = encod
elif isinstance(merged_ds, xr.Dataset):
for var in var_list:
merged_ds[var].encoding = encod[var]
return merged_ds
# ---- [case] Vectors
elif (extension in ['.shp', '.json', '.geojson']) | all([isinstance(ds_list[i], gpd.GeoDataFrame)
for i in range(0, len(ds_list))]):
### Option 1: data is flattened over the time axis
if flatten:
# This variable will store the names of the concatenated columns
global varying_columns
varying_columns = []
def agg_func(arg):
global varying_columns
if len(set(arg.values)) == 1:
return arg.values[0]
else:
varying_columns.append(arg.name)
return ', '.join(str(v) for v in arg.values)
# =========== list of texts are not correctly reloaded in python... ===========
# return list(arg.values)
# =============================================================================
c = 0
# =============================================================================
# gdf_list = []
# # ---- Append all gpd.GeoDataFrame into a list
# for f in filelist:
# gdf_list.append(load_any(f))
# print(f" . {f} ({c}/{len(filelist)})")
# c += 1
#
# merged_gdf = pd.concat(gdf_list)
# =============================================================================
f = filelist[c]
merged_gdf = ds_list[c]
if isinstance(f, (str, Path)):
f_text = f
else:
f_text = type(f)
print(f" . {f_text} ({c+1}/{len(filelist)})")
for c in range(1, len(ds_list)):
f = filelist[c]
merged_gdf = merged_gdf.merge(ds_list[c],
how = 'outer',
# on = merged_df.columns,
)
if isinstance(f, (str, Path)):
f_text = f
else:
f_text = type(f)
print(f" . {f_text} ({c+1}/{len(filelist)})")
# ========== previous method ==================================================
# x_var, y_var = main_space_dims(gdf_list[0])[0]
# merged_gdf = merged_gdf.dissolve(by=[x_var, y_var], aggfunc=agg_func)
# # Convert the new index (code_ouvrage) into a column as at the origin
# merged_gdf.reset_index(inplace = True, drop = False)
# =============================================================================
merged_gdf['geometry2'] = merged_gdf['geometry'].astype(str)
merged_gdf = merged_gdf.dissolve(by='geometry2', aggfunc=agg_func)
# Convert the new index (code_ouvrage) into a column as at the origin
merged_gdf.reset_index(inplace = True, drop = True)
varying_columns = list(set(varying_columns))
# Correct the dtypes of the concatenated columns, because fiona does
# not handle list dtypes
merged_gdf[varying_columns] = merged_gdf[varying_columns].astype(str)
return merged_gdf
else: # No flattening
c = 0
# ========= previous method with concat =======================================
# gdf_list = []
# # ---- Append all gpd.GeoDataFrame into a list
# for f in filelist:
# gdf_list.append(load_any(f))
# gdf_list[c]['annee'] = pd.to_datetime(gdf_list[c]['annee'], format = '%Y')
# print(f" . {f} ({c}/{len(filelist)})")
# c += 1
#
# merged_gdf = pd.concat(gdf_list)
# =============================================================================
f = filelist[c]
merged_gdf = ds_list[c]
if isinstance(f, (str, Path)):
f_text = f
else:
f_text = type(f)
print(f" . {f_text} ({c+1}/{len(filelist)})")
for c in range(1, len(ds_list)):
f = filelist[c]
merged_gdf = merged_gdf.merge(ds_list[c],
how = 'outer',
# on = merged_df.columns,
)
if isinstance(f, (str, Path)):
f_text = f
else:
f_text = type(f)
print(f" . {f_text} ({c+1}/{len(filelist)})")
return merged_gdf
# ---- [case] CSV
elif (extension in ['.csv']) | all([isinstance(ds_list[i], pd.DataFrame)
for i in range(0, len(ds_list))]):
# Create index based on dimensions
dim_list = [x_ref, y_ref, time_ref]
dim_list = [dim for dim in dim_list if dim is not None]
for c in range(0, len(ds_list)):
ds_list[c] = ds_list[c].set_index(dim_list, inplace = False)
# Merge
c = 0
f = filelist[c]
merged_df = ds_list[c]
if isinstance(f, (str, Path)):
f_text = f
else:
f_text = type(f)
print(f" . {f_text} ({c+1}/{len(filelist)})")
for c in range(1, len(ds_list)):
f = filelist[c]
merged_df = ds_list[c].combine_first(merged_df)
if isinstance(f, (str, Path)):
f_text = f
else:
f_text = type(f)
print(f" . {f_text} ({c+1}/{len(filelist)})")
# Reset index to columns
merged_df.reset_index(inplace = True)
return merged_df
elif len(filelist) == 1:
print("Warning: Only one file was found.")
return load_any(filelist[0], **kwargs)
elif len(filelist) == 0:
print("Err: No file was found")
return
#%% REPROJECTIONS, CLIP, CONVERSIONS
###############################################################################
###############################################################################
#%%% Other aliases (reproject, convert)
reproject = transform
convert = transform
###############################################################################
#%%% Clip (partial alias)
clip = partial(transform,
base_template = None,
x0 = None,
y0 = None,
rasterize = False,
main_var_list = None,
rasterize_mode = ['mean', 'coverage', 'and'],
# dst_crs = None,
)
clip.__name__ = 'clip(data, *, src_crs=None, bounds=None, mask=None, to_file=False, export_extension=None)'
clip.__doc__ = r"""
Clip space-time data.
:func:`clip` is a **partial alias** of the :func:`transform() <geobricks.transform>` function.
Parameters
----------
data : str, Path, xarray.Dataset, xarray.DataArray or geopandas.GeoDataFrame
Data to transform. Supported file formats are *.tif*, *.asc*, *.nc* and vector
formats supported by geopandas (*.shp*, *.json*, ...).
src_crs : int or str or rasterio.crs.CRS, optional, default None
Coordinate reference system of the source (``data``).
When passed as an *integer*, ``src_crs`` refers to the EPSG code.
When passed as a *string*, ``src_crs`` can be OGC WKT string or Proj.4 string.
bounds : iterable or None, optional, default None
Boundaries of the target domain as a tuple (x_min, y_min, x_max, y_max).
mask : str, Path, shapely.geometry, xarray.DataArray or geopandas.GeoDataFrame, optional, default None
Filepath of mask used to clip the data.
drop : bool, default False
Only applicable for raster/xarray.Dataset types. If True, coordinate labels
that only correspond to NaN values are dropped from the result.
to_file : bool, default True
If True and if ``data`` is a file, the resulting dataset will be exported to a
file with the same name and the suffix '_geop4th'.
export_extension : str, optional
Extension to which the data will be converted and exported. Only used
when the specified ``data`` is a filepath. It ``data`` is a variable
and not a file, it will not be exported.
Returns
-------
Clipped data : xarray.Dataset or geopandas.GeoDataFrame.
The type of the resulting variable is accordingly to the type of input data and to
the conversion operations (such as rasterize):
- all vector data will be output as a geopandas.GeoDataFrame
- all raster data and netCDF will be output as a xarray.Dataset
"""
###############################################################################
#%%% Rasterize (partial alias)
rasterize = partial(transform,
rasterize = True)
rasterize.__name__ = "rasterize(data, *, src_crs=None, base_template=None, bounds=None, x0=None, y0=None, mask=None, to_file=False, export_extension='.tif'', main_var_list=None, rasterize_mode=['sum', 'dominant', 'and'], **rio_kwargs)"
rasterize.__doc__ = r"""
Rasterize vector space-time data.
:func:`rasterize` is a **partial alias** of the :func:`transform() <geobricks.transform>` function.
Parameters
----------
data : str, Path, xarray.Dataset, xarray.DataArray or geopandas.GeoDataFrame
Data to transform. Supported file formats are *.tif*, *.asc*, *.nc* and vector
formats supported by geopandas (*.shp*, *.json*, ...).
src_crs : int or str or rasterio.crs.CRS, optional, default None
Coordinate reference system of the source (``data``).
When passed as an *integer*, ``src_crs`` refers to the EPSG code.
When passed as a *string*, ``src_crs`` can be OGC WKT string or Proj.4 string.
base_template : str, Path, xarray.DataArray or geopandas.GeoDataFrame, optional, default None
Filepath, used as a template for spatial profile. Supported file formats
are *.tif*, *.nc* and vector formats supported by geopandas (*.shp*, *.json*, ...).
bounds : iterable or None, optional, default None
Boundaries of the target domain as a tuple (x_min, y_min, x_max, y_max).
x0: number, optional, default None
Origin of the X-axis, used to align the reprojection grid.
y0: number, optional, default None
Origin of the Y-axis, used to align the reprojection grid.
mask : str, Path, shapely.geometry, xarray.DataArray or geopandas.GeoDataFrame, optional, default None
Filepath of mask used to clip the data.
drop : bool, default False
Only applicable for raster/xarray.Dataset types. If True, coordinate labels
that only correspond to NaN values are dropped from the result.
to_file : bool or path (str or pathlib.Path), default False
If True and if ``data`` is a path (str or pathlib.Path), the resulting
dataset will be exported to a file with the same pathname and the
suffix '_geop4th'. If ``to_file`` is a path, the resulting dataset
will be exported to this specified filepath.
export_extension : str, default '.tif'
Extension to which the data will be converted and exported. Only used
when the specified ``data`` is a filepath. It ``data`` is a variable
and not a file, it will not be exported.
main_var_list : iterable, default None
Data variables to rasterize. Only used if ``rasterize`` is ``True``.
If ``None``, all variables in ``data`` are rasterized.
rasterize_mode : str or list of str, or dict, default ['mean', 'coverage', 'and']
Defines the mode to rasterize data:
- for numeric variables: ``'count'``, ``'sum'`` or ``'mean'``(default)
- ``'mean'`` refers to:
- the sum of polygon values weighted by their relative coverage on each cell,
when the vector data contains Polygons (appropriate for intensive quantities)
- the average value of points on each cell, when the vector data contains Points
- ``'sum'`` refers to:
- the sum of polygon values downscaled to each cell (appropriate for extensive quantities)
- the sum values of points on each cell, when the vector data contains Points
- ``'count'`` refers to:
- the number of points or polygons intersecting each cell
- for categorical variables: ``'fraction'`` or ``'dominant'`` or ``'coverage'``(default)
- ``'coverage'`` refers to:
- the area covered by each level on each cell, when the vector data contains Polygons
- the count of points for each level on each cell, when the vector data contains Points
- ``'dominant'`` rises the most frequent level for each cell
- ``'fraction'`` creates a new variable per level, which stores
the fraction (from 0 to 1) of the coverage of this level compared
to all levels, for each cell.
- for boolean variables: ``'or'`` or ``'and'`` (default)
The modes can be specified for each variable by passing ``rasterize_mode``
as a dict: ``{'<var1>': 'mean', '<var2>': 'percent', ...}``. This argument
specification makes it possible to force a numeric variable to be rasterized
as a categorical variable. Unspecified variables will be rasterized with the default mode.
If `data` contains no variable other than 'geometry', the arbitrary name 'data'
can be used to specify a mode for the whole `data`.
force_polygon : bool, default False
Only Polygon geometry types will be kept when rasterizing.
force_point : bool, default False,
Only Point geometry types will be kept when rasterizing.
**rio_kwargs : keyword args, optional, defaults are None
Argument passed to the ``xarray.Dataset.rio.reproject()`` function call.
**Note**: These arguments are prioritary over ``base_template`` attributes.
May contain:
- ``dst_crs`` : str
- ``resolution`` : float or tuple
- ``shape`` : tuple (int, int) of (height, width)
- ``transform`` : Affine
- ``nodata`` : float or None
- ``resampling`` :
- see ``help(rasterio.enums.Resampling)``
- most common are: ``5`` (average), ``13`` (sum), ``0`` (nearest),
``9`` (min), ``8`` (max), ``1`` (bilinear), ``2`` (cubic)...
- the functionality ``'std'`` (standard deviation) is also available
- see ``help(xarray.Dataset.rio.reproject)``
Returns
-------
Transformed data : xarray.Dataset or geopandas.GeoDataFrame.
The type of the resulting variable is accordingly to the type of input data and to
the conversion operations (such as rasterize):
- all vector data will be output as a geopandas.GeoDataFrame
- all raster data and netCDF will be output as a xarray.Dataset
"""
###############################################################################
#%%% Align on the closest value
def nearest(x = None, y = None, x0 = 700012.5, y0 = 6600037.5, res = 75):
"""
Exemple
-------
import geoconvert as gc
gc.nearest(x = 210054)
gc.nearest(y = 6761020)
Parameters
----------
x : float, optional
Valeur de la coordonnée x (ou longitude). The default is None.
y : float, optional
Valeur de la coordonnée y (ou latitude). The default is None.
Returns
-------
Par défault, cette fonction retourne la plus proche valeur (de x ou de y)
alignée sur la grille des cartes topo IGN de la BD ALTI.
Il est possible de changer les valeurs de x0, y0 et res pour aligner sur
d'autres grilles.
"""
# ---- Paramètres d'alignement :
# ---------------------------
# =============================================================================
# # Documentation Lambert-93
# print('\n--- Alignement d'après doc Lambert-93 ---\n')
# x0 = 700000 # origine X
# y0 = 6600000 # origine Y
# res = 75 # résolution
# =============================================================================
# Coordonnées des cartes IGN BD ALTI v2
if (x0 == 700012.5 or y0 == 6600037.5) and res == 75:
print('\n--- Alignement sur grille IGN BD ALTI v2 (defaut) ---')
closest = []
if x is not None and y is None:
# print('x le plus proche = ')
if (x0-x)%res <= res/2:
closest = x0 - (x0-x)//res*res
elif (x0-x)%res > res/2:
closest = x0 - ((x0-x)//res + 1)*res
elif y is not None and x is None:
# print('y le plus proche = ')
if (y0-y)%res <= res/2:
closest = y0 - (y0-y)//res*res
elif (y0-y)%res > res/2:
closest = y0 - ((y0-y)//res + 1)*res
else:
print('Err: only one of x or y parameter should be passed')
return
return closest
###############################################################################
#%%% Format x_res and y_res
def format_xy_resolution(*, resolution=None, bounds=None, shape=None):
"""
Format x_res and y_res from a resolution value/tuple/list, or from
bounds and shape.
Parameters
----------
resolution : number | iterable, optional
xy_res or (x_res, y_res). The default is None.
bounds : iterable, optional
(x_min, y_min, x_max, y_max). The default is None.
shape : iterable, optional
(height, width). The default is None.
Returns
-------
x_res and y_res
"""
if (resolution is not None) & ((bounds is not None) | (shape is not None)):
print("Err: resolution cannot be specified alongside with bounds or shape")
return
if resolution is not None:
if isinstance(resolution, (tuple, list)):
x_res = abs(resolution[0])
y_res = -abs(resolution[1])
else:
x_res = abs(resolution)
y_res = -abs(resolution)
if ((bounds is not None) & (shape is None)) | ((bounds is None) & (shape is not None)):
print("Err: both bounds and shape need to be specified")
if (bounds is not None) & (shape is not None):
(height, width) = shape
(x_min, y_min, x_max, y_max) = bounds
x_res = (x_max - x_min) / width
y_res = -(y_max - y_min) / height
return x_res, y_res
###############################################################################
#%%% Get shape
def get_shape(x_res, y_res, bounds, x0=0, y0=0):
# bounds should be xmin, ymin, xmax, ymax
# aligne sur le 0, arrondit, et tutti quanti
(x_min, y_min, x_max, y_max) = bounds
x_min2 = nearest(x = x_min, res = x_res, x0 = x0)
if x_min2 > x_min:
x_min2 = x_min2 - x_res
y_min2 = nearest(y = y_min, res = y_res, y0 = y0)
if y_min2 > y_min:
y_min2 = y_min2 - abs(y_res)
x_max2 = nearest(x = x_max, res = x_res, x0 = x0)
if x_max2 < x_max:
x_max2 = x_max2 + x_res
y_max2 = nearest(y = y_max, res = y_res, y0 = y0)
if y_max2 < y_max:
y_max2 = y_max2 + abs(y_res)
width = (x_max2 - x_min2)/x_res
height = -(y_max2 - y_min2)/y_res
if (int(width) == width) & (int(height) == height):
shape = (int(height), int(width))
else:
print(f"Warning: shape values are not integers: ({width}, {height})")
rel_err = (abs((np.rint(width) - width)/np.rint(width)),
abs((np.rint(height) - height)/np.rint(height)))
print(f". errors: ({rel_err[0]*100} %, {rel_err[1]*100} %)")
# Safeguard
if (rel_err[0] > 1e-8) | (rel_err[1] > 1e-8):
print("Error")
shape = None
else:
shape = (int(np.rint(height)), int(np.rint(width)))
return shape, x_min2, y_max2
#%%% Internal rasterization functions
###############################################################################
def vector_grid(*, height,
width,
x_min,
y_max,
x_res,
y_res,
crs):
x_coords = np.arange(x_min, x_min + width*x_res, x_res, dtype = np.float32) + 0.5*x_res # aligned on centroids
y_coords = np.arange(y_max, y_max + height*y_res, y_res, dtype = np.float32) + 0.5*y_res
xy_coords = {'x_coords': x_coords,
'y_coords': y_coords}
geom_list = []
# for x in np.arange(x_min, x_min + width*x_res, x_res):
for x in x_coords - 0.5*x_res: # aligned on bounds
# for y in np.arange(y_max, y_max + height*y_res, y_res):
for y in y_coords - 0.5*y_res:
geom_list.append(Polygon([(x, y),
(x + x_res, y),
(x + x_res, y + y_res),
(x, y + y_res),
(x, y)]))
grid_gdf = gpd.GeoDataFrame(
{'x_centroid': np.repeat(x_coords, height),
'y_centroid': np.tile(y_coords, width),
'geometry': geom_list},
crs = crs)
# grid_gdf['grid_index'] = grid_gdf.index
return grid_gdf, xy_coords
def rasterize_numeric(data_gdf,
grid_gdf,
var_name,
mode,
*, is_polygon = False,
is_point = False):
"""
Rasterize a numeric column of a geopandas.GeoDataFrame, according to a
vector grid.
Parameters
----------
data_gdf : geopandas.GeoDataFrame
Input frame containing the column to rasterize.
grid_gdf : geopandas.GeoDataFrame
Vector grid composed: one rectangle will correspond to one raster cell.
var_name : str
Numeric variable column to rasterize.
mode : {'mean', 'sum', 'count'}
Rasterization mode:
- 'mean':
- if ``is_polygon = True``: sum of polygon values weighted by their relative coverage on each cell
(appropriate for intensive quantities)
- if ``is_point = True``: average value of points on each cell
- 'sum':
- if ``is_polygon = True``: sum of polygon values downscaled to each cell
(appropriate for extensive quantities)
- if ``is_point = True``: summed value of points in each cell
- 'count': number of points or polygons intersecting each cell
is_polygon : bool
Whether the input frame is a polygon vector dataset or not.
is_point : bool
Whether the input frame is a point vector dataset or not.
Returns
-------
cover : xarray.DataArray
DataArray containing the results.
"""
# Inputs
if is_polygon == is_point:
print("Err: `is_polygon` and `is_point` arguments in geobricks.rasterize_numeric()"
f" cannot be both {is_polygon}")
return
# Retrieve transform parameters
x_coords = grid_gdf.x_centroid.unique()
y_coords = grid_gdf.y_centroid.unique()
coords = [y_coords, x_coords]
dims = ['y', 'x']
shape = (len(y_coords), len(x_coords))
# Initialize a null xr.DataArray
cover = xr.DataArray(np.zeros(shape),
coords = coords,
dims = dims)
# Determine the spatial join with the grid
sjoin = data_gdf[['geometry']].sjoin(grid_gdf,
how = "left",
predicate = "intersects")
sjoin = sjoin[['index_right', 'x_centroid', 'y_centroid']]
sjoin = sjoin[~sjoin.index_right.isna()] # remove Nan rows
sjoin.reset_index(drop = False, inplace = True)
if mode == "mean":
if is_polygon:
total_area = cover.copy(deep = True) # for normalization
for i in sjoin.index: # for each intersection
intersect_area = grid_gdf.loc[sjoin.loc[i, 'index_right'], 'geometry'].intersection(data_gdf.loc[sjoin.loc[i, 'index'], 'geometry']).area
cover.loc[{'x': sjoin.loc[i, 'x_centroid'], 'y': sjoin.loc[i, 'y_centroid']}] += intersect_area * data_gdf.loc[sjoin.loc[i, 'index'], var_name]
total_area.loc[{'x': sjoin.loc[i, 'x_centroid'], 'y': sjoin.loc[i, 'y_centroid']}] += intersect_area
cover = cover / total_area
elif is_point:
count = cover.copy(deep = True) # for normalization
for i in sjoin.index: # for each intersection
cover.loc[{'x': sjoin.loc[i, 'x_centroid'], 'y': sjoin.loc[i, 'y_centroid']}] += data_gdf.loc[sjoin.loc[i, 'index'], var_name]
count.loc[{'x': sjoin.loc[i, 'x_centroid'], 'y': sjoin.loc[i, 'y_centroid']}] += 1
cover = cover / count
elif mode == "sum":
if is_polygon:
for i in sjoin.index: # for each intersection
intersect_area = grid_gdf.loc[sjoin.loc[i, 'index_right'], 'geometry'].intersection(data_gdf.loc[sjoin.loc[i, 'index'], 'geometry']).area
polygon_area = data_gdf.loc[sjoin.loc[i, 'index'], 'geometry'].area
cover.loc[{'x': sjoin.loc[i, 'x_centroid'], 'y': sjoin.loc[i, 'y_centroid']}] += intersect_area/polygon_area * data_gdf.loc[sjoin.loc[i, 'index'], var_name]
elif is_point:
for i in sjoin.index: # for each intersection
cover.loc[{'x': sjoin.loc[i, 'x_centroid'], 'y': sjoin.loc[i, 'y_centroid']}] += data_gdf.loc[sjoin.loc[i, 'index'], var_name]
elif mode == "count":
for i in sjoin.index: # for each intersection
cover.loc[{'x': sjoin.loc[i, 'x_centroid'], 'y': sjoin.loc[i, 'y_centroid']}] += 1
return cover
def rasterize_categorical(data_gdf,
grid_gdf,
*, is_polygon = False,
is_point = False):
"""
Rasterize a categorical column of a geopandas.GeoDataFrame, according to a
vector grid. For consistency, the geopandas.GeoDataFrame should contain
only one level of the categorical variable. This function returns the
count of points in each raster cell (if ``is_point = True``) or the
area covered by polygons in each raster cell (if ``is_polygon = True``).
Parameters
----------
data_gdf : geopandas.GeoDataFrame
Input frame containing the column to rasterize. This function will not
differentiate the different levels of the categorical variable. It is
expected that ``data_gdf`` contains only one signle level.
grid_gdf : geopandas.GeoDataFrame
Vector grid composed: one rectangle will correspond to one raster cell.
is_polygon : bool
Whether the input frame is a polygon vector dataset or not.
is_point : bool
Whether the input frame is a point vector dataset or not.
Returns
-------
cover : xarray.DataArray
DataArray containing the results.
- if ``is_polygon = True``: the sum of area intersecting each cell will be returned
- if ``is_point = True``: the count of points in each cell will returned
"""
# Inputs
if is_polygon == is_point:
print("Err: `is_polygon` and `is_point` arguments in geobricks.rasterize_categorical()"
f" cannot be both {is_polygon}")
return
# Retrieve transform parameters
x_coords = grid_gdf.x_centroid.unique()
y_coords = grid_gdf.y_centroid.unique()
coords = [y_coords, x_coords]
dims = ['y', 'x']
shape = (len(y_coords), len(x_coords))
# Initialize a null xr.DataArray
cover = xr.DataArray(np.zeros(shape),
coords = coords,
dims = dims)
data_gdf = data_gdf[['geometry']]
# Determine the spatial join with the grid
sjoin = data_gdf.sjoin(grid_gdf,
how = "left",
predicate = "intersects")
sjoin = sjoin[['index_right', 'x_centroid', 'y_centroid']]
sjoin = sjoin[~sjoin.index_right.isna()] # remove Nan rows
sjoin.reset_index(drop = False, inplace = True)
if is_polygon:
for i in sjoin.index: # for each intersection
intersect_area = grid_gdf.loc[sjoin.loc[i, 'index_right'], 'geometry'].intersection(data_gdf.loc[sjoin.loc[i, 'index'], 'geometry']).area
cover.loc[{'x': sjoin.loc[i, 'x_centroid'], 'y': sjoin.loc[i, 'y_centroid']}] += intersect_area
elif is_point:
for i in sjoin.index: # for each intersection
cover.loc[{'x': sjoin.loc[i, 'x_centroid'], 'y': sjoin.loc[i, 'y_centroid']}] += 1
return cover
#%%% Crop and fill according to mask
###############################################################################
def cropfill(data,
mask):
"""
Crop a xarray.Dataset or DataArray to a mask, and fill the NaN parts inside
the mask with an average value.
Parameters
----------
data : TYPE
DESCRIPTION.
mask : TYPE
DESCRIPTION.
Returns
-------
mask_ds : TYPE
DESCRIPTION.
"""
# Clip data
data_ds = clip(
data,
mask = mask,
)
# Safeguard for input data type
if not isinstance(data_ds, (xr.DataArray, xr.Dataset)):
print(f"Err: geobricks.cropfill() function only apply to xarray.Datasets or xarray.DataArrays, not to {type(data_ds)}")
# Convert any mask into a clipped dataset version of it, with the value 1
# inside the mask, and np.nan elsewhere
raster_mask = data_ds.where(False, 1) # 1 everywhere
mask_ds = clip(
raster_mask,
mask = mask,
)
x_var, y_var = main_space_dims(data_ds)[0]
# Fill missing areas
if isinstance(data_ds, xr.Dataset):
# mask_var = main_vars(mask_ds)[0]
for var in main_vars(data_ds):
encod = data_ds[var].encoding.copy()
data_ds[var] = data_ds[var].where(data_ds[var].notnull() | mask_ds[var].isnull(), data_ds[var].mean(dim = [x_var, y_var]))
data_ds[var].encoding = encod
elif isinstance(data_ds, xr.DataArray):
encod = data_ds.encoding.copy()
data_ds = data_ds.where(data_ds.notnull() | mask_ds.isnull(), data_ds.mean(dim = [x_var, y_var]))
data_ds.encoding = encod
return data_ds
#%%% Merge and fill years
###############################################################################
[docs]
def fill_years(data_list,
*, tag = ''):
"""
Combine several datasets covering different periods (years) into one dataset covering
the whole period. Missing years are filled by propagating forward the previous year.
Parameters
----------
data_list : str or pathlib.Path, or list of str or pathlib.Path, or list of xarray.Datasets
Datasets that will be assembled and completed.
The timestep that will be completed is the year: it is assumed that there
is no missing data within each year.
tag : str, optional
If ``data_list`` is a path to a folder, ``tag`` argument specifies which
type of files should be taken into account in this folder.
Returns
-------
xarra.Dataset
Resulting merged and completed dataset.
"""
# User can give a folder path as an argument, in chich case all the netCDF
# files in this folder will be considered as `data_list`
if isinstance(data_list, (str, Path)):
root, filelist = get_filelist(data_list, extension = 'nc', tag = tag)
root = Path(root)
data_list = [root / f for f in filelist]
# Safeguards
if not isinstance(data_list, (list, tuple)):
print("Err: `data_list` argument should be an iterable")
return data_list
if len(data_list) == 1:
print("Warning: There is no need to use `dynamic_landcover()` on a single dataset.")
return load_any(data_list[0])
# ---- Initialization
# Load and sort the datasets
ds_dict = {}
for d in data_list:
ds = load_any(d)
t = main_time_dims(ds)[0]
first_year = np.min(np.unique(ds[t].dt.year)).item()
ds_dict[first_year] = ds
sorted_years = sorted(ds_dict)
ds_list = [ds_dict[y] for y in sorted_years]
# Initialization of the complete list of datasets
ds = ds_list[-1]
up_year = np.max(np.unique(ds[main_time_dims(ds)[0]].dt.year)).item() + 1
ds_allyears = []
# ---- Browse data list from end to beginning
for ds in ds_list[::-1]:
# Retrieve info from loaded xr.dataset: names of main var and time dimension, and year(s)
t = main_time_dims(ds)[0]
year = np.unique(ds[t].dt.year)
# For each year in loaded xr.dataset, the corresponding xr.DataArray is appended to `ds_list`
for y in year:
ds_allyears.append(ds.loc[{t: slice(str(y), str(y))}])
# The xr.DataArray corresponding to the most recent year (in the loaded
# xr.Dataset) will be taken as the reference data that will be
# propagated to the next available year in `data_list`, called `up_year`
ref_ds = ds.loc[{t: slice(str(np.max(year)), str(np.max(year)))}]
# For every missing year between `ref_ds` and `up_year`, `ref_da` will be propagated
for y in range(np.max(year)+1, up_year):
add_ds = ref_ds.copy(deep = True)
add_ds[t] = pd.to_datetime(add_ds[t]) + pd.DateOffset(years = y - np.max(year))
ds_allyears.append(add_ds)
up_year = np.min(year).item()
# ---- Combine datasets
# Sort
ds_dict2 = {}
for ds in ds_allyears:
t = main_time_dims(ds)[0]
first_year = np.min(np.unique(ds[t].dt.year)).item()
ds_dict2[first_year] = ds
sorted_years2 = sorted(ds_dict2)
ds_allyears = [ds_dict2[y] for y in sorted_years2]
ds = merge_data(ds_allyears)
return ds
#%%% Broadcast in time
###############################################################################
[docs]
def prolong(data,
*, start_year,
end_year,
sample_years_start = 1,
sample_years_end = 1,
crs = None,
to_file = False):
"""
Extend in time, to cover the duration from ``start_year`` to ``end_year``.
In order to fill the added years, the missing years preceeding the input data
are filled with an average over the first ``sample_years_start`` years, and
the missing years following the input data are filled with an average over
the last ``sample_years_end``.
This function is particularly handy to extend in time data that do not
fully cover the time extend of CWatM simulation. For instance:
When running
CWatM from 2000 to 2030, if withdrawal data only exist between 2005 and 2025,
we can assume that the withdrawals during the years 2000-2004 are somehow
similar to the average withdrawal over the first 10 available years (2005-2015). In the same way
we can assume that the withdrawals during the years 2026-2030 will somehow
be consistend with the average withdrawal over the last 5 available years (2020-2025).
We can thus quickly run the simulation by preparing the input data with:
``cwatm.prolong(data, start_year = 2000, end_year = 2030, sample_years_start = 10, sample_years_end = 5, to_file = True)``.
Parameters
----------
data : path (str or pathlib.Path), or variable (xarray.Dataset or xarray.DataArray)
``data`` will be loaded into a standard *GEOP4TH* variable:
- all raster data (ASCII, GeoTIFF) and netCDF will be loaded as a xarray.Dataset
If ``data`` is already a variable, no operation will be executed.
start_year : int
Lower bound of the new time cover.
end_year : int
Upper bound of the new time cover.
sample_years_start : int, optional, default 1
Set the period length at the beginning of the input data time cover
that will be used to compute the average values used for filling data
from ``start_year`` to the beginning of the input data.
sample_years_end : int, optional, default 1
Set the period length at the end of the input data time cover
that will be used to compute the average values used for filling data
from the end of the input data to ``end_year``.
crs : int or str or rasterio.crs.CRS, optional
Coordinate reference system of the source (``data``), that will be embedded in the ``data``.
When passed as an *integer*, ``crs`` refers to the EPSG code.
When passed as a *string*, ``crs`` can be OGC WKT string or Proj.4 string.
to_file : bool or path (str or pathlib.Path), default False
If True and if ``data`` is a path (str or pathlib.Path), the resulting
dataset will be exported to the same location as ``data``, while appending '_prolonged' to its name.
If ``to_file`` is a path, the resulting dataset will be exported to this specified filepath.
Returns
-------
Time-expanded xarray.Dataset or geopandas.GeoDataFrame with a standard georeferencement.
If ``to_file`` argument is used, the resulting dataset will also be exported to a file.
"""
# ---- Initialization
ds = load_any(data)
time = main_time_dims(ds)[0]
if not isinstance(data, (xr.Dataset, xr.DataArray)):
print("Err: only space-time maps are supported so far")
return
if to_file is True:
if isinstance(data, (str, Path)):
data = Path(data)
outpath = data.parent
outname = data.stem
extension = data.suffix
else:
if not isinstance(to_file, str):
print("Warning: a outpath should be provided through `to_file`")
to_file = False
# ---- Expand before
ds_sample_start = ds.loc[
{time: slice(
ds[time][0], pd.to_datetime(ds[time])[0] + pd.DateOffset(years = sample_years_start,
days = -1)
)}
]
ds_sample_start = ds_sample_start.groupby([f'{time}.month', f'{time}.day']).mean().stack(time = ['month', 'day'])
ds_sample_start = ds_sample_start.assign_coords(
dict(time = pd.to_datetime(dict(year = ds[time][0].dt.year.item(), month = ds_sample_start.month, day = ds_sample_start.day))))
# ds_sample_start = ds_sample_start.rename({'dayofyear': time})
# ds_sample_start[time] = ds.loc[
# {time: slice(
# ds[time][0], pd.to_datetime(ds[time])[0] + pd.DateOffset(years = 1, days = -1)
# )}].time
start = []
for y in range(start_year, ds[time][0].dt.year.item()):
ds_sample_start[time] = pd.to_datetime(ds_sample_start[time]) + pd.DateOffset(years = y - ds_sample_start[time][0].dt.year.item())
# Handle leap years
ds_sample_start = ds_sample_start.drop_duplicates(dim = time)
start.append(ds_sample_start.copy())
# ---- Expand after
ds_sample_end = ds.loc[
{time: slice(
pd.to_datetime(ds[time])[-1] - pd.DateOffset(years = sample_years_end, days = -1),
ds[time][-1]
)}
]
ds_sample_end = ds_sample_end.groupby([f'{time}.month', f'{time}.day']).mean().stack(time = ['month', 'day'])
# ds_sample_end = ds_sample_end.groupby(ds_sample_end[time].dt.dayofyear).mean()
ds_sample_end = ds_sample_end.assign_coords(
dict(time = pd.to_datetime(dict(
year = ds[time][-1].dt.year.item(),
# year = (pd.to_datetime(ds[time])[-1] - pd.DateOffset(years = 1, days = -1)).year,
month = ds_sample_end.month,
day = ds_sample_end.day))))
# ds_sample_end = ds_sample_end.rename({'dayofyear': time})
ds_sample_end[time] = ds.loc[
{time: slice(
pd.to_datetime(ds[time])[-1] - pd.DateOffset(years = 1, days = -1),
ds[time][-1]
)}].time
end = []
for y in range(ds[time][-1].dt.year.item() + 1, end_year + 1):
ds_sample_end[time] = pd.to_datetime(ds_sample_end[time]) + pd.DateOffset(years = y - ds_sample_end[time][-1].dt.year.item())
# Handle leap years
ds_sample_end = ds_sample_end.drop_duplicates(dim = time)
end.append(ds_sample_end.copy())
# ---- Merge data and export
expand_ds = merge_data(start + [ds] + end)
expand_ds = georef(expand_ds, crs = crs)
if to_file:
export(expand_ds, outpath / (outname + "_prolonged" + extension))
return expand_ds
#%% COMPRESS & UNCOMPRESS netCDF
###############################################################################
[docs]
def unzip(data,
to_file = False):
"""
Uncompress gzipped netCDF. Only applies to gzip compression (non-lossy compression).
Even if gzip compression is not destructive, in some GIS softwares
uncompressed netCDF are quicker to manipulate than gzipped netCDF.
Parameters
----------
data : path (str or pathlib.Path), or variable (xarray.Dataset, xarray.DataArray)
Dataset (netCDF or xarray variable) that will be unzipped.
Note that ``data`` will be loaded into a xarray.Dataset or xarray.DataArray.
to_file : bool or path (str or pathlib.Path), default False
If True and if ``data`` is a path (str or pathlib.Path), the resulting
dataset will be exported to a file with the same pathname and the
suffix '_unzip'. If ``to_file`` is a path, the resulting dataset
will be exported to this specified filepath.
Returns
-------
data_ds : xarray.Dataset
Standard *GEOP4TH* variable (xarray.Dataset) with gzip compression removed.
If ``to_file`` argument is used, the resulting dataset can also be exported to a file.
"""
# Load
data_ds = load_any(data, decode_times = True, decode_coords = 'all')
if not isinstance(data_ds, (xr.Dataset, xr.DataArray)):
print("Error: the `unzip` function is only intended for netCDF datasets")
return
# Get main variable
var_list = main_vars(data_ds)
for var in var_list:
# Deactivate zlib
data_ds[var].encoding['zlib'] = False
# Export
if to_file == True:
if isinstance(data, (str, Path)):
print('\nExporting...')
export_filepath = os.path.splitext(data)[0] + "_unzip" + os.path.splitext(data)[-1]
export(data_ds, export_filepath)
else:
print("Warning; `data` should be a path (str or pathlib.Path) for using `to_file=True`.")
elif isinstance(to_file, (str, Path)):
print('\nExporting...')
export(data_ds, to_file)
# Return
return data_ds
###############################################################################
[docs]
def gzip(data,
complevel = 3,
shuffle = False,
to_file = False):
r"""
Apply a non-lossy compression (gzip) to a netCDF dataset.
Parameters
----------
data : path (str or pathlib.Path), or variable (xarray.Dataset, xarray.DataArray)
Dataset (netCDF or xarray variable) that will be gzipped (non-lossy).
Note that ``data`` will be loaded into a xarray.Dataset or xarray.DataArray.
complevel : {1, 2, 3, 4, 5, 6, 7, 8, 9}, default 3
Compression level, (1 being fastest, but lowest compression ratio,
9 being slowest but best compression ratio).
shuffle : bool, default False
HD5 shuffle filter, which de-interlaces a block of data before zgip
compression by reordering the bytes
to_file : bool or path (str or pathlib.Path), default False
If True and if ``data`` is a path (str or pathlib.Path), the resulting
dataset will be exported to a file with the same pathname and the
suffix '_gzip'. If ``to_file`` is a path, the resulting dataset
will be exported to this specified filepath.
Returns
-------
data_ds : xarray.Dataset
Standard *GEOP4TH* variable (xarray.Dataset) with gzip compression added.
If ``to_file`` argument is used, the resulting dataset can also be exported to a file.
Examples
--------
geo.gzip(myDataset, complevel = 4, shuffle = True)
geo.gzip(r"D:\folder\data1.nc", complevel = 5)
"""
# Load
data_ds = load_any(data, decode_times = True, decode_coords = 'all')
if not isinstance(data_ds, (xr.Dataset, xr.DataArray)):
print("Error: the `gzip` function is only intended for netCDF datasets")
return
# Get main variable
var_list = main_vars(data_ds)
for var in var_list:
# Activate zlib
data_ds[var].encoding['zlib'] = True
data_ds[var].encoding['complevel'] = complevel
data_ds[var].encoding['shuffle'] = shuffle
data_ds[var].encoding['contiguous'] = False
# Export
if to_file == True:
if isinstance(data, (str, Path)):
print('\nExporting...')
export_filepath = os.path.splitext(data)[0] + "_gzip" + os.path.splitext(data)[-1]
export(data_ds, export_filepath)
else:
print("Warning; `data` should be a path (str or pathlib.Path) for using `to_file=True`.")
elif isinstance(to_file, (str, Path)):
print('\nExporting...')
export(data_ds, to_file)
# Return
return data_ds
###############################################################################
[docs]
def pack(data,
nbits = 16,
to_file = False):
"""
Applies a lossy compression to a netCDF dataset, by packing the values to
a data type with smaller number of bits. Under the hood, this function
automatically defines the corresponding ``add_offset`` and ``scale_factor``.
Parameters
----------
data : path (str or pathlib.Path), or variable (xarray.Dataset, xarray.DataArray)
Dataset (netCDF or xarray variable) that will be gzipped (non-lossy).
Note that ``data`` will be loaded into a xarray.Dataset or xarray.DataArray.
nbits : {8, 16}, default 16
Number of bits for the data type of the output values.
to_file : bool or path (str or pathlib.Path), default False
If True and if ``data`` is a path (str or pathlib.Path), the resulting
dataset will be exported to a file with the same pathname and the
suffix '_pack'. If ``to_file`` is a path, the resulting dataset
will be exported to this specified filepath.
Returns
-------
data_ds : xarray.Dataset
Standard *GEOP4TH* variable (xarray.Dataset) with lossy compression.
If ``to_file`` argument is used, the resulting dataset can also be exported to a file.
"""
if (nbits != 16) & (nbits != 8):
print("Err: nbits should be 8 or 16")
return
# Load
data_ds = load_any(data, decode_times = True, decode_coords = 'all')
if not isinstance(data_ds, (xr.Dataset, xr.DataArray)):
print("Error: the `pack` function is only intended for netCDF datasets")
return
# Get main variable
var_list = main_vars(data_ds)
for var in var_list:
# Compress
bound_min = data_ds[var].min().item()
bound_max = data_ds[var].max().item()
# Add an increased max bound, that will be used for _FillValue
bound_max = bound_max + (bound_max - bound_min + 1)/(2**nbits)
scale_factor, add_offset = compute_scale_and_offset(
bound_min, bound_max, nbits)
data_ds[var].encoding['scale_factor'] = scale_factor
data_ds[var].encoding['dtype'] = f'uint{nbits}'
data_ds[var].encoding['_FillValue'] = (2**nbits)-1
data_ds[var].encoding['add_offset'] = add_offset
print(" Compression (lossy)")
# Prevent _FillValue issues
if ('missing_value' in data_ds[var].encoding) & ('_FillValue' in data_ds[var].encoding):
data_ds[var].encoding.pop('missing_value')
# Export
if to_file == True:
if isinstance(data, (str, Path)):
print('\nExporting...')
export_filepath = os.path.splitext(data)[0] + "_pack" + os.path.splitext(data)[-1]
export(data_ds, export_filepath)
else:
print("Warning; `data` should be a path (str or pathlib.Path) for using `to_file=True`.")
elif isinstance(to_file, (str, Path)):
print('\nExporting...')
export(data_ds, to_file)
# Return
return data_ds
#%%% Packing netcdf (previously packnetcdf.py)
"""
Created on Wed Aug 24 16:48:29 2022
@author: script based on James Hiebert's work (2015):
http://james.hiebert.name/blog/work/2015/04/18/NetCDF-Scale-Factors.html
RAPPEL des dtypes :
uint8 (unsigned int.) 0 to 255
uint16 (unsigned int.) 0 to 65535
uint32 (unsigned int.) 0 to 4294967295
uint64 (unsigned int.) 0 to 18446744073709551615
int8 (Bytes) -128 to 127
int16 (short integer) -32768 to 32767
int32 (integer) -2147483648 to 2147483647
int64 (integer) -9223372036854775808 to 9223372036854775807
float16 (half precision float) 10 bits mantissa, 5 bits exponent (~ 4 cs ?)
float32 (single precision float) 23 bits mantissa, 8 bits exponent (~ 8 cs ?)
float64 (double precision float) 52 bits mantissa, 11 bits exponent (~ 16 cs ?)
"""
###############################################################################
def compute_scale_and_offset(min, max, n):
"""
Computes scale and offset necessary to pack a float32 (or float64?) set
of values into a int16 or int8 set of values.
Parameters
----------
min : float
Minimum value from the data
max : float
Maximum value from the data
n : int
Number of bits into which you wish to pack (8 or 16)
Returns
-------
scale_factor : float
Parameter for netCDF's encoding
add_offset : float
Parameter for netCDF's encoding
"""
# stretch/compress data to the available packed range
add_offset = min
scale_factor = (max - min) / ((2 ** n) - 1)
return (scale_factor, add_offset)
###############################################################################
def pack_value(unpacked_value, scale_factor, add_offset):
"""
Compute the packed value from the original value, a scale factor and an
offset.
Parameters
----------
unpacked_value : numeric
Original value.
scale_factor : numeric
Scale factor, multiplied to the original value.
add_offset : numeric
Offset added to the original value.
Returns
-------
numeric
Packed value.
"""
# print(f'math.floor: {math.floor((unpacked_value - add_offset) / scale_factor)}')
return int((unpacked_value - add_offset) / scale_factor)
###############################################################################
def unpack_value(packed_value, scale_factor, add_offset):
"""
Retrieve the original value from a packed value, a scale factor and an
offset.
Parameters
----------
packed_value : numeric
Value to unpack.
scale_factor : numeric
Scale factor that was multiplied to the original value to retrieve.
add_offset : numeric
Offset that was added to the original value to retrieve.
Returns
-------
numeric
Original unpacked value.
"""
return packed_value * scale_factor + add_offset
#%% OPERATIONS ON UNITS
###############################################################################
def hourly_to_daily(data,
*, mode = 'sum',
to_file = False):
"""
Converts a hourly dataset to daily values. Implemented only for netCDF so far.
Parameters
----------
data : TYPE
DESCRIPTION.
mode : {'mean', 'max', 'min', 'sum'}, default 'sum'
DESCRIPTION.
to_file : bool or path (str or pathlib.Path), default False
If True and if ``data`` is a path (str or pathlib.Path), the resulting
dataset will be exported to a file with the same pathname and the
suffix '_daily'. If ``to_file`` is a path, the resulting dataset
will be exported to this specified filepath.
Returns
-------
datarsmpl : TYPE
DESCRIPTION.
If ``to_file`` argument is used, the resulting dataset can also be exported to a file.
"""
# ---- Process data
#% Load data:
data_ds = load_any(data, decode_coords = 'all', decode_times = True)
var_list = main_vars(data_ds)
mode_dict = {}
if isinstance(mode, str):
for var in var_list:
mode_dict[var] = mode
elif isinstance(mode, dict):
mode_dict = mode
if len(var_list) > len(mode_dict):
diff = set(var_list).difference(mode_dict)
print(f" _ Warning: {len(diff)} variables were not specified in 'mode': {', '.join(diff)}. They will be assigned the mode 'sum'.")
for d in diff:
mode_dict[d] = 'sum'
time_coord = main_time_dims(data_ds)[0]
datarsmpl = xr.Dataset()
#% Resample:
print(" _ Resampling time...")
for var in var_list:
if mode_dict[var] == 'mean':
datarsmpl[var] = data_ds[var].resample({time_coord: '1D'}).mean(dim = time_coord,
keep_attrs = True)
elif mode_dict[var] == 'max':
datarsmpl[var] = data_ds[var].resample({time_coord: '1D'}).max(dim = time_coord,
keep_attrs = True)
elif mode_dict[var] == 'min':
datarsmpl[var] = data_ds[var].resample({time_coord: '1D'}).min(dim = time_coord,
keep_attrs = True)
elif mode_dict[var] == 'sum':
datarsmpl[var] = data_ds[var].resample({time_coord: '1D'}).sum(dim = time_coord,
skipna = False,
keep_attrs = True)
# ---- Preparing export
# Transfer encodings
datarsmpl[var].encoding = data_ds[var].encoding
# Case of packing
if ('scale_factor' in datarsmpl[var].encoding) | ('add_offset' in datarsmpl[var].encoding):
# Packing (lossy compression) induces a loss of precision of
# apprx. 1/1000 of unit, for a quantity with an interval of 150
# units. The packing is initially used in some original ERA5-Land data.
if mode == 'sum':
print(" Correcting packing encodings...")
datarsmpl[var].encoding['scale_factor'] = datarsmpl[var].encoding['scale_factor']*24
datarsmpl[var].encoding['add_offset'] = datarsmpl[var].encoding['add_offset']*24
# Transfert coord encoding
for c in list(datarsmpl.coords):
datarsmpl[c].encoding = data_ds[c].encoding
datarsmpl[c].attrs = data_ds[c].attrs
datarsmpl[time_coord].encoding['units'] = datarsmpl[time_coord].encoding['units'].replace('hours', 'days')
# ---- Export
if to_file == True:
if isinstance(data, (str, Path)):
print('\nExporting...')
export_filepath = os.path.splitext(data)[0] + "_daily" + os.path.splitext(data)[-1]
export(datarsmpl, export_filepath)
else:
print("Warning; `data` should be a path (str or pathlib.Path) for using `to_file=True`.")
elif isinstance(to_file, (str, Path)):
print('\nExporting...')
export(datarsmpl, to_file)
return datarsmpl
###############################################################################
def to_instant(data,
derivative = False,
to_file = False):
"""
Parameters
----------
data : TYPE
DESCRIPTION.
derivative : TYPE, optional
DESCRIPTION. The default is False.
to_file : bool or path (str or pathlib.Path), default False
If True and if ``data`` is a path (str or pathlib.Path), the resulting
dataset will be exported to a file with the same pathname and the
suffix '_instant'. If ``to_file`` is a path, the resulting dataset
will be exported to this specified filepath.
Returns
-------
inst_ds : TYPE
DESCRIPTION.
If ``to_file`` argument is used, the resulting dataset can also be exported to a file.
"""
data_ds = load_any(data, decode_coords = 'all', decode_times = True)
time_coord = main_time_dims(data_ds)[0]
if derivative:
inst_ds = data_ds.diff(dim = time_coord)/data_ds[time_coord].diff(dim = time_coord)
else:
inst_ds = data_ds.diff(dim = time_coord)
# Export
if to_file == True:
if isinstance(data, (str, Path)):
print('\nExporting...')
export_filepath = os.path.splitext(data)[0] + "_instant" + os.path.splitext(data)[-1]
export(inst_ds, export_filepath)
else:
print("Warning; `data` should be a path (str or pathlib.Path) for using `to_file=True`.")
elif isinstance(to_file, (str, Path)):
print('\nExporting...')
export(inst_ds, to_file)
return inst_ds
###############################################################################
def convert_unit(data,
operation,
*, var = None,
to_file = False):
"""
Parameters
----------
data : TYPE
DESCRIPTION.
operation : TYPE
DESCRIPTION.
* : TYPE
DESCRIPTION.
var : TYPE, optional
DESCRIPTION. The default is None.
to_file : bool or path (str or pathlib.Path), default False
If True and if ``data`` is a path (str or pathlib.Path), the resulting
dataset will be exported to a file with the same pathname and the
suffix '_units'. If ``to_file`` is a path, the resulting dataset
will be exported to this specified filepath.
Returns
-------
data_ds : TYPE
DESCRIPTION.
If ``to_file`` argument is used, the resulting dataset can also be exported to a file.
"""
metric_prefixes = ['p', None, None, 'n', None, None, 'µ', None, None,
'm', 'c', 'd', '', 'da', 'h', 'k', None, None, 'M',
None, None, 'G']
# ---- Load data and operands
data_ds = load_any(data)
if not isinstance(operation, str):
print("Err: 'operation' should be a str.")
return
else:
operation = operation.replace(' ', '').replace('x', '*').replace('×', '*').replace('÷', '/')
operand = operation[0]
factor = float(operation[1:])
if isinstance(data_ds, (pd.DataFrame, gpd.GeoDataFrame)):
if var is None:
mvar = main_vars(data_ds)
else:
mvar = var
# ---- Operation
if operand == '*':
data_ds[mvar] = data_ds[mvar] * factor
elif operand == '/':
data_ds[mvar] = data_ds[mvar] / factor
elif operand == '+':
data_ds[mvar] = data_ds[mvar] + factor
elif operand == '-':
data_ds[mvar] = data_ds[mvar] - factor
return data_ds
elif isinstance(data_ds, xr.Dataset):
mvar = main_vars(data_ds)
if len(mvar) == 1:
data_da = data_ds[mvar[0]]
else: # mvar is a list
if var is not None:
data_da = data_ds[var]
else:
print("Err: convert_unit can only be used on xarray.DataArrays or xarray.Datasets with one variable. Consider passing the argument 'var'.")
return
elif isinstance(data_ds, xr.DataArray):
data_da = data_ds
# ---- Preparing export
attrs = data_da.attrs
encod = data_da.encoding
# ---- Operation
# exec(f"data_da = data_da {operation}") # vulnerability issues
if operand == '*':
data_da = data_da * factor
elif operand == '/':
data_da = data_da / factor
elif operand == '+':
data_da = data_da + factor
elif operand == '-':
data_da = data_da - factor
# ---- Transfert metadata
data_da.encoding = encod
data_da.attrs = attrs # normally unnecessary
for unit_id in ['unit', 'units']:
if unit_id in data_da.attrs:
if operand in ['*', '/']:
significand, exponent = f"{factor:e}".split('e')
significand = float(significand)
exponent = int(exponent)
# if factor_generic%10 == 0:
if significand == 1:
current_prefix = data_da.attrs[unit_id][0]
current_idx = metric_prefixes.index(current_prefix)
# new_idx = current_idx + int(np.log10(factor_generic))
if operand == "*":
new_idx = current_idx - exponent
new_unit = data_da.attrs[unit_id] + f" {operand}{significand}e{exponent}" # By default
elif operand == "/":
new_idx = current_idx + exponent
new_unit = data_da.attrs[unit_id] + f" *{significand}e{-exponent}" # By default
if (new_idx >= 0) & (new_idx <= len(metric_prefixes)):
if metric_prefixes[new_idx] is not None:
new_unit = metric_prefixes[new_idx] + data_da.attrs[unit_id][1:]
data_da.attrs[unit_id] = new_unit
else:
new_unit = data_da.attrs[unit_id] + f" {operand}{significand}e{exponent}" # By default
data_da.attrs[unit_id] = new_unit
# Case of packing
if ('scale_factor' in data_da.encoding) | ('add_offset' in data_da.encoding):
# Packing (lossy compression) induces a loss of precision of
# apprx. 1/1000 of unit, for a quantity with an interval of 150
# units. The packing is initially used in some original ERA5-Land data.
if operand == '+':
data_da.encoding['add_offset'] = data_da.encoding['add_offset'] + factor
elif operand == '-':
data_da.encoding['add_offset'] = data_da.encoding['add_offset'] - factor
elif operand == '*':
data_da.encoding['add_offset'] = data_da.encoding['add_offset'] * factor
data_da.encoding['scale_factor'] = data_da.encoding['scale_factor'] * factor
elif operand == '/':
data_da.encoding['add_offset'] = data_da.encoding['add_offset'] / factor
data_da.encoding['scale_factor'] = data_da.encoding['scale_factor'] / factor
if isinstance(data_ds, xr.Dataset):
if len(mvar) == 1:
data_ds[mvar[0]] = data_da
else: # mvar is a list
if var is not None:
data_ds[var] = data_da
else:
print("Err: convert_unit can only be used on xarray.DataArrays or xarray.Datasets with one variable. Consider passing the argument 'var'.")
return
elif isinstance(data_ds, xr.DataArray):
data_ds = data_da
# Export
if to_file == True:
if isinstance(data, (str, Path)):
print('\nExporting...')
export_filepath = os.path.splitext(data)[0] + "_units" + os.path.splitext(data)[-1]
export(data_ds, export_filepath)
else:
print("Warning; `data` should be a path (str or pathlib.Path) for using `to_file=True`.")
elif isinstance(to_file, (str, Path)):
print('\nExporting...')
export(data_ds, to_file)
return data_ds
#%% OPERATIONS ON VALUES
###############################################################################
def correct_bias(
data: xr.Dataset,
variables: Union[str, List[str], Dict[str, float]],
bias_factors: Optional[Union[float, Dict[str, float]]] = None,
region: Optional[str] = None,
progressive: bool = False,
progressive_factors: Optional[Dict[str, np.ndarray]] = None
) -> xr.Dataset:
"""
Apply bias correction to climate variables using either constant
or monthly progressive factors.
The correction is multiplicative:
corrected = original * factor
Parameters
----------
data : xr.Dataset
Input dataset containing the variables to be corrected.
variables : Union[str, List[str], Dict[str, float]]
Variables to correct. Can be string, list of strings, or dict mapping variable names to factors.
bias_factors : Optional[Union[float, Dict[str, float]]], default None
Base multiplicative factors. Float for uniform factor, dict for per-variable factors.
region : Optional[str], default None
Region key for predefined factors from configuration.
progressive : bool, default False
If True, apply monthly progressive factors instead of base factors.
progressive_factors : Optional[Dict[str, np.ndarray]], default None
Dict mapping variable names to 12-element monthly multiplier arrays.
Returns
-------
xr.Dataset
Dataset with bias correction applied to specified variables.
Raises
------
ValueError
If variables parameter has invalid type.
Notes
-----
Multiplicative correction only. Progressive mode ignores base factors.
Variables not found in dataset are skipped with warning.
"""
# Import regional factors from standardization workflow only if needed
if (bias_factors is None or progressive) and region is not None:
try:
from geop4th.workflows.standardize.standardize_wl import (
BIAS_CORRECTION_FACTORS,
PROGRESSIVE_CORRECTION_FACTORS,
)
BIAS_FACTORS_BY_REGION = {k: v.get('factors', {}) for k, v in BIAS_CORRECTION_FACTORS.items()}
PROGRESSIVE_FACTORS_BY_REGION = {k: v.get('factors', {}) for k, v in PROGRESSIVE_CORRECTION_FACTORS.items()}
except Exception:
logger.warning("Regional bias factor tables not available")
BIAS_FACTORS_BY_REGION = {}
PROGRESSIVE_FACTORS_BY_REGION = {}
else:
BIAS_FACTORS_BY_REGION = {}
PROGRESSIVE_FACTORS_BY_REGION = {}
# Work on a copy to avoid modifying original data
data_corrected = data.copy()
# Normalize variables input into a dict mapping variable names to base factors
if isinstance(variables, str):
var_factors: Dict[str, Optional[float]] = {variables: None if progressive else bias_factors}
elif isinstance(variables, list):
if progressive:
# Progressive mode ignores base factors
var_factors = {v: None for v in variables}
else:
if bias_factors is None and region is None:
raise ValueError("Provide `bias_factors` or `region` when variables is a list in base mode.")
if isinstance(bias_factors, dict):
var_factors = {v: bias_factors.get(v) for v in variables}
elif isinstance(bias_factors, (float, int)):
var_factors = {v: float(bias_factors) for v in variables}
else:
regional = BIAS_FACTORS_BY_REGION.get(region, {})
var_factors = {v: regional.get(v) for v in variables}
elif isinstance(variables, dict):
var_factors = dict(variables)
else:
raise ValueError("`variables` must be str, List[str], or Dict[str, float].")
for var_name, base_factor in var_factors.items():
if var_name not in data_corrected.data_vars:
logger.warning("Variable '%s' not found; skipping bias correction.", var_name)
continue
if progressive:
# Progressive mode: apply monthly factors
monthly_factors = None
# Try custom factors first
if progressive_factors and var_name in progressive_factors:
monthly_factors = progressive_factors[var_name]
# Fall back to regional factors
if monthly_factors is None and region in PROGRESSIVE_FACTORS_BY_REGION:
reg = PROGRESSIVE_FACTORS_BY_REGION[region]
if var_name in reg:
item = reg[var_name]
monthly_factors = item.get('values') if isinstance(item, dict) and 'values' in item else item
if monthly_factors is None:
logger.warning(
"Progressive correction requested for '%s' but no monthly factors were found for region '%s'. "
"No correction applied to this variable.", var_name, region
)
continue
# Apply monthly correction using groupby
monthly_correction = xr.DataArray(
np.asarray(monthly_factors, dtype=float),
dims=("month",),
coords={"month": list(range(1, 13))},
name=f"{var_name}_monthly_correction",
)
data_corrected[var_name] = (
data_corrected[var_name].groupby("time.month") * monthly_correction
)
# Clean up temporary month coordinate
if "month" in data_corrected.coords and "month" not in data.coords:
try:
data_corrected = data_corrected.drop_vars("month")
except Exception:
pass
continue # Skip base factor processing
# Base mode: apply constant factor
bf = base_factor
if bf is None and region is not None:
bf = BIAS_FACTORS_BY_REGION.get(region, {}).get(var_name)
if bf is None:
logger.warning(
"No base factor found for '%s' (region=%s). Skipping base correction for this variable.",
var_name, region,
)
continue
try:
bf = float(bf)
except Exception:
logger.warning("Invalid base factor for '%s' (%r). Skipping.", var_name, bf)
continue
# Apply multiplicative correction
data_corrected[var_name] = data_corrected[var_name] * bf
return data_corrected
###############################################################################
#%%% * hourly_to_daily (OLD)
def hourly_to_daily_old(*, data, mode = 'mean', **kwargs):
# Cette version précédente (mise à jour) gère des dossiers
"""
Example
-------
import geoconvert as gc
# full example:
gc.hourly_to_daily(input_file = r"D:/2011-2021_hourly Temperature.nc",
mode = 'max',
output_path = r"D:/2011-2021_daily Temperature Max.nc",
fields = ['t2m', 'tp'])
# input_file can also be a folder:
gc.hourly_to_daily(input_file = r"D:\2- Postdoc\2- Travaux\1- Veille\4- Donnees\8- Meteo\ERA5\Brittany\test",
mode = 'mean')
Parameters
----------
input_file : str, or list of str
Can be a path to a file (or a list of paths), or a path to a folder,
in which cas all the files in this folder will be processed.
mode : str, or list of str, optional
= 'mean' (default) | 'max' | 'min' | 'sum'
**kwargs
--------
fields : str or list of str, optional
e.g: ['t2m', 'tp', 'u10', 'v10', ...]
(if not specified, all fields are considered)
output_path : str, optional
e.g: [r"D:/2011-2021_daily Temperature Max.nc"]
(if not specified, output_name is made up according to arguments)
Returns
-------
None. Processed files are created in the output destination folder.
"""
# ---- Get input file path(s)
data_folder, filelist = get_filelist(data, extension = '.nc')
#% Safeguard for output_names:
if len(filelist) > 1 and 'output_path' in kwargs:
print('Warning: Due to multiple output, names of the output files are imposed.')
# ---- Format modes
if isinstance(mode, str): mode = [mode]
else: mode = list(mode)
if len(mode) != len(filelist):
if (len(mode) == 1) & (len(filelist)>1):
mode = mode*len(filelist)
else:
print("Error: lengths of input file and mode lists do not match")
return
# ---- Process data
for i, f in enumerate(filelist):
print(f"\n\nProcessing file {i+1}/{len(filelist)}: {f}...")
print("-------------------")
#% Load data:
data_ds = load_any(os.path.join(data_folder, f),
decode_coords = 'all', decode_times = True)
#% Get fields:
if 'fields' in kwargs:
fields = kwargs['fields']
if isinstance(fields, str): fields = [fields]
else: fields = list(fields) # in case fields are string or tuple
else:
fields = list(data_ds.data_vars) # if not input_arg, fields = all
#% Extract subset according to fields
fields_intersect = list(set(fields) & set(data_ds.data_vars))
data_subset = data_ds[fields_intersect]
print(" _ Extracted fields are {}".format(fields_intersect))
if fields_intersect != fields:
print('Warning: ' + ', '.join(set(fields) ^ set(fields_intersect))
+ ' absent from ' + data)
#% Resample:
print(" _ Resampling time...")
if mode[i] == 'mean':
datarsmpl = data_subset.resample(time = '1D').mean(dim = 'time',
keep_attrs = True)
elif mode[i] == 'max':
datarsmpl = data_subset.resample(time = '1D').max(dim = 'time',
keep_attrs = True)
elif mode[i] == 'min':
datarsmpl = data_subset.resample(time = '1D').min(dim = 'time',
keep_attrs = True)
elif mode[i] == 'sum':
datarsmpl = data_subset.resample(time = '1D').sum(dim = 'time',
skipna = False,
keep_attrs = True)
#% Build output name(s):
if len(filelist) > 1 or not 'output_path' in kwargs:
basename = os.path.splitext(f)[0]
output_name = os.path.join(
data_folder, 'daily',
basename + ' daily_' + mode[i] + '.nc')
## Regex solution, instead of splitext:
# _motif = re.compile('.+[^\w]')
# _basename = _motif.search(data).group()[0:-1]
if not os.path.isdir(os.path.join(data_folder, 'daily')):
os.mkdir(os.path.join(data_folder, 'daily'))
else:
output_name = kwargs['output_path']
# ---- Preparing export
# Transfer encodings
for c in list(datarsmpl.coords):
datarsmpl[c].encoding = data_ds[c].encoding
for f in fields_intersect:
datarsmpl[f].encoding = data_ds[f].encoding
# Case of packing
if ('scale_factor' in datarsmpl[f].encoding) | ('add_offset' in datarsmpl[f].encoding):
# Packing (lossy compression) induces a loss of precision of
# apprx. 1/1000 of unit, for a quantity with an interval of 150
# units. The packing is initially used in some original ERA5-Land data.
if mode[i] == 'sum':
print(" Correcting packing encodings...")
datarsmpl[f].encoding['scale_factor'] = datarsmpl[f].encoding['scale_factor']*24
datarsmpl[f].encoding['add_offset'] = datarsmpl[f].encoding['add_offset']*24
#% Export
export(datarsmpl, output_name)
def dummy_input(base, value):
"""
Creates a dummy space-time map with the same properties as the base, but with
a dummy value.
Parameters
----------
base : TYPE
DESCRIPTION.
value : TYPE
DESCRIPTION.
Returns
-------
None.
"""
data_ds = load_any(base)
var_list = main_vars(data_ds)
for var in var_list:
data_ds[var] = data_ds[var]*0 + value
return data_ds
#%% EXTRACTIONS
###############################################################################
[docs]
def timeseries(data,
*, coords = 'all',
coords_crs = None,
data_crs = None,
mode = 'mean',
start_date = None,
end_date = None,
var_list = None,
cumul = False):
"""
This function extracts the temporal data in one location given by
coordinate.
Parameters
----------
data : path (str, Path) or variable (xarray.Dataset or xarray.DataArray)
timeseries is only intended to handle raster data (ASCII and GeoTIFF) and netCDF.
``data`` will be loaded as a xarray.Dataset.
coords : 'all', str or path, geopandas.GeoDataFrame, shapely.geometry or tuple of (float, float), default 'all'
The keyword, coordinates or mask that will be used to extract the timeseries.
If 'all', all the pixels in data are considered. Mask can be raster
or vector data. If a tuple of coordinates is passed, coordinates should
be ordered as (x, y) or (lon, lat).
coords_crs : any CRS accepted by ``pyproj.CRS.from_user_input``, optional
CRS of ``coords``, in case it is not already embedded in it.
Accepted CRS can be for example:
- EPSG integer codes (such as 4326)
- authority strings (such as “epsg:4326”)
- CRS WKT strings
- pyproj.CRS
- ...
data_crs : any CRS accepted by ``pyproj.CRS.from_user_input``, optional
CRS of ``data``, in case it is not already embedded in it.
Accepted CRS can be for example:
- EPSG integer codes (such as 4326)
- authority strings (such as “epsg:4326”)
- CRS WKT strings
- pyproj.CRS
- ...
mode : {'mean', 'sum', 'max', 'min'}, default 'mean'
How selected data will be aggregated.
start_date : str or datetime, optional
Start of the selected time period to extract.
end_date : str or datetime, optional
End of the selected time period to extract.
var_list : (list of) str, optional
Fields (variables) to extract.
cumul : bool, default False
If True, values will be retrieved as cumulated sums.
Returns
-------
pandas.DataFrame
Frame containing the timeseries.
"""
# ---- Load
# Handle data
data_ds = load_any(data, decode_times = True, decode_coords = 'all')
if not isinstance(data_ds, (xr.Dataset, xr.DataArray, gpd.GeoDataFrame)):
# raise Exception('timeseries function is intended to handle geospatial raster or netCDF data')
print('Error: timeseries function is intended to handle geospatial vector, raster or netCDF data')
return
#% Get and process other arguments
if var_list is None:
var_list = main_vars(data_ds)
elif isinstance(var_list, str):
var_list = [var_list]
else:
var_list = list(var_list)
# =============================================================================
# fields = list(data_ds.data_vars) # if not input_arg, fields = all
# =============================================================================
print('Variables = {}'.format(str(var_list)))
if data_crs is not None:
data_ds = georef(data_ds, crs = data_crs, var_list = var_list)
else:
data_ds = georef(data_ds, var_list = var_list)
time_dim = main_time_dims(data_ds)[0]
space_dims = main_space_dims(data_ds)[0]
if isinstance(data_ds, gpd.GeoDataFrame):
data_ds = data_ds.set_index(time_dim)
# Handle coords
if isinstance(coords, list): coords = tuple(coords) # in case coords are a list instead of a tuple
if isinstance(coords, (str, Path)):
if coords != 'all':
coords = load_any(coords)
if isinstance(coords, gpd.GeoDataFrame):
if coords_crs is not None:
coords = georef(coords, crs = coords_crs)
else:
coords = georef(coords)
if start_date is not None:
start_date = pd.to_datetime(start_date)
# safeguard
if not isinstance(start_date, (datetime.datetime, pd.Timestamp)):
print("Error: `start_date` is not recognized as a valid date")
return
else:
if isinstance(data_ds, xr.Dataset):
start_date = data_ds[time_dim][0].values
elif isinstance(data_ds, gpd.GeoDataFrame):
start_date = data_ds.index[0]
if end_date is not None:
end_date = pd.to_datetime(end_date)
# safeguard
if not isinstance(end_date, (datetime.datetime, pd.Timestamp)):
print("Error: `end_date` is not recognized as a valid date")
return
else:
if isinstance(data_ds, xr.Dataset):
end_date = data_ds[time_dim][-1].values
elif isinstance(data_ds, gpd.GeoDataFrame):
end_date = data_ds.index[-1]
# =============================================================================
# #% Convert temperature:
# for _field in fields:
# if 'units' in data_ds[_field].attrs:
# if data_ds[_field].units == 'K':
# data_ds[_field] = data_ds[_field]-273.15
# # _datasubset[_field].units = '°C'
# =============================================================================
# ---- Extraction
if isinstance(coords, str):
if coords == 'all':
print("All cells are considered")
sel_ds = data_ds.copy()
# Reprojections when needed
elif isinstance(coords, (gpd.GeoDataFrame, gpd.GeoSeries)):
nrows = coords.shape[0]
if nrows > 1:
if isinstance(coords.geometry[0], Point):
# Safeguard for Frames containing more than one Point:
print("Error: `coords` contains several Points, instead of a single one")
return
elif isinstance(coords.geometry[0], Polygon, MultiPolygon):
print("Warning: `coords` contains several Polygons. The union of all will be considered")
# Note that there is no need for safeguard here, as masking with reproject() will handle several Polygons
# Handling CRS:
## . Set a CRS to `coords` if needed
if coords.crs is None:
if coords_crs is None:
print("Warning: No CRS is associated with `coords`. It is assumed that coords crs is the same as `data`'s. Note that CRS can be passed with `coords_crs` arguments")
coords.set_crs(epsg = data_ds.rio.crs, inplace = True, allow_override = True)
else:
coords.set_crs(epsg = coords_crs, inplace = True, allow_override = True)
## . Reproject (even when not necessary)
if isinstance(data_ds, xr.Dataset):
dst_crs = data_ds.rio.crs
elif isinstance(data_ds, gpd.GeoDataFrame):
dst_crs = data_ds.crs
coords = reproject(coords, dst_crs = dst_crs) # more GEOP4TH-style
# coords.to_crs(epsg = data_ds.rio.crs, inplace = True) # faster
# Case of Point: convert into a tuple
if isinstance(coords.geometry[0], Point):
coords = tuple(coords.geometry[0].x, coords.geometry[0].y)
elif isinstance(coords, (tuple, Point, Polygon, MultiPolygon)):
if coords_crs is None:
# if data_ds.rio.crs.is_valid: # deprecated in rasterio 2.0.0
if data_ds.rio.crs.is_epsg_code:
print("Warning: No `coords_crs` has been specified. It is assumed that coords crs is the same as `data`'s")
coords_crs = data_ds.rio.crs
else:
print("Warning: No valid CRS is defined for `data` nor `coords`. Note that CRS can be passed with `data_crs` and `coords_crs` arguments")
else:
if isinstance(coords, tuple):
coords = Point(coords)
# Newest method (convert to Frame)
coords = gpd.GeoDataFrame([0], geometry = [coords], crs = coords_crs)
coords = reproject(coords, dst_crs = data_ds.rio.crs)
coords = coords.geometry[0]
# Previous conversion method (rasterio.warp.transform)
# =============================================================================
# coords = rasterio.warp.transform(rasterio.crs.CRS.from_epsg(coords_crs),
# rasterio.crs.CRS.from_epsg(data_crs),
# [coords[0]], [coords[1]])
# coords = (coords[0][0], coords[1][0])
# # (to convert a tuple of arrays into a tuple of float)
# =============================================================================
# Case of Point: convert (back) into a tuple
if isinstance(coords, Point):
coords = (coords.x, coords.y)
# Extract
if isinstance(coords, (Polygon, MultiPolygon, gpd.GeoDataFrame, xr.Dataset)):
sel_ds = reproject(data_ds, mask = coords)
elif isinstance(coords, tuple): # Note that Points have been converted into tuples beforehand
if len(set(space_dims).intersection({'lat', 'lon', 'latitude', 'longitude'})) > 0:
print("Caution: When using a tuple as `coords` with a geographic coordinate system, the order should be (longitude, latitude)")
if isinstance(data_ds, xr.Dataset):
sel_ds = data_ds.sel({space_dims[0]: [coords[0]],
space_dims[1]: [coords[1]]},
method = 'nearest')
elif isinstance(data_ds, gpd.GeoDataFrame):
print("Error: Not implemented yet!")
# ---- Post-processing operations
if isinstance(data_ds, xr.Dataset):
# Select specified fields
sel_ds = sel_ds[var_list]
# Aggregation
if mode == 'mean':
# results = data_ds.mean(dim = list(data_ds.dims)[-2:],
# skipna = True, keep_attrs = True)
results = sel_ds.mean(dim = space_dims,
skipna = True, keep_attrs = True)
elif mode == 'sum':
results = sel_ds.sum(dim = space_dims,
skipna = True, keep_attrs = True)
elif mode == 'max':
results = sel_ds.max(dim = space_dims,
skipna = True, keep_attrs = True)
elif mode == 'min':
results = sel_ds.min(dim = space_dims,
skipna = True, keep_attrs = True)
# Time period selection
results = results.sel({time_dim: slice(start_date, end_date)})
# Cumulative sum option
if cumul:
print('\ncumul == True')
timespan = results['time'
].diff(
dim = 'time', label = 'lower')/np.timedelta64(1, 'D')
_var = main_vars(results)
results[_var][dict(time = slice(0, timespan.size))
] = (results[_var][dict(time = slice(0, timespan.size))
] * timespan.values).cumsum(axis = 0)
# Last value:
results[_var][-1] = np.nan
elif isinstance(data_ds, gpd.GeoDataFrame):
# Select specified fields
sel_ds = sel_ds[var_list]
# Aggregation
if mode == 'mean':
# results = data_ds.mean(dim = list(data_ds.dims)[-2:],
# skipna = True, keep_attrs = True)
results = sel_ds.groupby(by = time_dim).mean()
elif mode == 'sum':
results = sel_ds.groupby(by = time_dim).sum()
elif mode == 'max':
results = sel_ds.groupby(by = time_dim).max()
elif mode == 'min':
results = sel_ds.groupby(by = time_dim).min()
# Time period selection
results = results.loc[slice(start_date, end_date)]
# Cumulative sum option
if cumul:
print('\ncumul == True')
if isinstance(results, xr.Dataset):
timespan = results[time_dim
].diff(
dim = time_dim, label = 'lower')/np.timedelta64(1, 'D')
results.loc[slice(0, timespan.size)
] = (results.loc[slice(0, timespan.size)
] * timespan.values).cumsum(axis = 0)
# Last value:
results.iloc[-1] = np.nan
elif isinstance(results, pd.DataFrame):
timespan = results.index.diff()/np.timedelta64(1, 'D')
results = results.mul(timespan, axis = 'index')
# Last value:
results.iloc[-1] = np.nan
# Memory cleaning
del data_ds
del sel_ds
gc.collect()
# ---- Export
# =============================================================================
# # Drop spatial_ref
# if 'spatial_ref' in results.coords or 'spatial_ref' in results.data_vars:
# results = results.drop('spatial_ref')
# =============================================================================
# Drop non-data_vars fields
if isinstance(results, xr.Dataset):
to_discard = list(set(list(results.data_vars)).union(set(list(results.coords))) - set(var_list) - set([time_dim]))
results = results.drop(to_discard)
# Convert to pandas.DataFrame
results = results.to_dataframe()
elif isinstance(results, pd.DataFrame):
to_discard = list(set(results.columns) - set(var_list) - set(time_dim))
results = results.drop(columns = to_discard)
return results
###############################################################################
[docs]
def compare(data1,
data2,
*, mode = "absolute difference",
reprojection = 'finest',
crs1 = None,
crs2 = None):
"""
Compare two spatio-temporal maps by returning a spatio-temporal map of the
difference (absolute or relative, via ``mode`` argument). The 2 input maps
can be heterogeneous (different space and time extents and resolutions,
different variable names).
If the 2 input maps have different resolutions (and CRS), the finest,
coarsest, last or first resolution will be used, depending on ``reprojection``
argument.
If both data have a time axis, the time intersection is kept. If one data
is not tepmporal, its values are broadcasted against the time axis of the
other data.
Parameters
----------
data1 : path (str or pathlib.Path), or variable (xarray.Dataset or xarray.DataArray)
First dataset, that will be the reference to which ``data2`` will be compared.
data2 : path (str or pathlib.Path), or variable (xarray.Dataset, xarray.DataArray)
Second dataset, that will be compared to ``data1``.
mode : {'difference' (or 'absolute difference' or 'absdif'), 'relative difference' (or 'reldif') or 'ratio'}, default 'difference'
Method of comparison:
- 'absolute difference': data2 - data1
- 'relative difference': (data2 - data1) / data1
reprojection : {'finest' (or 'downscale'), 'coarsest' (or 'upscale'), 'last', or 'first'}, default 'finest'
If the two datasets have different resolutions, ``reprojection`` defines
how they will be reprojected.
crs1 : int or str or rasterio.crs.CRS, optional
Coordinate reference system of the first dataset (``data1``), that will be embedded in the ``data1``.
When passed as an *integer*, ``crs1`` refers to the EPSG code.
When passed as a *string*, ``crs1`` can be OGC WKT string or Proj.4 string.
crs2 : int or str or rasterio.crs.CRS, optional
Coordinate reference system of the second dataset (``data2``), that will be embedded in the ``data2``.
When passed as an *integer*, ``crs2`` refers to the EPSG code.
When passed as a *string*, ``crs2`` can be OGC WKT string or Proj.4 string.
Returns
-------
xarray.Dataset
"""
# ---- Initialization
# Format ``mode``
mode = mode.casefold().replace(' ', '').replace('-', '')
# Load datasets
data_ds1 = load_any(data1)
data_ds2 = load_any(data2)
# Handle Coordinate Reference Systems
if crs1 is None:
crs1 = data_ds1.rio.crs
if crs2 is None:
crs2 = data_ds2.rio.crs
if crs1 is None:
if crs2 is not None:
print("Warning: `data1` has no CRS embedded. The CRS of `data2` is used instead. Consider passing a value for the `crs1` argument.")
data_ds1 = georef(data_ds1, crs = crs2)
else:
print("Err: `data1` and `data2` have no CRS embedded. Consider passing a value for the `crs1` and `crs2` arguments.")
# geo.reproject() requires a CRS
return
else:
data_ds1 = georef(data_ds1, crs = crs1)
if crs2 is None:
if crs1 is not None:
print("Warning: `data2` has no CRS embedded. The CRS of `data1` is used instead. Consider passing a value for the `crs2` argument.")
data_ds2 = georef(data_ds2, crs = crs1)
else:
print("Err: `data1` and `data2` have no CRS embedded. Consider passing a value for the `crs1` and `crs2` arguments.")
# geo.reproject() requires a CRS
return
else:
data_ds1 = georef(data_ds1, crs = crs1)
# ---- Reprojection
if reprojection in ['finest', 'downscale']:
# If one of the two datasets is geographic, and the other is reprojected,
# it is assumed that the reprojected dataset has the finest resolution
if (data_ds1.rio.crs.is_geographic) and not (data_ds2.rio.crs.is_geographic):
data_ds1 = transform(data_ds1, base_template = data_ds2)
elif not (data_ds1.rio.crs.is_geographic) and (data_ds2.rio.crs.is_geographic):
data_ds2 = transform(data_ds2, base_template = data_ds1)
else:
# if both datasets are reprojected or geographic, the finest resolution is determined from x_res*y_res
if (data_ds1.rio.resolution()[0] * data_ds1.rio.resolution()[1]) > (data_ds2.rio.resolution()[0] * data_ds2.rio.resolution()[1]):
data_ds1 = transform(data_ds1, base_template = data_ds2)
elif (data_ds1.rio.resolution()[0] * data_ds1.rio.resolution()[1]) < (data_ds2.rio.resolution()[0] * data_ds2.rio.resolution()[1]):
data_ds2 = transform(data_ds2, base_template = data_ds1)
else: # in case of equality, the smallest dataset will serve as the reference
if (data_ds1.rio.bounds()[2] - data_ds1.rio.bounds()[0]) * (data_ds1.rio.bounds()[3] - data_ds1.rio.bounds()[1]) \
< (data_ds2.rio.bounds()[2] - data_ds2.rio.bounds()[0]) * (data_ds2.rio.bounds()[3] - data_ds2.rio.bounds()[1]):
data_ds2 = transform(data_ds2, base_template = data_ds1)
else:
data_ds1 = transform(data_ds1, base_template = data_ds2)
elif reprojection in ['coarsest', 'upscale']:
# If one of the two datasets is geographic, and the other is reprojected,
# it is assumed that the geographic dataset has the coarsest resolution
if (data_ds1.rio.crs.is_geographic) and not (data_ds2.rio.crs.is_geographic):
data_ds2 = transform(data_ds2, base_template = data_ds1)
elif not (data_ds1.rio.crs.is_geographic) and (data_ds2.rio.crs.is_geographic):
data_ds1 = transform(data_ds1, base_template = data_ds2)
else:
# if both datasets are reprojected or geographic, the coarsest resolution is determined from x_res*y_res
if (data_ds1.rio.resolution()[0] * data_ds1.rio.resolution()[1]) > (data_ds2.rio.resolution()[0] * data_ds2.rio.resolution()[1]):
data_ds2 = transform(data_ds2, base_template = data_ds1)
elif (data_ds1.rio.resolution()[0] * data_ds1.rio.resolution()[1]) < (data_ds2.rio.resolution()[0] * data_ds2.rio.resolution()[1]):
data_ds1 = transform(data_ds1, base_template = data_ds2)
else: # in case of equality, the largest dataset will serve as the reference
if (data_ds1.rio.bounds()[2] - data_ds1.rio.bounds()[0]) * (data_ds1.rio.bounds()[3] - data_ds1.rio.bounds()[1]) \
> (data_ds2.rio.bounds()[2] - data_ds2.rio.bounds()[0]) * (data_ds2.rio.bounds()[3] - data_ds2.rio.bounds()[1]):
data_ds2 = transform(data_ds2, base_template = data_ds1)
else:
data_ds1 = transform(data_ds1, base_template = data_ds2)
elif reprojection == 'last':
data_ds1 = transform(data_ds1, base_templace = data_ds2)
elif reprojection == 'first':
data_ds2 = transform(data_ds2, base_templace = data_ds1)
else:
print(f"Err: the `reprojection` '{reprojection}' is not recognized. Implemented modes are: 'finest' (or 'downscale'), 'coarsest' (or 'upscale'), 'last' or 'first'")
return
# ---- Rename variables
ds1_varnames = main_vars(data_ds1)
ds2_varnames = main_vars(data_ds2)
if set(ds1_varnames) != set(ds2_varnames):
if len(ds1_varnames) != len(ds2_varnames):
print("`data1` and `data2` have a different number of variables, with different names."
"Only the variables with the same name (if any) are compared.")
else: # if len(ds1_varnames) == len(ds2_varnames)
print("Variable names of `data2` will be remapped as follow:")
print("`data1` <-- `data2`")
remap_list = [f"{ds1_varnames[i]} <-- {ds2_varnames[i]}" for i in range(len(ds1_varnames))]
print(' - ' + '\n - '.join(sorted(remap_list)) + '\n')
data_ds2 = data_ds2.rename({ds2_varnames[i]: ds1_varnames[i] for i in range(len(ds1_varnames))})
# ---- Time management
t1 = main_time_dims(data_ds1)[0]
t2 = main_time_dims(data_ds2)[0]
# - If t1 is a time axis and t2 is inexistent or contains only one time,
# then the value of data_ds2 is broadcasted against t1
# - Same for t2.
# - If t1 and t2 are both time axes, the time intersction is kept.
if t1 is not None:
if len(data_ds1[t1]) > 1:
if t2 is None:
data_ds2 = data_ds2.broadcast_like(data_ds1)
else:
if len(data_ds2[t2]) > 1:
time_intersect = data_ds1.indexes[t1].intersection(data_ds2.indexes[t2])
data_ds1 = data_ds1.loc[{t1: time_intersect}]
data_ds2 = data_ds2.loc[{t2: time_intersect}]
else: # if len(data_ds2[t2]) <= 1
data_ds2 = data_ds2[{t2: 0}].broadcast_like(data_ds1)
else: # if len(data_ds1[t1]) <= 1
if t2 is not None:
data_ds1 = data_ds1.broadcast_like(data_ds2)
else: # if t1 is None
if t2 is None:
data_ds1 = data_ds1.broadcast_like(data_ds2)
else: # if t2 is not None
if len(data_ds2[t2]) > 1:
time_intersect = data_ds1.indexes[t1].intersection(data_ds2.indexes[t2])
data_ds1 = data_ds1.loc[{t1: time_intersect}]
data_ds2 = data_ds2.loc[{t2: time_intersect}]
else: # if len(data_ds2[t2]) <= 1
data_ds2 = data_ds2[{t2: 0}].broadcast_like(data_ds1)
# ---- Compute difference
if mode in ['difference', 'absolutedifference',
'absolutediff', 'absolutedif', 'absdiff', 'absdif']:
comp_ds = data_ds2 - data_ds1
comp_ds = comp_ds.rename({var: '_'.join([var, 'absdiff']) for var in ds1_varnames})
elif mode in ['relativedifference', 'relativediff', 'relativedif',
'reldiff', 'reldif']:
comp_ds = (data_ds2 - data_ds1) / data_ds1
comp_ds = comp_ds.where(data_ds1 != 0, np.nan)
comp_ds = comp_ds.rename({var: '_'.join([var, 'reldiff']) for var in ds1_varnames})
elif mode in ['ratio']:
comp_ds = data_ds2 / data_ds1
comp_ds = comp_ds.where(data_ds1 != 0, np.nan)
comp_ds = comp_ds.rename({var: '_'.join([var, 'ratio']) for var in ds1_varnames})
else:
print(f"Err: the `mode` '{mode}' is not recognized. The implemented modes are: 'difference' (or 'absolute difference' or 'absdif'), 'relative difference' (or 'reldif') or 'ratio'")
return
# Corrections
# comp_ds = comp_ds.fillna(np.nan)
# comp_ds = comp_ds.where(abs(comp_ds) != np.inf, np.nan)
comp_ds = georef(comp_ds)
return comp_ds
###############################################################################
#%%% * xr.DataSet to DataFrame
def xr_to_pd(xr_data):
"""
Format xr objects (such as those from gc.time_series) into pandas.DataFrames
formatted as in gc.tss_to_dataframe.
Parameters
----------
xr_data : xarray.DataSet or xarray.DataArary
Initial data to convert into pandas.DataFrame
NB: xr_data needs to have only one dimension.
Returns
-------
Pandas.DataFrame
"""
print("\n_Infos...")
if type(xr_data) == xr.core.dataset.Dataset:
var_list = main_vars(xr_data)
print(f" Data is a xr.Dataset, with {', '.join(var_list)} as the main variables")
xr_data = xr_data[var_list]
elif type(xr_data) == xr.core.dataarray.DataArray:
print(" Data is a xr.Dataarray")
res = xr_data.to_dataframe(name = 'val')
res = res[['val']]
res['time'] = pd.to_datetime(res.index)
if not res.time.dt.tz:
print(" The timezone is not defined")
# res['time'] = res.time.dt.tz_localize('UTC')
res.index = range(0, res.shape[0])
# =============================================================================
# res['id'] = res.index
# =============================================================================
print('') # newline
return res
###############################################################################
#%%% ° tss_to_dataframe
def tss_to_dataframe(*, input_file, skip_rows, start_date, cumul = False):
"""
Example
-------
base = gc.tss_to_dataframe(input_file = r"D:\2- Postdoc\2- Travaux\3_CWatM_EBR\results\raw_results\001_prelim_cotech\2022-03-19_base\discharge_daily.tss",
skip_rows = 4,
start_date = '1991-08-01')
precip = gc.tss_to_dataframe(input_file = r"D:\2- Postdoc\2- Travaux\3_CWatM_EBR\results\raw_results\003_artif\2022-03-25_base\Precipitation_daily.tss",
skip_rows = 4,
start_date = '2000-01-01')
precip.val = precip.val*534000000/86400
# (le BV topographique du Meu fait 471 851 238 m2)
precip['rolling_mean'] = precip['val'].rolling(10).mean()
Parameters
----------
input_file : str
Chemin d'accès au fichier d'entrée
skip_rows : int
Nombre de lignes à enlever en tête de fichier. /!\ ce nombre n'est '
start_date : str ou datetime
Date de la 1re valeur du fichier
/!\ Si str, il faut qu'elle soit au format "%Y-%m-%d"
Returns
-------
df : pandas.DataFrame
Implémentations futures
-----------------------
Récupérer la start_date à partir du fichier de settings indiqué au début
du fichier *.tss., et regarder ensuite le SpinUp
"""
#%% Récupération des inputs :
# ---------------------------
if start_date == 'std':
print('> Pas encore implémenté...')
# récupérer la start_date du fichier de settings
else:
start_date = pd.to_datetime(start_date)
# print('> start_date = ' + str(start_date))
#%% Création des dataframes :
# ---------------------------
# df = pd.read_csv(input_file, sep = r"\s+", header = 0, names = ['id', 'val'], skiprows = skip_rows)
if skip_rows == 0: # Cas des fichiers de débits *.css, sans lignes d'info,
# avec seulement un header
_fields = ['']
n_col = 1
else: # Cas des outputs *.tss avec plusieurs lignes d'info, la 2e ligne
# indiquant le nombre de colonnes. Dans ce cas, skip_rows doit être
# égal à 2.
with open(input_file) as temp_file:
# temp_file.readline()
# n_col = int(temp_file.readline()[0])
content = temp_file.readlines()
n_col = int(content[skip_rows-1][0])
_fields = [str(n) for n in range(1, n_col)]
_fields[0] = ''
df = pd.read_csv(input_file,
sep = r"\s+",
header = 0,
skiprows = skip_rows -1 + n_col,
names = ['id'] + ['val' + ending for ending in _fields],
)
# Si la colonne id contient déjà des dates (au format texte ou datetime) :
if type(df.id[0]) in [str,
pd.core.indexes.datetimes.DatetimeIndex,
pd._libs.tslibs.timestamps.Timestamp]:
df['time'] = pd.to_datetime(df.id)
# Sinon (= si la colonne id contient des indices), il faut reconstruire les dates :
else:
date_indexes = pd.date_range(start = start_date,
periods = df.shape[0], freq = '1D')
df['time'] = date_indexes
if cumul:
print('\ncumul == True')
# Values are expected to be expressed in [.../d]
# Cumsum is applied on all columns with values ('val', 'val2', 'val3', ...)
timespan = df.loc[
:, df.columns == 'time'
].diff().shift(-1, fill_value = 0)/np.timedelta64(1, 'D')
# timespan = df.loc[
# :, df.columns == 'time'
# ].diff()/np.timedelta64(1, 'D')
df.iloc[
:].loc[:, (df.columns != 'id') * (df.columns != 'time')
] = (df.iloc[
:].loc[:, (df.columns != 'id') * (df.columns != 'time')
] * timespan.values).cumsum(axis = 0)
# Last value
# df.iloc[-1].loc[:, (df.columns != 'id') * (df.columns != 'time')] = np.nan
# np.diff(df.time)/np.timedelta64(1, 'D')
return df
#%% MNT & WATERSHEDS OPERATIONS
###############################################################################
def extract_watersheds(*,
ldd,
outlets,
output_path,
dirmap = '1-9',
engine: str = 'pysheds',
src_crs = None,
separate_files = True,
extension = None,
auto_snap = True, # 3 pixels
snap_distance = None,
min_area_km2 = 0.0,
prefix = '',
sort_by_area = True,
compute_stats = True):
"""
Extract multiple watersheds from outlet information.
Parameters
----------
ldd : str, pathlib.Path, xarray.Dataset or xarray.DataArray
Local Drain Direction raster data.
outlets : str, pathlib.Path, xarray.Dataset, geopandas.GeoDataFrame or list of tuples
Outlet information in various formats:
- CSV file with columns: watershed_name, x_outlet, y_outlet, [snap_dist], [buff_percent], [crs_proj]
- Raster file: pixels with value=1 will be used as outlets (ignoring 0 and other values)
- Vector file (shp/gpkg): point features or polygon centroids will be used as outlets
- List of tuples: [(name, x, y), (name, x, y, snap_dist, buff_percent), ...]
output_path : str or pathlib.Path
Output directory path (folder only, not filename). Filenames are auto-generated.
dirmap : tuple or str, optional, default '1-9'
Direction codes convention (same as extract_watershed).
engine : str, optional, default 'pysheds'
Processing engine to use.
src_crs : int or str, optional
Source CRS for LDD if not embedded.
separate_files : bool, optional, default True
If True: creates one file per watershed (e.g., "BassinNord.shp", "BassinSud.shp")
If False: creates one combined file with all watersheds (e.g., "all_watersheds.shp")
extension : str, optional
Output file extension (e.g., '.shp', '.gpkg', '.tif', '.nc').
If None, automatically detected from first outlet or defaults to '.shp' for vector, '.tif' for raster.
auto_snap : bool, optional, default True
If True, automatically calculate snap distance based on LDD resolution.
If False, uses snap_dist from outlets data or manual snap_distance.
snap_distance : float, optional, default None
Manual snap distance in meters. If provided, overrides auto_snap and individual outlet snap_dist.
If None, uses auto_snap or individual outlet snap_dist values.
min_area_km2 : float, optional, default 0.0
Minimum area in km² for a watershed to be included in results.
prefix : str, optional, default ''
Prefix for output filenames.
sort_by_area : bool, default True
If True, sort watersheds by area (smallest first) for better display in GIS software.
Smaller watersheds will appear on top of larger ones.
compute_stats : bool, default True
Whether to compute additional statistics (area, perimeter, etc.).
Returns
-------
dict
Dictionary containing:
- 'watersheds': geopandas.GeoDataFrame or xarray.Dataset with watershed data
- 'outlets_used': Final outlet coordinates after processing
- 'files_created': List of output file paths created
- 'statistics': Summary statistics of extraction process
Warnings
--------
- Raises warning if separate_files=False and raster output with potential overlapping watersheds
- Validates CRS compatibility between LDD and outlets
- Checks if outlets are within LDD domain
"""
print(f"\nExtracting watersheds from outlets: {outlets}")
# ---- Initialize result tracking
outlets_used = []
files_created = []
statistics = {
'total_outlets': 0,
'successful_extractions': 0,
'failed_extractions': 0,
'total_area_km2': 0.0,
'warnings': []
}
# ---- Validate and create output directory
output_path = Path(output_path)
if not output_path.exists():
output_path.mkdir(parents=True, exist_ok=True)
print(f" _ Created output directory: {output_path}")
elif output_path.is_file():
raise ValueError(f"output_path must be a directory, not a file: {output_path}")
# ---- Parse input outlets from various formats
if isinstance(outlets, (str, Path)):
outlets_path = Path(outlets)
if not outlets_path.exists():
raise FileNotFoundError(f"Outlets file not found: {outlets_path}")
outlet_ext = outlets_path.suffix.lower()
if outlet_ext == '.csv':
# Use load_any which handles separator auto-detection
outlets_df = load_any(outlets_path)
outlets_df.columns = outlets_df.columns.str.strip()
outlets_df['x_outlet'] = pd.to_numeric(outlets_df['x_outlet'])
outlets_df['y_outlet'] = pd.to_numeric(outlets_df['y_outlet'])
if 'snap_dist' not in outlets_df.columns:
outlets_df['snap_dist'] = 0
if 'buff_percent' not in outlets_df.columns:
outlets_df['buff_percent'] = 0
elif outlet_ext in ['.shp', '.gpkg', '.geojson']:
gdf = load_any(outlets_path)
if src_crs and gdf.crs is None:
gdf.set_crs(src_crs, inplace=True, allow_override=True)
outlets_data = []
for idx, row in gdf.iterrows():
geom = row.geometry
if geom.geom_type == 'Point':
x, y = geom.x, geom.y
elif geom.geom_type in ['Polygon', 'MultiPolygon']:
centroid = geom.centroid
x, y = centroid.x, centroid.y
else:
continue
name_cols = ['name', 'watershed_name', 'basin_name', 'id', 'NAME', 'ID']
watershed_name = None
for col in name_cols:
if col in gdf.columns:
watershed_name = str(row[col])
break
if watershed_name is None:
watershed_name = f"Outlet_{idx+1}"
outlets_data.append({'watershed_name': watershed_name, 'x_outlet': x, 'y_outlet': y, 'snap_dist': 0, 'buff_percent': 0})
outlets_df = pd.DataFrame(outlets_data)
elif outlet_ext in ['.tif', '.nc', '.asc']:
raster_ds = load_any(outlets_path)
if src_crs and raster_ds.rio.crs is None:
raster_ds.rio.write_crs(src_crs, inplace=True)
var_name = main_vars(raster_ds)[0]
raster_data = raster_ds[var_name]
y_coords, x_coords = np.where(raster_data.values == 1)
if len(y_coords) == 0:
raise ValueError("No pixels with value = 1 found in raster outlets")
outlets_data = []
for i, (y_idx, x_idx) in enumerate(zip(y_coords, x_coords)):
x = float(raster_data.x[x_idx].values)
y = float(raster_data.y[y_idx].values)
outlets_data.append({'watershed_name': f"Outlet_{i+1}", 'x_outlet': x, 'y_outlet': y, 'snap_dist': 0, 'buff_percent': 0})
outlets_df = pd.DataFrame(outlets_data)
print(f" _ Extracted {len(outlets_df)} outlets from raster (pixels with value=1)")
else:
raise ValueError(f"Unsupported outlets file format: {outlet_ext}")
elif isinstance(outlets, list):
outlets_data = []
for outlet in outlets:
if len(outlet) == 3:
name, x, y = outlet
outlets_data.append({'watershed_name': name, 'x_outlet': x, 'y_outlet': y, 'snap_dist': 0, 'buff_percent': 0})
elif len(outlet) == 5:
name, x, y, snap_dist, buff_percent = outlet
outlets_data.append({'watershed_name': name, 'x_outlet': x, 'y_outlet': y, 'snap_dist': snap_dist, 'buff_percent': buff_percent})
else:
raise ValueError(f"Invalid outlet tuple length: {len(outlet)}. Expected 3 or 5 elements.")
outlets_df = pd.DataFrame(outlets_data)
else:
raise ValueError("outlets must be a file path or list of tuples")
# ---- Auto-detect extension if not provided
if extension is None:
if isinstance(outlets, (str, Path)) and os.path.exists(outlets):
test_ext = os.path.splitext(outlets)[-1].lower()
if test_ext in ['.shp', '.gpkg', '.geojson']:
extension = '.shp'
elif test_ext in ['.tif', '.nc']:
extension = '.tif'
else:
extension = '.shp'
else:
extension = '.shp'
# ---- Determine output format from extension
raster_extensions = ['.tif', '.nc', '.asc']
vector_extensions = ['.shp', '.gpkg', '.geojson']
if extension.lower() in raster_extensions:
output_format = 'raster'
if not separate_files:
print(" _ Warning: Raster output with separate_files=False may cause issues if watersheds overlap")
statistics['warnings'].append("Potential watershed overlap in combined raster output")
elif extension.lower() in vector_extensions:
output_format = 'vector'
else:
raise ValueError(f"Unsupported extension: {extension}. Supported: {raster_extensions + vector_extensions}")
print(f" _ Output format: {output_format} ({extension})")
print(f" _ Separate files: {separate_files}")
print(f" _ Found {len(outlets_df)} outlets to process")
# ---- Specify directional mapping
if isinstance(dirmap, str):
dirmap = dirmap.casefold().replace(' ', '').replace('-', '')
if dirmap in ['19', '[19]', 'keypad', 'pcraster']:
dirmap = (8, 9, 6, 3, 2, 1, 4, 7)
elif dirmap in ['d8', 'esri']:
dirmap = (64, 128, 1, 2, 4, 8, 16, 32)
elif dirmap in ['d8wbt', 'wbt', 'whiteboxtools']:
dirmap = (128, 1, 2, 4, 8, 16, 32, 64)
else:
dirmap = (8, 9, 6, 3, 2, 1, 4, 7)
# ---- Loading and validating LDD
ds = load_any(ldd, decode_coords='all')
if src_crs is not None:
ds.rio.write_crs(src_crs, inplace=True)
else:
if ds.rio.crs is None:
raise ValueError("Coordinate Reference System is required for LDD")
# ---- Validate outlets are within LDD domain
bounds = ds.rio.bounds()
x_min, y_min, x_max, y_max = bounds
outside_domain = 0
for _, row in outlets_df.iterrows():
x, y = row['x_outlet'], row['y_outlet']
if not (x_min <= x <= x_max and y_min <= y <= y_max):
outside_domain += 1
print(f"Warning: Outlet '{row['watershed_name']}' is outside LDD domain")
if outside_domain > 0:
warning_msg = f"{outside_domain} outlets are outside LDD domain"
statistics['warnings'].append(warning_msg)
print(f" _ Warning: {warning_msg}")
ds, nodata = standardize_fill_value(ds)
var = main_vars(ds)[0]
print(f"Drain direction variable is inferred to be {var}")
x_var, y_var = main_space_dims(ds)[0]
encod = ds[var].encoding
# Replacing nan with appropriate nodata value
print(f" _ LDD variable: {var}")
# ---- Calculate snap distance
if snap_distance is not None:
# Manual snap distance overrides everything
print(f" _ Manual snap distance: {snap_distance:.1f}m")
outlets_df['snap_dist'] = float(snap_distance)
elif auto_snap:
# Auto-calculated snap distance
pixel_size = abs(ds.rio.resolution()[0])
auto_snap_dist = pixel_size * 3
print(f" _ Auto-snap distance: {auto_snap_dist:.1f}m (3x pixel size)")
outlets_df['snap_dist'] = outlets_df['snap_dist'].fillna(auto_snap_dist)
outlets_df.loc[outlets_df['snap_dist'] == 0, 'snap_dist'] = float(auto_snap_dist)
else:
# Use individual outlet snap_dist values or 0
outlets_df['snap_dist'] = outlets_df['snap_dist'].fillna(0)
print(f" _ Using individual outlet snap distances")
# ---- Prepare nodata value
std_nodata = min(dirmap) - 4
if np.isnan(nodata):
nodata = std_nodata
else:
if (not np.int32(nodata) == nodata) or (nodata in dirmap):
nodata = std_nodata
else:
nodata = np.int32(nodata)
ds[var] = ds[var].fillna(nodata)
# ---- Setup pysheds grid
if engine.casefold() in ["pyshed", "pysheds"]:
print(' _ Using Pysheds engine...')
viewfinder = ViewFinder(affine=ds.rio.transform(), shape=ds.rio.shape, crs=ds.rio.crs, nodata=np.int32(nodata))
ldd_raster = Raster(ds[var].astype(np.int32).data, viewfinder=viewfinder)
grid = Grid.from_raster(ldd_raster)
print(' _ Computing flow accumulation...')
acc = grid.accumulation(ldd_raster, dirmap=dirmap, nodata_out=np.int32(-1))
# ---- Process each watershed
watersheds_list = []
watersheds_raster_dict = {} if output_format == 'raster' else None
x_var, y_var = main_space_dims(ds)[0]
for _, row in outlets_df.iterrows():
watershed_name = str(row['watershed_name']).strip('"')
x_outlet = row['x_outlet']
y_outlet = row['y_outlet']
snap_dist = row.get('snap_dist', 0)
print(f" _ Processing watershed: {watershed_name}")
# ---- Snap outlet if needed
try:
col, row_idx = grid.nearest_cell(x_outlet, y_outlet, snap='center')
current_acc = acc[row_idx, col]
if current_acc < 15 and snap_dist > 0:
print(f" . Snapping outlet within {snap_dist:.1f}m...")
try:
x_snap, y_snap = grid.snap_to_mask(acc > 15, (x_outlet, y_outlet))
x_snap += ds.rio.resolution()[0]/2
y_snap += ds.rio.resolution()[1]/2
print(f" . Outlet snapped from ({x_outlet:.1f}, {y_outlet:.1f}) to ({x_snap:.1f}, {y_snap:.1f})")
x_outlet, y_outlet = x_snap, y_snap
except Exception as e:
print(f" . Warning: Could not snap outlet for {watershed_name}: {e}")
statistics['warnings'].append(f"Failed to snap outlet {watershed_name}")
except Exception as e:
print(f" . Error processing outlet location for {watershed_name}: {e}")
statistics['failed_extractions'] += 1
continue
# ---- Extract watershed
try:
shed = grid.catchment(x=x_outlet, y=y_outlet, fdir=ldd_raster, xytype='coordinate',
nodata_out=np.bool_(False), dirmap=dirmap, snap='center')
shed_ds = ds.copy()
shed_ds[var] = ([y_var, x_var], np.array(shed).astype(np.int8))
shed_ds = shed_ds.rename({var: 'watershed'})
shed_ds = shed_ds.where(shed_ds.watershed > 0, drop=True)
if len(shed_ds.watershed.values.flatten()) == 0:
print(f" . Warning: Empty watershed for {watershed_name}")
statistics['failed_extractions'] += 1
continue
mask = shed_ds.watershed.values > 0
pixel_area = abs(ds.rio.resolution()[0] * ds.rio.resolution()[1])
n_pixels = np.sum(mask)
area_km2 = (n_pixels * pixel_area) / 1e6
if area_km2 < min_area_km2:
print(f" . Warning: Watershed {watershed_name} too small ({area_km2:.3f} km² < {min_area_km2} km²)")
statistics['failed_extractions'] += 1
continue
if output_format == 'raster':
watershed_raster = shed_ds.watershed.copy()
watershed_raster.attrs.update({'watershed_name': watershed_name, 'x_outlet': x_outlet, 'y_outlet': y_outlet, 'area_km2': area_km2, 'n_pixels': int(n_pixels)})
watershed_ds_georef = georef(shed_ds)
watershed_raster = watershed_ds_georef.watershed.copy()
watershed_raster.attrs.update({'watershed_name': watershed_name, 'x_outlet': x_outlet, 'y_outlet': y_outlet, 'area_km2': area_km2, 'n_pixels': int(n_pixels)})
watersheds_raster_dict[watershed_name] = watershed_raster
print(f" . Success: {area_km2:.2f} km² (raster)")
else:
transform = shed_ds.rio.transform()
shapes = list(rasterio.features.shapes(mask.astype(np.uint8), mask=mask, transform=transform))
if not shapes:
print(f" . Warning: Could not vectorize {watershed_name}")
statistics['failed_extractions'] += 1
continue
geometries = [sg.shape(geom) for geom, value in shapes if value == 1]
if not geometries:
print(f" . Warning: No valid geometries for {watershed_name}")
statistics['failed_extractions'] += 1
continue
main_geom = max(geometries, key=lambda x: x.area)
# Use geometry area for more accurate calculation
geom_area_km2 = main_geom.area / 1e6
watershed_info = {'name': watershed_name, 'x_outlet': x_outlet, 'y_outlet': y_outlet, 'area_km2': geom_area_km2, 'n_pixels': n_pixels, 'geometry': main_geom}
if compute_stats:
watershed_info.update({
'perimeter_km': main_geom.length / 1000,
'compactness': (4 * np.pi * main_geom.area) / (main_geom.length ** 2),
'bounds_north': main_geom.bounds[3], 'bounds_south': main_geom.bounds[1],
'bounds_east': main_geom.bounds[2], 'bounds_west': main_geom.bounds[0]
})
watersheds_list.append(watershed_info)
print(f" . Success: {geom_area_km2:.2f} km²")
statistics['successful_extractions'] += 1
statistics['total_area_km2'] += geom_area_km2 if output_format == 'vector' else area_km2
outlets_used.append((watershed_name, x_outlet, y_outlet))
except Exception as e:
print(f" . Error extracting watershed {watershed_name}: {e}")
statistics['failed_extractions'] += 1
continue
# ---- WhiteBox Tools engine (commented out for now)
elif engine.casefold() in ["wbt", "whiteboxtools", "whitebox"]:
print(' _ Using WhiteBox Tools engine...')
print(' This engine is not implemented yet.')
# ---- Avec WhiteToolBox (deprecated)
# ======== discontinued =======================================================
# elif engine.casefold() in ["wtb", "whitetoolbox"]:
# print('WhiteToolBox engine...')
#
# wbt.watershed(
# d8_path,
# outlets_file,
# os.path.join(os.path.split(d8_path)[0], "mask_bassin_xxx_wbt.tif"),
# esri_pntr = True,
# )
# =============================================================================
# ---- Create output dataset
if output_format == 'raster':
if not watersheds_raster_dict:
print("No watersheds were successfully extracted!")
return {'watersheds': None, 'outlets_used': outlets_used, 'files_created': files_created, 'statistics': statistics}
if sort_by_area and len(watersheds_raster_dict) > 1:
watersheds_raster_dict = dict(sorted(watersheds_raster_dict.items(),
key=lambda item: item[1].attrs.get('area_km2', 0), reverse=True))
print(f" _ Sorted watersheds by area (largest first, so smallest appear on top)")
if separate_files:
result = watersheds_raster_dict
else:
# Use the original LDD grid as the template for combined output
watersheds_ds = ds.copy()
watersheds_ds = watersheds_ds.drop_vars([var])
# Initialize combined mask with zeros, same shape as LDD
ldd_data = ds[var]
combined_mask = np.zeros_like(ldd_data.values, dtype=np.int32)
watershed_id = 1
for watershed_name, watershed_raster in watersheds_raster_dict.items():
# Ensure watershed raster has same CRS as LDD data
if watershed_raster.rio.crs is None:
watershed_raster = watershed_raster.rio.write_crs(ldd_data.rio.crs)
# Create mask for this watershed on the original LDD grid
if watershed_raster.ndim > 1 and watershed_raster.shape != ldd_data.shape:
# Multi-dimensional watershed with different shape - resample to LDD grid
resampled = watershed_raster.rio.reproject_match(ldd_data)
mask = resampled.values > 0
else:
# Same shape or single pixel - use directly
if watershed_raster.ndim == 1:
# Single pixel watershed - need to broadcast to LDD shape
mask = np.zeros_like(ldd_data.values, dtype=bool)
# Find the position in the original watershed coordinate system
try:
# Get the coordinates from the watershed if available
if hasattr(watershed_raster, 'x') and hasattr(watershed_raster, 'y'):
x_coord = float(watershed_raster.x.values)
y_coord = float(watershed_raster.y.values)
col, row_idx = ldd_data.rio.index(x_coord, y_coord)
if 0 <= row_idx < mask.shape[0] and 0 <= col < mask.shape[1]:
mask[row_idx, col] = bool(watershed_raster.values > 0)
except:
# Fallback: skip single pixel watersheds that can't be located
continue
else:
mask = watershed_raster.values > 0
# Add watershed to combined mask where no other watershed exists
watershed_data = np.where(mask, watershed_id, 0)
combined_mask = np.where((combined_mask == 0) & (watershed_data > 0), watershed_data, combined_mask)
watershed_id += 1
# Create combined raster with same coordinates as LDD
combined_raster = ldd_data.copy()
combined_raster.values = combined_mask
combined_raster.attrs = {'long_name': 'Combined watersheds', 'description': 'Watershed ID values (1, 2, 3, ...)', 'nodata': 0}
watersheds_ds['watersheds'] = combined_raster
result = georef(watersheds_ds)
print(f" _ Combined {len(watersheds_raster_dict)} watersheds into single raster")
else:
if not watersheds_list:
print("No watersheds were successfully extracted!")
return {'watersheds': None, 'outlets_used': outlets_used, 'files_created': files_created, 'statistics': statistics}
result = gpd.GeoDataFrame(watersheds_list, crs=ds.rio.crs)
if sort_by_area and len(result) > 1:
result = result.sort_values('area_km2', ascending=False).reset_index(drop=True)
print(f" _ Sorted watersheds by area (largest first, so smallest appear on top)")
print(f" _ Successfully extracted {len(result)} watersheds")
print(f" _ Total area: {result['area_km2'].sum():.2f} km²")
# ---- Export files
try:
if separate_files:
if output_format == 'raster':
for watershed_name, watershed_raster in result.items():
safe_name = watershed_name.replace(' ', '_').replace('-', '_')
safe_name = ''.join(c for c in safe_name if c.isalnum() or c in ['_', '-'])
filename = f"{prefix}{safe_name}{extension}"
filepath = output_path / filename
export_ds = watershed_raster.to_dataset(name='watershed')
export(export_ds, str(filepath))
files_created.append(str(filepath))
else:
for _, row in result.iterrows():
watershed_name = row['name']
safe_name = watershed_name.replace(' ', '_').replace('-', '_')
safe_name = ''.join(c for c in safe_name if c.isalnum() or c in ['_', '-'])
filename = f"{prefix}{safe_name}{extension}"
filepath = output_path / filename
single_gdf = gpd.GeoDataFrame([row], crs=result.crs)
export(single_gdf, str(filepath))
files_created.append(str(filepath))
else:
if output_format == 'raster':
filename = f"{prefix}all_watersheds{extension}"
else:
filename = f"{prefix}watersheds_combined{extension}"
filepath = output_path / filename
export(result, str(filepath))
files_created.append(str(filepath))
print(f" _ Files exported to: {output_path}")
for file in files_created:
print(f" • {Path(file).name}")
except Exception as e:
error_msg = f"Export failed: {e}"
print(f" _ Error: {error_msg}")
statistics['warnings'].append(error_msg)
# ---- Final statistics
statistics['total_outlets'] = len(outlets_df)
print(f"\n _ Extraction complete:")
print(f" • Successful: {statistics['successful_extractions']}/{statistics['total_outlets']}")
print(f" • Total area: {statistics['total_area_km2']:.2f} km²")
print(f" • Files created: {len(files_created)}")
if statistics['warnings']:
print(f" • Warnings: {len(statistics['warnings'])}")
return {
'watersheds': result,
'outlets_used': outlets_used,
'files_created': files_created,
'statistics': statistics
}
###############################################################################
def compute_ldd(dem_path,
dirmap = '1-9',
engine:str = 'pysheds',
src_crs = None):
"""
Convert a Digital Elevation Model (DEM) into a Local Drain Direction map (LDD).
Parameters
----------
dem_path : str, Path, xarray.Dataset or xarray.DataArray
Digital Elevation Model data. Supported file formats are *.tif*, *.asc* and *.nc*.
dirmap : tuple or str, optional, default '1-9'
Direction codes convention.
- ``'1-9'`` (or ``'keypad'``, or ``'pcraster'``): from 1 to 9, upward,
from bottom-left corner, no-flow is 5 [pcraster convention]
- ``'D8'`` (or ``'ESRI'``): from 1 to 128 (base-2), clockwise, from
middle-right position, no-flow is 0 [esri convention]
- ``'D8-WBT'`` (or ``'WhiteBoxTools'``): from 1 to 128 (base-2),
clockwise, from top-right corner, no-flow is 0 [WhiteBoxTools convention]
engine : {'pysheds', 'whiteboxtools'}, optional, default 'pyshed'
``'whiteboxtools'`` has been deactivated to avoid the need to install whiteboxtools.
Returns
-------
LDD raster, xarray.Dataset.
"""
# ---- With pysheds
if engine.casefold() in ["pyshed", "pysheds"]:
"""
Adapted from Luca Guillaumot's work
"""
print('Pysheds engine...')
# Load the pysheds elements (grid & data)
# ===== obsolete: load from file ==============================================
# ext = os.path.splitext(dem_path)[-1]
# if ext == '.tif':
# grid = Grid.from_raster(dem_path, data_name = 'dem')
# dem = grid.read_raster(dem_path)
# elif ext == '.asc':
# grid = Grid.from_ascii(dem_path)
# dem = grid.read_ascii(dem_path)
# =============================================================================
ds = load_any(dem_path, decode_coords = 'all')
if src_crs is not None:
ds.rio.write_crs(src_crs, inplace = True)
else:
if ds.rio.crs is None:
print("Err: The Coordinate Reference System is required. It should be embedded in the input DEM or passed with the 'src_crs' argument")
return
ds, nodata = standardize_fill_value(ds)
var = main_vars(ds)[0]
print(f"Elevation variable is inferred to be {var}")
x_var, y_var = main_space_dims(ds)[0]
encod = ds[var].encoding
# NaN data are problematic when filling
# nan_mask = ds[var].isnull().data
nan_mask = xr.where(~ds[var].isnull(), True, False).data
ds[var] = ds[var].fillna(-9999)
ds[var].encoding = encod
viewfinder = ViewFinder(affine = ds.rio.transform(),
shape = ds.rio.shape,
crs = ds.rio.crs,
nodata = nodata)
dem = Raster(ds[var].data, viewfinder=viewfinder)
grid = Grid.from_raster(dem)
# Fill depressions in DEM
# =============================================================================
# print(' . dem no data is ', grid.nodata)
# =============================================================================
flooded_dem = grid.fill_depressions(dem)
# Resolve flats in DEM
inflated_dem = grid.resolve_flats(flooded_dem)
# Specify directional mapping
if isinstance(dirmap, str):
dirmap = dirmap.casefold().replace(' ', '').replace('-', '')
if dirmap in ['19', '[19]', 'keypad', 'pcraster']:
dirmap = (8, 9, 6, 3, 2, 1, 4, 7)
elif dirmap in ['d8', 'esri']:
dirmap = (64, 128, 1, 2, 4, 8, 16, 32) # ESRI system
elif dirmap in ['d8wbt', 'wbt', 'whiteboxtools']:
dirmap = (128, 1, 2, 4, 8, 16, 32, 64) # WhiteBox Tools system
# Compute flow directions
direc = grid.flowdir(inflated_dem, dirmap=dirmap,
nodata_out = np.int32(-3))
# Replace flats (-1) with value 5 (no flow)
direc = xr.where(direc == -1, 5, direc)
# Replace pits (-2) with value 5 (no flow)
direc = xr.where(direc == -2, 5, direc)
# Output
ds[var] = ([y_var, x_var], np.array(direc))
ds[var] = ds[var].where(nan_mask)
ds = ds.rename({var: 'LDD'})
# =============================================================================
# ds['LDD'] = ds['LDD'].astype(float) # astype(int)
# =============================================================================
# =============================================================================
# ds['LDD'] = ds['LDD'].astype(np.int32)
# =============================================================================
ds['LDD'].encoding = encod
ds['LDD'].encoding['dtype'] = np.int32
ds['LDD'].encoding['rasterio_dtype'] = np.int32
ds['LDD'].encoding['_FillValue'] = -3
# ========= issue with dtypes when exporting ==================================
# if 'scale_factor' in ds['LDD'].encoding:
# ds['LDD'].encoding['scale_factor'] = np.int32(ds['LDD'].encoding['scale_factor'])
# if 'add_offset' in ds['LDD'].encoding:
# ds['LDD'].encoding['add_offset'] = np.int32(ds['LDD'].encoding['add_offset'])
# if '_FillValue' in ds['LDD'].encoding:
# ds['LDD'].encoding['_FillValue'] = np.int32(-1)
# =============================================================================
ds = georef(data = ds)
# ---- With WhiteToolBox (discontinued)
# =============================================================================
# elif engine.casefold() in ["wtb", "whitetoolbox"]:
# print('WhiteToolBox engine...')
# dist_ = 10
#
# # Breach depressions
# wbt.breach_depressions_least_cost(
# dem_path,
# os.path.splitext(dem_path)[0] + f"_breached{dist_}[wtb].tif",
# dist_)
# print(' Fichier intermédiaire créé')
#
# # =============================================================================
# # # Fill depression (alternative)
# # wbt.fill_depressions(
# # dem_path,
# # os.path.splitext(dem_path)[0] + "_filled[wtb].tif",
# # 10)
# # =============================================================================
#
# # Creation du D8
# suffix = "breached{}[wtb]".format(dist_)
# wbt.d8_pointer(
# os.path.splitext(dem_path)[0] + "_" + suffix + ".tif",
# os.path.join(os.path.split(dem_path)[0], "D8_xxx_" + suffix + "_wtb.tif"),
# esri_pntr = True)
# print(' LDD "D8 ESRI" créé')
# =============================================================================
return ds
###############################################################################
def cell_area(data, src_crs = None, engine = 'shapely'):
"""
Compute cell area of a raster.
Parameters
----------
data : path to a .tif, .nc or .asc file (str or pathlib.Path), or variable (xarray.Dataset or xarray.DataArray)
If ``data`` is a raster file (ASCII, GeoTIFF) or a netCDF, it will be
loaded into a standard *GEOP4TH* variable (xarray.Dataset).
src_crs : int or str or rasterio.crs.CRS, optional
Coordinate reference system of the source (``data``), that will be embedded in the ``data``.
When passed as an *integer*, ``src_crs`` refers to the EPSG code.
engine : {'shapely', 'equalearth', 'geographiclib'}, default 'shapely'
Method to compute area.
Returns
-------
xarray.Dataset
"""
# ---- Initialization
data_ds = load_any(data)
if src_crs is not None:
data_ds.rio.write_crs(src_crs, inplace = True)
else:
if data_ds.rio.crs is None:
print("Err: The Coordinate Reference System is required. It should be embedded in the input DEM or passed with the 'src_crs' argument")
return
# engine
engine = engine.casefold().replace(' ', '').replace('-', '')
# Initialize spatial info
x_var, y_var = main_space_dims(data_ds)[0]
height = len(data_ds[y_var])
width = len(data_ds[x_var])
x_res, y_res = data_ds.rio.resolution()
x_min, _, _, y_max = data_ds.rio.bounds()
x_coords = np.arange(x_min, x_min + width*x_res, x_res, dtype = np.float32) + 0.5*x_res # aligned on centroids
y_coords = np.arange(y_max, y_max + height*y_res, y_res, dtype = np.float32) + 0.5*y_res
# Initialize a null xr.DataArray
cell_area = xr.DataArray(np.zeros((height, width)),
coords = [y_coords, x_coords],
dims = [y_var, x_var])
if engine == 'geographiclib':
from geographiclib.geodesic import Geodesic
geod = Geodesic.WGS84
for x in x_coords - 0.5*x_res: # aligned on bounds
# for y in np.arange(y_max, y_max + height*y_res, y_res):
for y in y_coords - 0.5*y_res:
if engine == 'shapely':
area = Polygon([(x, y),
(x + x_res, y),
(x + x_res, y + y_res),
(x, y + y_res),
(x, y)]).area
elif engine == 'equalearth':
# Equal Earth projection is a worldwide equal area projection
# defined in: Šavrič, B., Patterson, T., & Jenny, B. (2018). The Equal
# Earth map projection. International Journal of Geographical Information
# Science, 33(3), 454–465. https://doi.org/10.1080/13658816.2018.1504949
reprj_coords = rasterio.warp.transform(
data_ds.rio.crs,
8857, # pyproj.CRS("+proj=eqearth").to_authority()
# https://epsg.io/8857 ; https://epsg.io/1078-method
[x, x + x_res, x + x_res, x], [y, y, y + y_res, y + y_res])
area = Polygon(
[(reprj_coords[0][i], reprj_coords[1][i]) \
for i in range(0, len(reprj_coords[0]))]).area
elif engine == 'geographiclib':
p = geod.Polygon()
reprj_coords = rasterio.warp.transform(
data_ds.rio.crs,
4326,
[x, x + x_res, x + x_res, x], [y, y, y + y_res, y + y_res])
for i in range(0, len(reprj_coords[0])):
p.AddPoint(reprj_coords[0][i], reprj_coords[1][i])
_, _, area = p.Compute()
# Store area
cell_area.loc[{'x': x+0.5*x_res, 'y': y+0.5*y_res}] = area
return cell_area
# =============================================================================
# for another solution: https://gis.stackexchange.com/questions/413349/calculating-area-of-lat-lon-polygons-without-transformation-using-geopandas
# =============================================================================
# ====== prev version that does not work ======================================
# # ---- With pysheds
# if engine.casefold() in ["pyshed", "pysheds"]:
# print('Pysheds engine...')
#
# # Load the pysheds grid
# ds = load_any(data, decode_coords = 'all')
# if src_crs is not None:
# ds.rio.write_crs(src_crs, inplace = True)
# else:
# if ds.rio.crs is None:
# print("Err: The Coordinate Reference System is required. It should be embedded in the input DEM or passed with the 'src_crs' argument")
# return
# ds, nodata = standardize_fill_value(ds)
# var = main_vars(ds)[0]
# x_var, y_var = main_space_dims(ds)[0]
# encod = ds[var].encoding
# # =============================================================================
# # # NaN data are problematic when filling
# # # nan_mask = ds[var].isnull().data
# # nan_mask = xr.where(~ds[var].isnull(), True, False).data
# # ds[var] = ds[var].fillna(-9999)
# # =============================================================================
# ds[var].encoding = encod
#
# # ===== useless because pGrid only takes files as inputs ======================
# # viewfinder = ViewFinder(affine = ds.rio.transform(),
# # shape = ds.rio.shape,
# # crs = ds.rio.crs,
# # nodata = nodata)
# # raster = Raster(ds[var].data, viewfinder=viewfinder)
# # =============================================================================
#
# export(ds, r"temp_raster.tif")
# grid = pGrid.from_raster(r"temp_raster.tif", data_name = 'area')
# grid.cell_area()
# os.remove(r"temp_raster.tif")
# print(r" _ The temporary file 'temp_raster.tif' has been removed")
#
# # Output
# ds[var] = ([y_var, x_var], np.array(grid.area))
# ds = ds.rename({var: 'area'})
# ds['area'].encoding = encod
#
# print("\nWarning: This function does not work as expected yet: area are only computed from the resolution")
# return ds
# =============================================================================
###############################################################################
#%%% ° Convert LDD code
def switch_direction_map(input_file, input_mapping, output_mapping):
"""
To switch between different direction mappings
"""
#%%% Inputs
mapping_dict = {'input': input_mapping, 'output': output_mapping}
for m in mapping_dict:
if mapping_dict[m].casefold().replace("_", "").replace(" ", "") in ["ldd","localdraindirections"]:
mapping_dict[m] = "LDD"
elif mapping_dict[m].casefold().replace("_", "").replace(" ", "") in ["d8", "esri", "d8esri", "esrid8", "d8standard", "standardd8"]:
mapping_dict[m] = "D8 ESRI"
elif mapping_dict[m].casefold().replace("_", "").replace(" ", "") in ["wtb", "whitetoolbox", "d8whitetoolbox", "d8wtb", "wtbd8"]:
mapping_dict[m] = "WTB"
elif mapping_dict[m].casefold().replace("_", "").replace(" ", "") in ["agnps"]:
mapping_dict[m] = "AGNPS"
else:
return "Error: mapping unknown"
print(f"{m} direction: {mapping_dict[m]}")
#%%% Conversion
# Chargement des données
data_in = rasterio.open(input_file, 'r')
data_profile = data_in.profile
val = data_in.read()
data_in.close()
# Conversion
# rows: 0:'LDD', 1:'D8', 2:'WTB'
col = ['LDD', 'D8 ESRI', 'WTB', 'AGNPS']
keys_ = np.array(
[[8, 64, 128, 1,],#N
[9, 128, 1, 2,], #NE
[6, 1, 2, 3,], #E
[3, 2, 4, 4,], #SE
[2, 4, 8, 5,], #S
[1, 8, 16, 6,], #SO
[4, 16, 32, 7,], #O
[7, 32, 64, 8,], #NO
[5, 0, 0, None,]]) #-
for d in range(0, 9):
val[val == keys_[d,
col.index(mapping_dict['input'])
]
] = -keys_[d,
col.index(mapping_dict['output'])]
val = -val # On passe par une valeur négative pour éviter les redondances
# du type : 3 --> 2, 2 --> 4
#%%% Export
output_file = os.path.splitext(input_file)[0] + "_{}.tif".format(mapping_dict['output'])
with rasterio.open(output_file, 'w', **data_profile) as output_f:
output_f.write(val)
print("\nFile created")
###############################################################################
#%%% ° Alter modflow_river_percentage
def river_pct(input_file, value):
"""
Creates artificial modflow_river_percentage inputs (in *.nc) to use for
drainage.
Parameters
----------
input_file : str
Original modflow_river_percentage.tif file to duplicate/modify
value : float
Value to impose on cells (from [0 to 1], not in percentage!)
This value is added to original values as a fraction of the remaining
"non-river" fraction:
For example, value = 0.3 (30%):
- cells with 0 are filled with 0.3
- cells with 1 remain the same
- cells with 0.8 take the value 0.86, because 30% of what should
have been capillary rise become baseflow (0.8 + 0.3*(1-0.8))
- cells with 0.5 take the value 0.65 (0.5 + 0.3*(1-0.5))
Returns
-------
None.
"""
#%% Loading
# ---------
if os.path.splitext(input_file)[-1] == '.tif':
with xr.open_dataset(input_file, # .tif
decode_times = True,
) as ds:
ds.load()
elif os.path.splitext(input_file)[-1] == '.nc':
with xr.open_dataset(input_file,
decode_times = True,
decode_coords = 'all',
) as ds:
ds.load()
#%% Computing
# -----------
# ds['band_data'] = ds['band_data']*0 + value
ds_ones = ds.copy(deep = True)
ds_ones['band_data'] = ds_ones['band_data']*0 + 1
#% modflow_river_percentage_{value}.nc:
ds1 = ds.copy(deep = True)
ds1['band_data'] = ds1['band_data'] + (ds_ones['band_data'] - ds1['band_data'])*value
#% drainage_river_percentage_{value}.nc :
ds2 = ds1 - ds
#%% Formatting
# ------------
ds1.rio.write_crs(2154, inplace = True)
ds1.x.attrs = {'standard_name': 'projection_x_coordinate',
'long_name': 'x coordinate of projection',
'units': 'Meter'}
ds1.y.attrs = {'standard_name': 'projection_y_coordinate',
'long_name': 'y coordinate of projection',
'units': 'Meter'}
# To avoid conflict when exporting to netcdf:
ds1.x.encoding['_FillValue'] = None
ds1.y.encoding['_FillValue'] = None
ds2.rio.write_crs(2154, inplace = True)
ds2.x.attrs = {'standard_name': 'projection_x_coordinate',
'long_name': 'x coordinate of projection',
'units': 'Meter'}
ds2.y.attrs = {'standard_name': 'projection_y_coordinate',
'long_name': 'y coordinate of projection',
'units': 'Meter'}
# To avoid conflict when exporting to netcdf:
ds2.x.encoding['_FillValue'] = None
ds2.y.encoding['_FillValue'] = None
#%% Exporting
# -----------
(folder, file) = os.path.split(input_file)
(file, extension) = os.path.splitext(file)
output_file1 = os.path.join(folder, "_".join([file, str(value)]) + '.nc')
ds1.to_netcdf(output_file1)
output_file2 = os.path.join(folder, "_".join(['drainage_river_percentage', str(value)]) + '.nc')
ds2.to_netcdf(output_file2)
#%% QUANTITIES OPERATIONS
###############################################################################
# Calcule ETref et EWref à partir de la "pan evaporation" de ERA5-Land
def compute_Erefs_from_Epan(input_file):
print("\nDeriving standard grass evapotranspiration and standard water evapotranspiration from pan evaporation...")
Epan = load_any(input_file, decode_coords = 'all', decode_times = True)
var = main_vars(Epan)
print(" _ Computing ETref (ET0) from Epan...")
ETref = Epan.copy()
ETref[var] = ETref[var]*0.675
print(" _ Computing EWref from Epan...")
EWref = Epan.copy()
EWref[var] = EWref[var]*0.75
print(" _ Transferring encodings...")
ETref[var].encoding = Epan[var].encoding
EWref[var].encoding = Epan[var].encoding
# Case of packing
if ('scale_factor' in Epan[var].encoding) | ('add_offset' in Epan[var].encoding):
# Packing (lossy compression) induces a loss of precision of
# apprx. 1/1000 of unit, for a quantity with an interval of 150
# units. The packing is initially used in some original ERA5-Land data
ETref[var].encoding['scale_factor'] = ETref[var].encoding['scale_factor']*0.675
ETref[var].encoding['add_offset'] = ETref[var].encoding['add_offset']*0.675
EWref[var].encoding['scale_factor'] = EWref[var].encoding['scale_factor']*0.75
EWref[var].encoding['add_offset'] = EWref[var].encoding['add_offset']*0.75
return ETref, EWref
###############################################################################
def compute_wind_speed(u_wind_data, v_wind_data):
"""
U-component of wind is parallel to the x-axis
V-component of wind is parallel to the y-axis
"""
# =============================================================================
# print("\nIdentifying files...")
# U_motif = re.compile('U-component')
# U_match = U_motif.findall(input_file)
# V_motif = re.compile('V-component')
# V_match = V_motif.findall(input_file)
#
# if len(U_match) > 0:
# U_input_file = '%s' % input_file # to copy the string
# V_input_file = '%s' % input_file
# V_input_file = V_input_file[:U_motif.search(input_file).span()[0]] + 'V' + V_input_file[U_motif.search(input_file).span()[0]+1:]
# elif len(V_match) > 0:
# V_input_file = '%s' % input_file # to copy the string
# U_input_file = '%s' % input_file
# U_input_file = U_input_file[:V_motif.search(input_file).span()[0]] + 'U' + U_input_file[V_motif.search(input_file).span()[0]+1:]
# =============================================================================
print("\nComputing wind speed from U- and V-components...")
U_ds = load_any(u_wind_data, decode_coords = 'all', decode_times = True)
V_ds = load_any(v_wind_data, decode_coords = 'all', decode_times = True)
wind_speed_ds = U_ds.copy()
wind_speed_ds = wind_speed_ds.rename(u10 = 'wind_speed')
wind_speed_ds['wind_speed'] = np.sqrt(U_ds.u10*U_ds.u10 + V_ds.v10*V_ds.v10)
# nan remain nan
print(" _ Transferring encodings...")
wind_speed_ds['wind_speed'].encoding = V_ds.v10.encoding
wind_speed_ds['wind_speed'].attrs['long_name'] = '10 metre wind speed'
# Case of packing
if ('scale_factor' in V_ds.v10.encoding) | ('add_offset' in V_ds.v10.encoding):
# Packing (lossy compression) induces a loss of precision of
# apprx. 1/1000 of unit, for a quantity with an interval of 150
# units. The packing is initially used in some original ERA5-Land data
# Theoretical max wind speed:
max_speed = 56 # m/s = 201.6 km/h
(scale_factor, add_offset) = compute_scale_and_offset(-max_speed, max_speed, 16)
# Out: (0.0017090104524299992, 0.0008545052262149966)
wind_speed_ds['wind_speed'].encoding['scale_factor'] = scale_factor
wind_speed_ds['wind_speed'].encoding['add_offset'] = add_offset
# wind_speed_ds['wind_speed'].encoding['FillValue_'] = -32767
# To remain the same as originally
# Corresponds to -55.99829098954757 m/s
return wind_speed_ds
###############################################################################
def compute_relative_humidity(*, dewpoint_input_file,
temperature_input_file,
pressure_input_file,
method = "Penman-Monteith"):
"""
cf formula on https://en.wikipedia.org/wiki/Dew_point
gc.compute_relative_humidity(
dewpoint_input_file = r"D:\2- Postdoc\2- Travaux\1- Veille\4- Donnees\8- Meteo\ERA5\Brittany\2011-2021 Dewpoint temperature.nc",
temperature_input_file = r"D:\2- Postdoc\2- Travaux\1- Veille\4- Donnees\8- Meteo\ERA5\Brittany\2011-2021 Temperature.nc",
pressure_input_file = r"D:\2- Postdoc\2- Travaux\1- Veille\4- Donnees\8- Meteo\ERA5\Brittany\2011-2021 Surface pressure.nc",
method = "Sonntag")
"""
# ---- Loading data
# --------------
print("\nLoading data...")
# Checking that the time period matches:
years_motif = re.compile('\d{4,4}-\d{4,4}')
years_dewpoint = years_motif.search(dewpoint_input_file).group()
years_pressure = years_motif.search(pressure_input_file).group()
years_temperature = years_motif.search(temperature_input_file).group()
if (years_dewpoint == years_pressure) and (years_dewpoint == years_temperature):
print(" Years are matching: {}".format(years_dewpoint))
else:
print(" /!\ Years are not matching: {}\n{}\n{}".format(years_dewpoint, years_pressure, years_temperature))
# return 0
with xr.open_dataset(dewpoint_input_file, decode_coords = 'all') as Dp:
Dp.load() # to unlock the resource
with xr.open_dataset(temperature_input_file, decode_coords = 'all') as T:
T.load() # to unlock the resource
with xr.open_dataset(pressure_input_file, decode_coords = 'all') as Pa:
Pa.load() # to unlock the resource
# ---- Sonntag formula
# -----------------
if method.casefold() in ['sonntag', 'sonntag1990']:
print("\nComputing the relative humidity, using the Sonntag 1990 formula...")
# NB : air pressure Pa is not used in this formula
# Constants:
alpha_ = 6.112 # [hPa]
beta_ = 17.62 # [-]
lambda_ = 243.12 # [°C]
# Temperature in degrees Celsius:
Tc = T.copy()
Tc['t2m'] = T['t2m'] - 273.15
Dpc = Dp.copy()
Dpc['d2m'] = Dp['d2m'] - 273.15
# Saturation vapour pressure [hPa]:
Esat = Tc.copy()
Esat = Esat.rename(t2m = 'vpsat')
Esat['vpsat'] = alpha_ * np.exp((beta_ * Tc['t2m']) / (lambda_ + Tc['t2m']))
# Vapour pressure [hPa]:
E = Dp.copy()
E = E.rename(d2m = 'vp')
E['vp'] = alpha_ * np.exp((Dpc['d2m'] * beta_) / (lambda_ + Dpc['d2m']))
# Relative humidity [%]:
RHS = Dp.copy()
RHS = RHS.rename(d2m = 'rh')
RHS['rh'] = E['vp']/Esat['vpsat']*100
elif method.casefold() in ['penman', 'monteith', 'penman-monteith']:
print("\nComputing the relative humidity, using the Penman Monteith formula...")
# NB : air pressure Pa is not used in this formula
# Used in evaporationPot.py
# http://www.fao.org/docrep/X0490E/x0490e07.htm equation 11/12
# Constants:
alpha_ = 0.6108 # [kPa]
beta_ = 17.27 # [-]
lambda_ = 237.3 # [°C]
# Temperature in degrees Celsius:
Tc = T.copy()
Tc['t2m'] = T['t2m'] - 273.15
Dpc = Dp.copy()
Dpc['d2m'] = Dp['d2m'] - 273.15
# Saturation vapour pressure [kPa]:
Esat = Tc.copy()
Esat = Esat.rename(t2m = 'vpsat')
Esat['vpsat'] = alpha_ * np.exp((beta_ * Tc['t2m']) / (lambda_ + Tc['t2m']))
# Vapour pressure [kPa]:
E = Dp.copy()
E = E.rename(d2m = 'vp')
E['vp'] = alpha_ * np.exp((beta_ * Dpc['d2m']) / (lambda_ + Dpc['d2m']))
# Relative humidity [%]:
# https://www.fao.org/3/X0490E/x0490e07.htm Eq. (10)
RHS = Dp.copy()
RHS = RHS.rename(d2m = 'rh')
RHS['rh'] = E['vp']/Esat['vpsat']*100
#% Attributes
print("\nTransferring encodings...")
RHS['rh'].attrs['units'] = '%'
RHS['rh'].attrs['long_name'] = 'Relative humidity (from 2m dewpoint temperature)'
RHS['rh'].encoding = Dp['d2m'].encoding
# Case of packing
if ('scale_factor' in Dp['d2m'].encoding) | ('add_offset' in Dp['d2m'].encoding):
# Packing (lossy compression) induces a loss of precision of
# apprx. 1/1000 of unit, for a quantity with an interval of 150
# units. The packing is initially used in some original ERA5-Land data.
# RHS['rh'].encoding['scale_factor'] = 0.0016784924086366065
# RHS['rh'].encoding['add_offset'] = 55.00083924620432
# RHS['rh'].encoding['_FillValue'] = 32767
# RHS['rh'].encoding['missing_value'] = 32767
(scale_factor, add_offset) = compute_scale_and_offset(-1, 100, 16)
# Out: (0.0015411612115663385, 49.50077058060578)
RHS['rh'].encoding['scale_factor'] = scale_factor
RHS['rh'].encoding['add_offset'] = add_offset
# RHS['rh'].encoding['_FillValue'] = -32767
# To match with original value
# Corresponds to -0.9984588387884301 %
return RHS
###############################################################################
# Convertit les données de radiation (J/m2/h) en W/m2
def convert_downwards_radiation(input_file, is_dailysum = False):
print("\nConverting radiation units...")
rad = load_any(input_file, decode_coords = 'all', decode_times = True)
var = main_vars(rad)
print(" _ Field is: {}".format(var))
print(" _ Computing...")
rad_W = rad.copy()
if not is_dailysum:
conv_factor = 3600 # because 3600s in 1h
else:
conv_factor = 86400 # because 86400s in 1d
rad_W[var] = rad_W[var]/conv_factor
print(" _ Transferring encodings...")
rad_W[var].attrs['units'] = 'W m**-2'
rad_W[var].encoding = rad[var].encoding
# Case of packing
if ('scale_factor' in rad_W[var].encoding) | ('add_offset' in rad_W[var].encoding):
# Packing (lossy compression) induces a loss of precision of
# apprx. 1/1000 of unit, for a quantity with an interval of 150
# units. The packing is initially used in some original ERA5-Land data.
rad_W[var].encoding['scale_factor'] = rad_W[var].encoding['scale_factor']/conv_factor
rad_W[var].encoding['add_offset'] = rad_W[var].encoding['add_offset']/conv_factor
# NB:
# rad_W[var].encoding['FillValue_'] = -32767
# To remain the same as originally
# Corresponds to -472.11... m/s
# NB: For few specific times, data are unavailable. Such data are coded
# with the value -1, packed into -32766
return rad_W
#%% * OBSOLETE ? Shift rasters (GeoTIFF or NetCDF)
###############################################################################
# Pas totalement fini. Code issu de 'datatransform.py'
def transform_tif(*, input_file, x_shift = 0, y_shift = 0, x_size = 1, y_size = 1):
"""
EXAMPLE:
import datatransform as dt
dt.transform_tif(input_file = r"D:\2- Postdoc\2- Travaux\3_CWatM_EBR\data\input_1km_LeMeu\areamaps\mask_cwatm_LeMeu_1km.tif",
x_shift = 200,
y_shift = 300)
"""
# Ouvrir le fichier :
data = rasterio.open(input_file, 'r')
# Récupérer le profil :
_prof_base = data.profile
trans_base = _prof_base['transform']
# Juste pour visualiser :
print('\nLe profil affine initial est :')
print(trans_base)
# Modifier le profile :
trans_modf = Affine(trans_base[0]*x_size, trans_base[1], trans_base[2] + x_shift,
trans_base[3], trans_base[4]*y_size, trans_base[5] + y_shift)
print('\nLe profil modifié est :')
print(trans_modf)
_prof_modf = _prof_base
_prof_modf.update(transform = trans_modf)
# Exporter :
_basename = os.path.splitext(input_file)[0]
add_name = ''
if x_shift != 0 or y_shift !=0:
add_name = '_'.join([add_name, 'shift'])
if x_shift != 0:
add_name = '_'.join([add_name, 'x' + str(x_shift)])
if y_shift != 0:
add_name = '_'.join([add_name, 'y' + str(y_shift)])
if x_size != 1 or y_size !=1:
add_name = '_'.join([add_name, 'size'])
if x_size != 1:
add_name = '_'.join([add_name, 'x' + str(x_size)])
if y_size != 1:
add_name = '_'.join([add_name, 'y' + str(y_size)])
output_file = '_'.join([_basename, add_name]) + '.tif'
with rasterio.open(output_file, 'w', **_prof_modf) as out_f:
out_f.write_band(1, data.read()[0])
data.close()
def transform_nc(*, input_file, x_shift = 0, y_shift = 0, x_size = 1, y_size = 1):
"""
EXAMPLE:
import datatransform as dt
dt.transform_nc(input_file = r"D:\2- Postdoc\2- Travaux\3_CWatM_EBR\data\input_1km_LeMeu\landsurface\topo\demmin.nc",
x_shift = 200,
y_shift = 400)
"""
with xr.open_dataset(input_file) as data:
data.load() # to unlock the resource
# Modifier :
data['x'] = data.x + x_shift
data['y'] = data.y + y_shift
# Exporter :
_basename = os.path.splitext(input_file)[0]
add_name = ''
if x_shift != 0 or y_shift !=0:
add_name = '_'.join([add_name, 'shift'])
if x_shift != 0:
add_name = '_'.join([add_name, 'x' + str(x_shift)])
if y_shift != 0:
add_name = '_'.join([add_name, 'y' + str(y_shift)])
if x_size != 1 or y_size !=1:
add_name = '_'.join([add_name, 'size'])
if x_size != 1:
add_name = '_'.join([add_name, 'x' + str(x_size)])
if y_size != 1:
add_name = '_'.join([add_name, 'y' + str(y_size)])
output_file = '_'.join([_basename, add_name]) + '.nc'
data.to_netcdf(output_file)
#%% * tools for computing coordinates
###############################################################################
def convert_coord(pointXin, pointYin, inputEPSG = 2154, outputEPSG = 4326):
"""
Il y a un soucis dans cette fonction. X et Y se retrouvent inversées.
Il vaut mieux passer par les fonctions rasterio (voir plus haut) :
coords_conv = rasterio.warp.transform(rasterio.crs.CRS.from_epsg(inputEPSG),
rasterio.crs.CRS.from_epsg(outputEPSG),
[pointXin], [pointYin])
pointXout = coords_conv[0][0]
pointYout = coords_conv[1][0]
"""
#% Inputs (standards)
# =============================================================================
# # Projected coordinates in Lambert-93
# pointXin = 350556.92318 #Easthing
# pointYin = 6791719.72296 #Northing
# (Rennes coordinates)
# =============================================================================
# =============================================================================
# # Geographical coordinates in WGS84 (2D)
# pointXin = 48.13222 #Latitude (Northing)
# pointYin = -1.7 #Longitude (Easting)
# (Rennes coordinates)
# =============================================================================
# =============================================================================
# # Spatial Reference Systems
# inputEPSG = 2154 #Lambert-93
# outputEPSG = 4326 #WGS84 (2D)
# =============================================================================
# # Conversion into EPSG system
# For easy use, inputEPSG and outputEPSG can be defined with identifiers strings
switchEPSG = {
'L93': 2154, #Lambert-93
'L-93': 2154, #Lambert-93
'WGS84': 4326, #WGS84 (2D)
'GPS': 4326, #WGS84 (2D)
'LAEA': 3035, #LAEA Europe
}
if isinstance(inputEPSG, str):
inputEPSG = switchEPSG.get(inputEPSG, False)
# If the string is not a valid identifier:
if not inputEPSG:
print('Unknown input coordinates system')
return
if isinstance(outputEPSG, str):
outputEPSG = switchEPSG.get(outputEPSG, False)
# If the string is not a valid identifier:
if not outputEPSG:
print('Unknown output coordinates system')
return
#% Outputs
# =============================================================================
# # Méthode osr
# # create a geometry from coordinates
# point = ogr.Geometry(ogr.wkbPoint)
# point.AddPoint(pointXin, pointYin)
#
# # create coordinate transformation
# inSpatialRef = osr.SpatialReference()
# inSpatialRef.ImportFromEPSG(inputEPSG)
#
# outSpatialRef = osr.SpatialReference()
# outSpatialRef.ImportFromEPSG(outputEPSG)
#
# coordTransform = osr.CoordinateTransformation(inSpatialRef, outSpatialRef)
#
# # transform point
# point.Transform(coordTransform)
# pointXout = point.GetX()
# pointYout = point.GetY()
# =============================================================================
# Méthode rasterio
coords_conv = rasterio.warp.transform(rasterio.crs.CRS.from_epsg(inputEPSG),
rasterio.crs.CRS.from_epsg(outputEPSG),
[pointXin], [pointYin])
pointXout = coords_conv[0][0]
pointYout = coords_conv[1][0]
# Return point coordinates in output format
return(pointXout, pointYout)
#%% date tools for QGIS
"""
Pour faire facilement la conversion "numéro de bande - date" dans QGIS lorsqu'on
ouvre les fichers NetCDF comme rasters.
/!\ Dans QGIS, le numéro de 'band' est différent du 'time'
(parfois 'band' = 'time' + 1, parfois il y a une grande différence)
C'est le 'time' qui compte.
"""
###############################################################################
def date_to_index(_start_date, _date, _freq):
time_index = len(pd.date_range(start = _start_date, end = _date, freq = _freq))-1
print('La date {} correspond au temps {}'.format(_date, str(time_index)))
return time_index
###############################################################################
def index_to_date(_start_date, _time_index, _freq):
date_index = pd.date_range(start = _start_date, periods = _time_index+1, freq = _freq)[-1]
print('Le temps {} correspond à la date {}'.format(_time_index, str(date_index)))
return date_index
def get_geop4th_version() -> str:
"""Get the current version of geop4th package."""
try:
return importlib.metadata.version('geop4th')
except importlib.metadata.PackageNotFoundError:
return 'unknown'
def scan_files(
paths: Union[str, Path, List[Union[str, Path]]],
*,
variables_to_find: Optional[Union[str, List[str]]] = None,
bbox: Optional[Tuple[float, float, float, float]] = None,
mask: Optional[Union[str, Path, xr.Dataset, gpd.GeoDataFrame]] = None,
start_date: Optional[Union[str, pd.Timestamp]] = None,
end_date: Optional[Union[str, pd.Timestamp]] = None,
frequency: Optional[str] = None,
file_extension: str = "*.nc"
) -> Dict[str, List[Dict[str, Any]]]:
"""
Find and analyze existing files across multiple paths, extracting metadata
including spatial bounds, temporal range, and variables.
Parameters
----------
paths : Union[str, Path, List[Union[str, Path]]]
Directory path(s) to scan. Can be a single path or list of paths.
variables_to_find : Optional[Union[str, List[str]]]
Variable name(s) to search for. Can be a single variable or list of variables.
If None, returns all files.
bbox : Optional[Tuple[float, float, float, float]]
Target bounding box (North, West, South, East)
mask : Optional[Union[str, Path, xr.Dataset, gpd.GeoDataFrame]]
Mask to extract bbox from. Can be path or dataset. Overrides bbox if provided.
start_date : Optional[Union[str, pd.Timestamp]]
Start date for filtering
end_date : Optional[Union[str, pd.Timestamp]]
End date for filtering
frequency : Optional[str]
Target frequency for filtering
file_extension : str, default "*.nc"
File pattern to search. Accepts formats like "*.nc", ".nc", or "nc".
Returns
-------
Dict[str, List[Dict[str, Any]]]
Dict mapping variable names to lists of file metadata dicts.
"""
# Convert single values to lists
if isinstance(paths, (str, Path)):
paths = [paths]
if isinstance(variables_to_find, str):
variables_to_find = [variables_to_find]
# Extract bbox from mask if provided
if mask is not None:
mask_ds = load_any(mask) if isinstance(mask, (str, Path)) else mask
if isinstance(mask_ds, xr.Dataset):
bbox_extracted = mask_ds.rio.bounds()
# Convert from (left, bottom, right, top) to (North, West, South, East)
bbox = (bbox_extracted[3], bbox_extracted[0], bbox_extracted[1], bbox_extracted[2])
elif isinstance(mask_ds, gpd.GeoDataFrame):
bbox_extracted = mask_ds.total_bounds
# total_bounds returns (minx, miny, maxx, maxy) = (West, South, East, North)
bbox = (bbox_extracted[3], bbox_extracted[0], bbox_extracted[1], bbox_extracted[2])
logger.info(f"Extracted bbox from mask: {bbox}")
# Normalize file_extension format
if file_extension and not file_extension.startswith('*'):
if not file_extension.startswith('.'):
file_extension = '.' + file_extension
file_extension = '*' + file_extension
logger.info(f"Scanning {len(paths)} paths for files with pattern: {file_extension}")
# Convert dates to pandas timestamps
if start_date:
start_date = pd.to_datetime(start_date)
if end_date:
end_date = pd.to_datetime(end_date)
# Convert to set for faster lookups
variables_set = set(variables_to_find) if variables_to_find else None
# Results and tracking
files_by_variable = {}
processed_files = set() # Avoid processing duplicates
for path in paths:
path_obj = Path(path)
if not path_obj.exists():
logger.warning(f"Path does not exist: {path}")
continue
logger.debug(f"Scanning path: {path_obj}")
# Handle both files and directories
if path_obj.is_file() and path_obj.match(file_extension.replace('*', path_obj.name)):
files_to_process = [path_obj]
else:
files_to_process = list(path_obj.rglob(file_extension))
logger.debug(f"Found {len(files_to_process)} files in {path_obj}")
for file_path in files_to_process:
# Skip already processed files
file_path_resolved = file_path.resolve()
if file_path_resolved in processed_files:
continue
processed_files.add(file_path_resolved)
try:
# Load metadata efficiently
if str(file_path).endswith('.nc'):
ds = xr.open_dataset(file_path, decode_times=True, decode_coords='all')
else:
# Fallback for non-NetCDF files
ds = load_any(file_path, decode_times=True, decode_coords='all')
if ds is None:
logger.warning(f"Could not load file: {file_path}")
continue
# Extract spatial bounds using rio.bounds()
file_bbox = None
if isinstance(ds, xr.Dataset) and hasattr(ds, 'rio'):
try:
bounds = ds.rio.bounds()
# rio.bounds() returns (left, bottom, right, top)
file_bbox = (
bounds[3], # North (top)
bounds[0], # West (left)
bounds[1], # South (bottom)
bounds[2] # East (right)
)
except Exception as e:
logger.debug(f"Could not extract bbox from {file_path.name}: {e}")
temporal_range = None
file_frequency = None
time_coord = main_time_dims(ds)
if time_coord and isinstance(time_coord, list):
time_coord = time_coord[0]
if time_coord and time_coord in ds.coords:
try:
time_values = ds[time_coord].values
if len(time_values) > 0:
file_start = pd.to_datetime(time_values[0])
file_end = pd.to_datetime(time_values[-1])
temporal_range = (file_start, file_end)
if len(time_values) > 1:
time_diff = pd.Timedelta(time_values[1] - time_values[0])
if time_diff <= pd.Timedelta(hours=6):
file_frequency = 'hourly'
elif time_diff <= pd.Timedelta(hours=30):
file_frequency = 'daily'
elif time_diff >= pd.Timedelta(days=25):
file_frequency = 'monthly'
else:
file_frequency = 'unknown'
except Exception as e:
logger.debug(f"Could not extract temporal info from {file_path.name}: {e}")
# Apply filters
skip_file = False
# Spatial filter
if bbox and file_bbox:
# Check spatial overlap
target_north, target_west, target_south, target_east = bbox
file_north, file_west, file_south, file_east = file_bbox
if (file_south > target_north or file_north < target_south or
file_east < target_west or file_west > target_east):
skip_file = True
# Temporal filter
if not skip_file and start_date and end_date and temporal_range:
file_start, file_end = temporal_range
# Check for temporal overlap
if file_end < start_date or file_start > end_date:
skip_file = True
# Frequency filter (skip invariant files)
if not skip_file and frequency and file_frequency and file_frequency != frequency:
if temporal_range is not None: # Has time dimension
skip_file = True
if skip_file:
logger.debug(f"Skipping {file_path.name} due to filters")
continue
# Extract variables from file
all_file_variables = main_vars(ds)
# Filter variables if specific ones are requested
if variables_set:
matching_variables = [var for var in all_file_variables if var in variables_set]
if not matching_variables:
logger.debug(f"Skipping {file_path.name} - no matching variables")
continue
target_variables = matching_variables
else:
target_variables = all_file_variables
# Create file metadata
file_metadata = {
'file_path': file_path,
'bbox': file_bbox,
'temporal_range': temporal_range,
'frequency': file_frequency,
'variables': all_file_variables,
'has_time_dimension': temporal_range is not None
}
# Organize files by variable
for var in target_variables:
if var not in files_by_variable:
files_by_variable[var] = []
# Create variable-specific metadata
var_specific_metadata = file_metadata.copy()
var_specific_metadata['primary_variable'] = var
files_by_variable[var].append(var_specific_metadata)
except Exception as e:
logger.warning(f"Error processing file {file_path}: {e}")
continue
finally:
# Ensure dataset is closed to avoid resource leaks
if ds is not None and hasattr(ds, 'close'):
try:
ds.close()
except:
pass
# Sort files by temporal range for each variable
for var_name in files_by_variable:
files_by_variable[var_name].sort(
key=lambda x: x['temporal_range'][0] if x['temporal_range'] else pd.Timestamp.min
)
logger.info(f"Found files for {len(files_by_variable)} variables: {list(files_by_variable.keys())}")
for var, files in files_by_variable.items():
logger.debug(f" {var}: {len(files)} files")
return files_by_variable
# Backward compatibility alias
find_existing_files = scan_files
#%% EXAMPLES
###############################################################################
def example(data):
"""
Quick way to load GEOP4TH examples.
Parameters
----------
data : {'mask', 'netcdf'}
Type of example data to load.
Returns
-------
ds : xr.Dataset, gpd.GeoDataFrame, pd.DataFrame
Data loaded as a GEOP4TH variable.
"""
root = Path(os.path.dirname(__file__)) / 'examples'
if data.casefold() in ['mask', 'polygon', 'polygons', 'vector', 'geopackage', 'gpkg', '.gpkg', 'shapefile', 'shp', '.shp']:
ds = load_any(root / "mask.gpkg")
elif data.casefold() in ['point', 'points', 'bnpe', 'json']:
ds = load_any(root / "BNPE.json")
elif data.casefold() in ['raster', 'tif', 'dem', '.tif']:
ds = load_any(root / "BDALTI_1000m.tif")
elif data.casefold() in ['netcdf', 'nc']:
ds = load_any(root / "SIM2_EVAP_2024-2025.nc")
elif data.casefold() in ['faulty', 'defective']:
ds = load_any(root / "faulty.nc")
else:
print(f"Err: keyword '{data}' not recognized. Possible options are:")
print(" . 'mask', 'polygon', 'polygons', 'vector', 'geopackage', 'gpkg', '.gpkg', 'shapefile', 'shp', '.shp'")
print(" . 'point', 'points', 'bnpe', 'json', '.json'")
print(" . 'raster', 'tif', 'dem', '.tif'")
print(" . 'netcdf', 'nc', '.nc'")
print(" . 'faulty', 'defective'")
return
return ds
#%% main
if __name__ == "__main__":
# Format the inputs (assumed to be strings) into floats
sys.argv[1] = float(sys.argv[1])
sys.argv[2] = float(sys.argv[2])
# Print some remarks
print('Arguments:')
print(sys.argv[1:])
# Execute the ConvertCoord function
(a,b) = convert_coord(*sys.argv[1:])
print(a,b)