Source code for standardize_era5

# -*- coding: utf-8 -*-
"""
ERA5-Land data standardization module.

Created on Wed Jun 03 2025
@author: Alexandre Kenshilik Coche & Bastien Boivin
@contact: alexandre.co@hotmail.fr | bastien.boivin@proton.me
"""

import pathlib
from typing import Union, Dict, Optional

import numpy as np
import pandas as pd
import xarray as xr

from geop4th import (
    geobricks as geo,
    utils,
    )
from geop4th.datasets import (
    ERA5_LAND,
    get_temporal_aggregation_modes,
    get_accumulation_variables,
    get_bias_correction_factors,
    get_radiation_variables,
    get_sign_correction_variables,
    get_variable_info,
    get_short_to_long_name_mapping
)
from geop4th.workflows.standardize._common import (
    BaseStandardizer,
    detect_frequency,
    clean_encoding_conflicts,
    polish_cf_compliance,
)


# Configuration constants

TEMPORAL_AGGREGATION_MODES = get_temporal_aggregation_modes(ERA5_LAND)
SECONDARY_VARIABLES = ERA5_LAND['secondary_variables']
ACCUMULATION_VARS = get_accumulation_variables(ERA5_LAND)
SIGN_CORRECTION_VARS = set(get_sign_correction_variables(ERA5_LAND))
RADIATION_VARS = set(get_radiation_variables(ERA5_LAND))
BIAS_CORRECTION_FACTORS = get_bias_correction_factors(ERA5_LAND, 'ERA5-brittany-daily')
PROGRESSIVE_CORRECTION_FACTORS = get_bias_correction_factors(ERA5_LAND, 'ERA5-brittany-daily', progressive=True)

# Helper functions
def available():
    excluded = {
        '_create_standardizer',
        '_validate_era5_batch',
        '_validate_era5_input',
        }

    utils.available(__name__,
                    ignore = excluded,
                    )

# detect_frequency is imported from _common.py


def _deaccumulate_hourly_cumulative(da: xr.DataArray) -> xr.DataArray:
    """Convert ERA5-Land hourly cumulative-since-00-UTC values to per-hour increments.

    ERA5-Land stores accumulation fields (precipitation, radiation,
    evaporation, runoff, snowfall, ...) as the running total since
    00 UTC of the current day. The increment for the period [t-1h, t]
    is the diff along time, except across the daily 00 UTC reset where
    the diff jumps back to (near) zero. The reset is detected by the
    time-of-day label rather than the sign of the diff, because some
    fields (potential_evaporation, total_evaporation) accumulate as
    negative quantities.

    Output time labels are shifted by -1h so each increment is labeled
    at the START of its 1-hour period; this aligns increments with the
    calendar-day bins used by daily resampling.

    The output is one hour shorter than the input (no predecessor for
    the very first time step).
    """
    delta = da.diff('time')                      # labeled at end of period
    after_reset = da.isel(time=slice(1, None))   # values aligned with delta
    # The diff labeled at 01:00 spans 00:00 -> 01:00, i.e. across the
    # daily reset; replace it with the post-reset accumulation.
    is_reset = delta['time'].dt.hour == 1
    delta = xr.where(is_reset, after_reset, delta)
    delta = delta.assign_coords(time=delta['time'] - pd.Timedelta(hours=1))
    delta.attrs.update(da.attrs)
    return delta


def _validate_era5_input(data_input: Union[str, pathlib.Path, xr.Dataset]) -> Dict[str, Union[bool, str]]:
    """
    Check if input contains ERA5-Land variables and report processing status.

    Returns a dict with three keys: ``is_era5_land`` (bool),
    ``is_processed`` (bool, True if the file has the standardization
    metadata stamp) and ``error`` (str or None).
    """
    try:
        if isinstance(data_input, xr.Dataset):
            ds = data_input
            owns_ds = False
        else:
            path = pathlib.Path(data_input)
            if path.is_dir():
                return {'is_era5_land': True, 'is_processed': False, 'error': None}
            ds = xr.open_dataset(path)
            owns_ds = True

        short_to_long = get_short_to_long_name_mapping(ERA5_LAND)
        all_era5_vars = set(short_to_long.keys()) | set(short_to_long.values())
        is_era5_land = bool(set(ds.data_vars.keys()) & all_era5_vars)
        is_processed = ds.attrs.get('standardization_applied') == 'True'

        if owns_ds:
            ds.close()

        return {'is_era5_land': is_era5_land, 'is_processed': is_processed, 'error': None}

    except Exception as e:
        return {'is_era5_land': False, 'is_processed': False, 'error': str(e)}

# Main standardizer class

class ERA5StandardizerWL(BaseStandardizer, name="era5_land"):
    """
    Standardize ERA5-Land reanalysis data for hydrological modeling applications.
    
    This class provides comprehensive standardization of ERA5-Land data including
    temporal aggregation, unit corrections, secondary variable computation, and
    bias correction for improved accuracy in hydrological modeling.

    Parameters
    ---------- 
    data : Union[str, pathlib.Path, xr.Dataset]
        Path to NetCDF file, directory containing NetCDF files, or xarray Dataset.
        If directory path is provided, will process each ERA5-Land file individually.
    target_frequency : str, default 'auto'
        Target temporal frequency ('auto', 'hourly', 'daily', 'monthly').
        'auto' will automatically detect the optimal frequency based on input data.
    apply_bias_correction : bool, default False
        Whether to apply bias correction factors.
    bias_region : str, default 'ERA5-brittany'
        Region identifier for predefined bias factors.
    custom_bias_factors : Optional[Dict[str, float]], default None
        Custom bias correction factors.
    progressive_bias : bool, default False
        Use progressive (monthly) bias instead of fixed bias.
    custom_progressive_factors : Optional[Dict[str, np.ndarray]], default None
        Custom monthly factors per variable (arrays of length 12).
    output_path : Union[bool, str, pathlib.Path, None], default None
        Output directory control:
        - False: No files saved, return datasets in memory only
        - None or True: Save to current directory or input file directory 
        - str/Path: Save to specified directory
    output_prefix : str, default 'ERA5Land'
        Prefix for output filenames. Final names: {prefix}_{variable}_{frequency}.nc
    force_reprocess : bool
        Force reprocessing even if already standardized
    compute_secondary : bool, list of str, or None
        Control secondary variable computation:
        - None or False: No secondary variables computed
        - True: Compute all possible secondary variables 
        - List[str]: Compute only specified secondary variables
        Available: ['wind_speed', 'relative_humidity', 'ET0', 'EW0']
    """
    
    def __init__(
        self,
        data: Union[str, pathlib.Path, xr.Dataset],
        target_frequency: str = 'auto',
        apply_bias_correction: bool = False,
        bias_region: str = None,
        custom_bias_factors: Optional[Dict[str, float]] = None,
        progressive_bias: bool = False,
        custom_progressive_factors: Optional[Dict[str, np.ndarray]] = None,
        output_path: Union[bool, str, pathlib.Path, None] = None,
        output_prefix: str = 'ERA5Land',
        force_reprocess: bool = False,
        compute_secondary: Optional[Union[bool, list]] = None,
        target_crs: Union[int, str, None] = 4326,
    ):
        self.data = data
        self.target_frequency = target_frequency
        self.apply_bias_correction = apply_bias_correction
        self.bias_region = bias_region
        self.custom_bias_factors = custom_bias_factors or {}
        self.progressive_bias = progressive_bias
        self.custom_progressive_factors = custom_progressive_factors or {}
        self.force_reprocess = force_reprocess
        self.compute_secondary = compute_secondary
        self.output_prefix = output_prefix
        self.target_crs = target_crs
        
        # Track actual bias correction applied
        self.bias_correction_applied = False
        self.applied_bias_factors = {}
        
        # Handle output_path
        if output_path is False:
            self.output_dir = None
            self._want_save = False  # Explicitly don't save
        elif isinstance(output_path, (str, pathlib.Path)):
            self.output_dir = pathlib.Path(output_path)
            self._want_save = True
        elif output_path is None:
            self.output_dir = None  # Will use source directory when saving
            self._want_save = True
        else:
            self.output_dir = "auto"
            self._want_save = True

        # Disable progressive bias if no bias correction is applied
        if self.apply_bias_correction is False:
            self.progressive_bias = False
            self.bias_region = None

    def _scan_directory_for_era5_files(self, directory_path: pathlib.Path) -> list:
        """
        Scan directory for ERA5-Land NetCDF files with validation and user confirmation.
        
        Opens each file only once to check format and processing status.
        
        Parameters
        ----------
        directory_path : pathlib.Path
            Path to directory to scan
            
        Returns
        -------
        list
            List of valid ERA5-Land file paths
        """
        print(f"Scanning directory for ERA5-Land files: {directory_path}")
        
        # Find all NetCDF files
        nc_files = list(directory_path.glob("*.nc"))
        
        if not nc_files:
            print(f"No NetCDF files found in {directory_path}")
            return []
        
        print(f"Found {len(nc_files)} NetCDF files, validating ERA5-Land format...")
        
        # Check each file and get status in one go
        valid_files = []
        already_processed = []
        invalid_files = []
        
        for file_path in nc_files:
            validation_result = _validate_era5_input(file_path)
            
            if validation_result['error']:
                print(f"Validation error for {file_path.name}: {validation_result['error']}")
                invalid_files.append(file_path)
            elif not validation_result['is_era5_land']:
                invalid_files.append(file_path)
            elif validation_result['is_processed']:
                already_processed.append(file_path)
            else:
                valid_files.append(file_path)
        
        # Report results
        print(f"Validation complete:")
        print(f"  Files with ERA5-Land variables: {len(valid_files)}")
        print(f"  Already standardized: {len(already_processed)}")
        print(f"  Files without ERA5-Land variables: {len(invalid_files)}")
        
        if already_processed:
            print(f"Found {len(already_processed)} already standardized files (will skip):")
            for f in already_processed[:5]:  # Show first 5
                print(f"  - {f.name}")
            if len(already_processed) > 5:
                print(f"  ... and {len(already_processed) - 5} more")
        
        if invalid_files:
            print(f"Found {len(invalid_files)} files without ERA5-Land variables (will skip):")
            for f in invalid_files[:3]:
                print(f"  - {f.name}")
            if len(invalid_files) > 3:
                print(f"  ... and {len(invalid_files) - 3} more")
        
        # Safety check for large number of files
        if len(valid_files) > 50:
            print(f"\nWarning: Found {len(valid_files)} ERA5-Land files to process.")
            print("This could take a significant amount of time and storage space.")
            response = input("Do you want to continue? (yes/no): ").lower().strip()
            if response not in ['yes', 'y', 'oui']:
                print("Processing cancelled by user")
                return []
        
        return valid_files

    def _process_directory_files(self, directory_path: pathlib.Path) -> Dict[str, pathlib.Path]:
        """
        Process ERA5-Land NetCDF files from a directory individually.
        
        Scans directory for ERA5-Land files and processes each one separately,
        returning a dictionary of standardized files.
        
        Parameters
        ----------
        directory_path : pathlib.Path
            Path to directory containing NetCDF files.
            
        Returns
        -------
        Dict[str, pathlib.Path]
            Dictionary mapping original filenames to output file paths.
        """
        # Scan and validate files
        valid_files = self._scan_directory_for_era5_files(directory_path)
        
        if not valid_files:
            print(f"No files with ERA5-Land variables found in {directory_path}")
            return {}
        
        print(f"Processing {len(valid_files)} files individually")
        
        processed_files = {}
        
        for file_path in valid_files:
            print(f"Processing file: {file_path.name}")
            
            try:
                # Process file directly without creating new standardizer to avoid recursion
                # Load the file
                ds = geo.load(file_path)
                if ds is None:
                    print(f"Could not load file: {file_path.name}")
                    continue
                
                # Process this file individually (disable auto-save to avoid duplication)
                file_standardizer = _create_standardizer(
                    data_input=ds,
                    target_frequency=self.target_frequency,
                    apply_bias_correction=self.apply_bias_correction,
                    bias_region=self.bias_region,
                    custom_bias_factors=self.custom_bias_factors,
                    progressive_bias=self.progressive_bias,
                    custom_progressive_factors=self.custom_progressive_factors,
                    output_path=False,  # Disable auto-save to prevent duplication
                    output_prefix=self.output_prefix,
                    force_reprocess=self.force_reprocess,
                    compute_secondary=self.compute_secondary,
                    target_crs=self.target_crs,
                )
                
                # Process the dataset
                result = file_standardizer.standardize()
                
                # Save with proper naming using parent's output settings
                if self._want_save:
                    # Get variable name for filename (use long name)
                    var_name = list(result.data_vars)[0] if len(result.data_vars) == 1 else 'data'
                    output_file_path = self._save_dataset(result, file_path, variable_name=var_name)
                    
                    if output_file_path:
                        processed_files[file_path.name] = output_file_path
                        print(f"Processed: {file_path.name} -> {output_file_path.name}")
                    else:
                        print(f"Failed to save: {file_path.name}")
                else:
                    processed_files[file_path.name] = result
                    print(f"Processed: {file_path.name} (in memory)")
                    
            except Exception as e:
                print(f"Error processing {file_path.name}: {e}")
                continue
        
        return processed_files
            
    def standardize(self) -> Union[xr.Dataset, Dict[str, pathlib.Path]]:
        """
        Main standardization workflow.
        
        Returns
        -------
        Union[xr.Dataset, Dict[str, pathlib.Path]]
            For single files/datasets: standardized dataset
            For directories: dictionary mapping original filenames to output paths
            
        Raises
        ------
        ValueError
            If input is not ERA5-Land format or cannot be loaded
        """
        # Check if input contains ERA5-Land variables
        validation = _validate_era5_input(self.data)
        if not validation['is_era5_land']:
            error_msg = validation.get('error', 'Unknown validation error')
            raise ValueError(f"Input data does not contain recognized ERA5-Land variables. {error_msg}")

        # Process directory files individually
        if isinstance(self.data, (str, pathlib.Path)):
            input_path = pathlib.Path(self.data)
            if input_path.is_dir():
                print(f"Directory input detected: {input_path}")
                return self._process_directory_files(input_path)

        # Process single file or dataset
        if isinstance(self.data, xr.Dataset):
            ds = self.data.copy()
            input_path = None
        else:
            input_path = pathlib.Path(self.data)

            # Reuse the validation we already did on disk: avoids reopening
            # the same file just to check the standardization stamp.
            if not self.force_reprocess and validation['is_processed']:
                print(f"File {input_path.name} is already standardized (use force_reprocess=True to override)")
                ds = geo.load(input_path)
                if ds is None:
                    raise ValueError(f"Could not load already processed file from {input_path}")
                if self._want_save:
                    var_name = list(ds.data_vars)[0] if len(ds.data_vars) == 1 else ''
                    self._save_dataset(ds, input_path, variable_name=var_name)
                return ds

            ds = geo.load(input_path)
            if ds is None:
                raise ValueError(f"Could not load data from {input_path}")
        
        print(f"Processing {len(ds.data_vars)} variables: {list(ds.data_vars)}")
        
        # Skip if already processed
        if not self.force_reprocess and self._is_already_processed(ds):
            print("Dataset already standardized, skipping processing (use force_reprocess=True to override)")
            if self._want_save:
                var_name = list(ds.data_vars)[0] if len(ds.data_vars) == 1 else ''
                self._save_dataset(ds, input_path, variable_name=var_name)
            return ds
            
        # Convert short variable names to long names
        ds = self._rename_variables_to_long_names(ds)
        
        # Determine if dataset has time dimension
        has_time_dimension = 'time' in ds.dims or any(dim_name in ds.dims for dim_name in ['valid_time', 't', 'time0'])
        
        if has_time_dimension:
            print("Time-varying dataset detected")

            ds = self._standardize_time_dimension(ds)

            current_freq = detect_frequency(ds)
            if self.target_frequency == 'auto':
                self.target_frequency = current_freq
                print(f"Auto-detected frequency: {self.target_frequency}")

            ds = self._apply_temporal_aggregation(ds, current_freq=current_freq)
        else:
            print("Invariant dataset detected (no time dimension)")
            
            # For invariant variables, always override target frequency
            if self.target_frequency != 'invariant':
                original_frequency = self.target_frequency
                self.target_frequency = 'invariant'
                print(f"Overriding target frequency from '{original_frequency}' to 'invariant' for time-independent data")
        
        ds = self._apply_unit_conversions(ds)
        
        if self.compute_secondary is not None and self.compute_secondary is not False:
            ds = self._compute_secondary_variables(ds)
        else:
            print("Secondary variable computation disabled")
        
        # Apply bias corrections
        if self.apply_bias_correction:
            ds = self._apply_bias_correction(ds)

        # Add spatial reference (ERA5-Land native CRS = WGS84)
        ds = geo.georef(ds, crs=4326)

        # Reproject to the user-requested target CRS (no-op when the
        # request matches WGS84 or is None).
        ds = self._reproject_to_target(ds)

        # Add CF metadata
        ds = self._add_cf_metadata(ds)

        ds = self._add_standardization_metadata(ds)

        # Clean encoding conflicts
        ds = clean_encoding_conflicts(ds)

        # Final CF-compliance polish (drop GRIB ensemble leftovers,
        # _FillValue on coordinates, cast int64 stored types to int32)
        ds = self._polish_cf_compliance(ds)
        
        if self._want_save:
            var_name = ''
            if len(ds.data_vars) == 1:
                var_name = list(ds.data_vars)[0]
            self._save_dataset(ds, input_path, variable_name=var_name)
            
        return ds
    
    def _is_already_processed(self, ds: xr.Dataset) -> bool:
        """Check if dataset was already standardized."""
        is_standardized = ds.attrs.get('standardization_applied', 'False') == 'True'
        stored_freq = ds.attrs.get('target_frequency', None)
        # 'auto' matches whatever frequency the previous run resolved to.
        freq_matches = (
            self.target_frequency == 'auto'
            or stored_freq == self.target_frequency
        )
        bias_matches = ds.attrs.get('standardization_apply_bias_correction', 'False') == str(self.apply_bias_correction)

        return is_standardized and freq_matches and bias_matches

    def _is_file_already_processed(self, file_path: pathlib.Path) -> bool:
        """Check if file was already standardized."""
        return _validate_era5_input(file_path).get('is_processed', False)


    def _complete_variable_metadata(self, ds: xr.Dataset, var_name: str) -> xr.Dataset:
        """
        Complete variable metadata using existing attrs when available.

        Uses metadata from dataset first, then fills missing info from database.

        Parameters
        ----------
        ds : xr.Dataset
            Dataset containing the variable.
        var_name : str
            Variable name.

        Returns
        -------
        xr.Dataset
            Dataset with completed metadata.
        """
        if var_name not in ds.data_vars:
            return ds

        var_attrs = ds[var_name].attrs

        # Check if metadata already complete
        has_units = 'units' in var_attrs and var_attrs['units']
        has_description = 'long_name' in var_attrs and var_attrs['long_name']

        if has_units and has_description:
            print(f"   _ Using existing metadata for '{var_name}' (units: {var_attrs['units']})")
            return ds

        # Complete missing metadata from database
        var_info = get_variable_info(ERA5_LAND, var_name)
        if var_info:
            updates = {}
            if not has_units:
                updates['units'] = var_info['units']
            if not has_description:
                updates['long_name'] = var_info['description']

            ds[var_name].attrs.update(updates)
            print(f"   _ Added metadata for '{var_name}' from database: {list(updates.keys())}")
        else:
            print(f"   _ Warning: No metadata found for '{var_name}'")

        return ds

    def _rename_variables_to_long_names(self, ds: xr.Dataset) -> xr.Dataset:
        """Convert ERA5-Land short names to official long names."""
        short_to_long = get_short_to_long_name_mapping(ERA5_LAND)

        rename_dict = {}
        for var_name in ds.data_vars:
            if var_name in short_to_long:
                long_name = short_to_long[var_name]
                rename_dict[var_name] = long_name
                print(f"Will rename '{var_name}' to '{long_name}'")

        if rename_dict:
            ds_renamed = ds.rename(rename_dict)
            print(f"Renamed {len(rename_dict)} variables to long names: {list(rename_dict.values())}")

            # Complete metadata for renamed variables
            for new_name in rename_dict.values():
                ds_renamed = self._complete_variable_metadata(ds_renamed, new_name)

            return ds_renamed
        else:
            print("No short names to rename found")
            return ds
        
    def _standardize_time_dimension(self, ds: xr.Dataset) -> xr.Dataset:
        """Standardize time dimension names."""
        ds_standardized = geo.standardize_time_coord(ds)
        
        # Rename time dimension to 'time'
        if 'time' not in ds_standardized.dims and 'time' in ds_standardized.coords:
            for dim_name in ['valid_time', 't', 'time0']:
                if dim_name in ds_standardized.dims:
                    ds_standardized = ds_standardized.rename({dim_name: 'time'})
                    break
        
        if 'time' not in ds_standardized.dims:
            raise ValueError("No time dimension found in dataset")
            
        print("Standardized time coordinates using geobricks")
        return ds_standardized
        
    def _apply_temporal_aggregation(self, ds: xr.Dataset, current_freq: Optional[str] = None) -> xr.Dataset:
        """Apply temporal aggregation."""
        if current_freq is None:
            current_freq = detect_frequency(ds)

        if current_freq == self.target_frequency:
            print("No temporal aggregation needed")
            return ds
            
        print(f"Aggregating from {current_freq} to {self.target_frequency}")
        
        # Set aggregation mode for each variable
        agg_dict = {}
        for var in ds.data_vars:
            if var in TEMPORAL_AGGREGATION_MODES:
                agg_dict[var] = TEMPORAL_AGGREGATION_MODES[var]['mode']
                print(f"Using {TEMPORAL_AGGREGATION_MODES[var]['mode']} aggregation for '{var}'")
            else:
                agg_dict[var] = 'mean'

        if self.target_frequency == 'daily':
            resample_rule = 'D'
        elif self.target_frequency == 'monthly':
            resample_rule = 'MS'
        else:
            print(f"Aggregation to {self.target_frequency} not implemented")
            return ds

        # ERA5-Land hourly accumulation variables are stored as
        # cumulative-since-00-UTC, so naively summing 24 of them inflates
        # the daily total by ~10x. Convert them to per-hour increments
        # before resampling. Work on DataArrays directly to avoid the
        # xarray-Dataset alignment that would re-pad the deaccumulated
        # arrays back to the original (1h longer) time index with NaNs.
        inputs = {}
        for var in ds.data_vars:
            if current_freq == 'hourly' and var in ACCUMULATION_VARS:
                print(f"Deaccumulating cumulative hourly variable '{var}'")
                inputs[var] = _deaccumulate_hourly_cumulative(ds[var])
            else:
                inputs[var] = ds[var]

        # If any deaccumulation happened, those arrays are 1h shorter
        # than the rest. Trim mean/min/max variables to the same cutoff
        # so the final merged Dataset has a consistent time axis (and so
        # the boundary hour fetched by download_era5 doesn't leak into
        # the output as a spurious extra day).
        shorter = [a for a in inputs.values() if a.sizes['time'] < ds.sizes['time']]
        if shorter:
            cutoff = shorter[0]['time'].max().values
            for var, da in inputs.items():
                if da.sizes['time'] > shorter[0].sizes['time']:
                    inputs[var] = da.sel(time=slice(None, cutoff))

        # Apply resampling
        aggregated_vars = {}
        for var, mode in agg_dict.items():
            var_resampler = inputs[var].resample(time=resample_rule)
            if mode == 'sum':
                aggregated_vars[var] = var_resampler.sum(keep_attrs=True)
            elif mode == 'min':
                aggregated_vars[var] = var_resampler.min(keep_attrs=True)
            elif mode == 'max':
                aggregated_vars[var] = var_resampler.max(keep_attrs=True)
            else:
                aggregated_vars[var] = var_resampler.mean(keep_attrs=True)
                
        # xr.merge keeps each DataArray's own coordinates (including
        # scalar ones such as height=2 for t2m), unlike rebuilding a
        # Dataset from a single variable's coords.
        if aggregated_vars:
            merged = xr.merge(
                [da.rename(name) for name, da in aggregated_vars.items()]
            )
            merged.attrs = dict(ds.attrs)
            return merged
        else:
            print("No variables to aggregate, returning original dataset")
            return ds
        
            
    def _apply_unit_conversions(self, ds: xr.Dataset) -> xr.Dataset:
        """Apply unit conversions for radiation and evaporation variables."""
        ds_converted = ds.copy()
        
        is_daily = self.target_frequency == 'daily'
        
        for var in ds.data_vars:
            # Convert radiation units
            if var in RADIATION_VARS:
                print(f"Converting {var} from J/m² to W/m² using geobricks")
                var_ds = xr.Dataset({var: ds_converted[var]})
                converted_ds = geo.convert_downwards_radiation(var_ds, is_dailysum=is_daily)
                ds_converted[var] = converted_ds[var]
                print(f"Successfully converted {var} using geobricks")
                
            # Fix evaporation sign (ERA5 convention)
            elif var in SIGN_CORRECTION_VARS:
                print(f"Applying sign correction to {var}")
                original_attrs = ds_converted[var].attrs.copy()
                
                ds_converted[var] = -ds_converted[var]
                
                ds_converted[var].attrs.update(original_attrs)
                if 'long_name' in original_attrs:
                    ds_converted[var].attrs['long_name'] = original_attrs['long_name'].replace('(NEGATIVE)', '(positive)')
                
                # Clean encoding after sign change
                if hasattr(ds_converted[var], 'encoding'):
                    encoding_cleaned = {}
                    for k, v in ds_converted[var].encoding.items():
                        if k == 'add_offset':
                            encoding_cleaned[k] = -v
                        elif k == 'scale_factor':
                            encoding_cleaned[k] = -v
                        else:
                            encoding_cleaned[k] = v
                    ds_converted[var].encoding = encoding_cleaned
                
                print(f"Applied sign correction to {var}")
                
        return ds_converted
        
    def _compute_secondary_variables(self, ds: xr.Dataset) -> xr.Dataset:
        """Compute secondary variables."""
        ds_with_secondary = ds.copy()
        
        if self.compute_secondary is True:
            requested_vars = set(SECONDARY_VARIABLES.keys())
        elif isinstance(self.compute_secondary, list):
            requested_vars = set(self.compute_secondary)
            invalid_vars = requested_vars - set(SECONDARY_VARIABLES.keys())
            if invalid_vars:
                print(f"Unknown secondary variables requested: {invalid_vars}. Available: {list(SECONDARY_VARIABLES.keys())}")
                requested_vars = requested_vars - invalid_vars
        else:
            return ds_with_secondary
            
        if not requested_vars:
            print("No valid secondary variables to compute")
            return ds_with_secondary
            
        print(f"Computing secondary variables: {requested_vars}")
        
        for var_name in requested_vars:
            if var_name in ds.data_vars:
                print(f"Variable {var_name} already exists, skipping")
                continue

            var_config = SECONDARY_VARIABLES[var_name]

            # Variables have already been renamed to long names earlier in
            # the pipeline, so we only need to look them up by long name.
            required_vars = [r for r in var_config['requires'] if r in ds.data_vars]
            missing_vars = [r for r in var_config['requires'] if r not in ds.data_vars]

            if missing_vars:
                print(f"Cannot compute {var_name}: missing {missing_vars} (available: {list(ds.data_vars)})")
                continue
                
            try:
                func_name = var_config['function']
                geo_function = getattr(geo, func_name)

                if var_name == 'wind_speed':
                    result_ds = geo_function(ds[required_vars[0]], ds[required_vars[1]])
                    result = result_ds['wind_speed']
                elif var_name == 'relative_humidity':
                    result_ds = geo_function(
                        temperature_input_file=ds[required_vars[0]],
                        dewpoint_input_file=ds[required_vars[1]],
                        pressure_input_file=ds[required_vars[2]],
                    )
                    result = result_ds['relative_humidity']
                elif var_name in ('ET0', 'EW0'):
                    et0_ds, ew0_ds = geo.compute_Erefs_from_Epan(ds[required_vars[0]])
                    chosen = et0_ds if var_name == 'ET0' else ew0_ds
                    # The helper returns a Dataset whose variable name
                    # mirrors the input (e.g. 'potential_evaporation');
                    # extract the DataArray by taking the first non-CRS var.
                    pick = next(v for v in chosen.data_vars if v != 'spatial_ref')
                    result = chosen[pick]
                else:
                    result = geo_function(ds[required_vars[0]])
                    if isinstance(result, xr.Dataset):
                        pick = next(v for v in result.data_vars if v != 'spatial_ref')
                        result = result[pick]

                # Override metadata so the secondary var carries its own
                # description/units from the dataset config rather than
                # the input variable's.
                result.attrs.update({
                    'long_name': var_config['description'],
                    'units': var_config['units'],
                })

                # Drop scale/offset since the value range is different.
                if hasattr(result, 'encoding'):
                    result.encoding = {
                        k: v for k, v in result.encoding.items()
                        if k not in ('scale_factor', 'add_offset')
                    }

                ds_with_secondary[var_name] = result
                print(f"Successfully computed {var_name} from {required_vars} using geo.{func_name}")

            except Exception as e:
                print(f"Failed to compute {var_name}: {e}")
                
        return ds_with_secondary
        
    def _apply_bias_correction(self, ds: xr.Dataset) -> xr.Dataset:
        """Apply bias correction (progressive or fixed)."""
        ds_corrected = ds.copy()

        # Region keys in the dataset config encode the frequency they were
        # calibrated against ('ERA5-brittany-daily', ...). Warn if the user
        # asked for a target frequency that doesn't match.
        if self.bias_region and self.target_frequency not in (None, 'auto'):
            for freq_tag in ('hourly', 'daily', 'monthly'):
                if self.bias_region.endswith(f'-{freq_tag}') and freq_tag != self.target_frequency:
                    print(
                        f"Warning: bias region '{self.bias_region}' was calibrated for "
                        f"'{freq_tag}' data but target frequency is '{self.target_frequency}'. "
                        "Factors are dimensionally tied to a frequency; results may be wrong."
                    )
                    break

        if self.progressive_bias:
            print("Checking for progressive (monthly) bias correction factors")
            regional = (
                get_bias_correction_factors(ERA5_LAND, self.bias_region, progressive=True)
                if self.bias_region else {}
            )

            monthly_factors = {}
            for var in ds.data_vars:
                if var in self.custom_progressive_factors:
                    monthly_factors[var] = self.custom_progressive_factors[var]
                    print(f"Using custom progressive factors for '{var}'")
                elif var in regional:
                    monthly_factors[var] = regional[var]
                    print(f"Found progressive factors for '{var}' in region '{self.bias_region}'")

            if monthly_factors:
                print(f"Applying progressive bias correction to {len(monthly_factors)} variables")
                ds_corrected = geo.correct_bias(
                    ds_corrected,
                    variables=list(monthly_factors.keys()),
                    progressive=True,
                    progressive_factors=monthly_factors
                )
                self.bias_correction_applied = True
                self.applied_bias_factors = monthly_factors
            else:
                print(f"No progressive bias factors found for region '{self.bias_region}' - no correction applied")

        else:
            print("Checking for fixed bias correction factors")
            regional = (
                get_bias_correction_factors(ERA5_LAND, self.bias_region)
                if self.bias_region else {}
            )

            base_factors = {}
            for var in ds.data_vars:
                if var in self.custom_bias_factors:
                    base_factors[var] = self.custom_bias_factors[var]
                    print(f"Using custom fixed factor for '{var}': {self.custom_bias_factors[var]}")
                elif var in regional:
                    base_factors[var] = regional[var]
                    print(f"Found fixed factor for '{var}' in region '{self.bias_region}': {regional[var]}")

            if base_factors:
                print(f"Applying fixed bias correction to {len(base_factors)} variables")
                ds_corrected = geo.correct_bias(
                    ds_corrected,
                    variables=base_factors,
                    progressive=False
                )
                self.bias_correction_applied = True
                self.applied_bias_factors = base_factors
            else:
                print(f"No bias factors found for region '{self.bias_region}' - no correction applied")

        return ds_corrected
    
    def _add_cf_metadata(self, ds: xr.Dataset) -> xr.Dataset:
        """Add CF-compliant metadata."""
        import datetime
        
        # Global attributes
        if self.target_frequency == 'invariant':
            # For invariant variables, adapt the title and comment
            title_text = 'Standardized ERA5-Land reanalysis data - time-independent variables'
            comment_text = f'Time-independent data processed, bias correction: {self.bias_correction_applied}'
        else:
            title_text = f'Standardized ERA5-Land reanalysis data - {self.target_frequency} frequency'
            comment_text = f'Data processed with temporal aggregation: {self.target_frequency}, bias correction: {self.bias_correction_applied}'
        
        global_attrs = {
            # We declare both CF-1.7 (the latest version recognized by
            # the IOOS compliance-checker) and CF-1.8 (which adds nothing
            # we use but is the most recent version explicitly supported
            # by mainstream tooling). The CF spec accepts a space- or
            # comma-separated list.
            'Conventions': 'CF-1.7 CF-1.8',
            
            'title': title_text,
            'institution': 'Processed using geop4th - Original data: European Centre for Medium-Range Weather Forecasts (ECMWF)',
            'source': 'ERA5-Land atmospheric reanalysis - standardized for hydrological modeling',
            'history': f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}: Standardized using geop4th ERA5StandardizerWL',
            'references': 'Muñoz Sabater, J., et al. (2021): ERA5-Land: a state-of-the-art global reanalysis dataset for land applications. Earth Syst. Sci. Data, 13, 4349–4383. https://doi.org/10.5194/essd-13-4349-2021',
            'comment': comment_text,
            
            'geop4th_version': geo.get_geop4th_version(),
            'processing_level': 'L3',  # Aggregated and standardized data 
        }
        
        # Preserve provenance metadata from original dataset
        provenance_attrs = ['data_source', 'download_date', 'clip_date', 'computation_date', 'data_provenance']
        for attr in provenance_attrs:
            if attr in ds.attrs:
                global_attrs[attr] = ds.attrs[attr]
        
        ds.attrs.update(global_attrs)
        
        # Variable metadata
        for var_name in ds.data_vars:
            var = ds[var_name]

            # Complete missing metadata
            ds = self._complete_variable_metadata(ds, var_name)

            # Add standard_name and cell_methods from database. The raw
            # ERA5-Land NetCDF returned by CDS often carries a literal
            # 'unknown' as standard_name (a leftover from the GRIB->NC
            # conversion), which is not a CF standard name. Replace it
            # with the value from the dataset config when available, or
            # drop it so the file remains CF-valid.
            var_info = get_variable_info(ERA5_LAND, var_name)
            current_sn = str(var.attrs.get('standard_name', '')).strip().lower()
            placeholder = current_sn in ('', 'unknown')
            if var_info and 'standard_name' in var_info and placeholder:
                var.attrs['standard_name'] = var_info['standard_name']
            elif placeholder and 'standard_name' in var.attrs:
                del var.attrs['standard_name']

            # CDS uses Fortran-style exponents in the unit strings (e.g.
            # 'm s**-1', 'J m**-2'). UDUNITS-2 accepts both, but CF
            # convention idiomatically prefers the bare form ('m s-1').
            u = var.attrs.get('units')
            if isinstance(u, str) and '**' in u:
                var.attrs['units'] = u.replace('**', '')

            if var_info and self.target_frequency != 'hourly':
                if var_name in TEMPORAL_AGGREGATION_MODES:
                    mode = TEMPORAL_AGGREGATION_MODES[var_name]['mode']
                    var.attrs['cell_methods'] = f'time: {mode}'
        
        # Time coordinate attributes
        if 'time' in ds.coords:
            time_coord = ds['time']
            time_attrs = time_coord.attrs.copy()
            
            if 'axis' not in time_attrs:
                time_attrs['axis'] = 'T'
                
            ds['time'].attrs.update(time_attrs)
        
        return ds
        
    def _add_standardization_metadata(self, ds: xr.Dataset) -> xr.Dataset:
        """Add processing metadata."""
        attrs = {
            'standardization_applied': 'True',
            'standardization_date': pd.Timestamp.now().isoformat(),
            'target_frequency': self.target_frequency,
            'standardization_apply_bias_correction': str(self.bias_correction_applied),
            'standardization_progressive_bias': str(self.progressive_bias),
            'standardization_secondary_variables': 'automatic',
        }
        
        # Add bias correction details only if actually applied
        if self.bias_correction_applied:
            if self.bias_region is not None:
                attrs['standardization_bias_region'] = self.bias_region
            
            # Store the actual bias factors used
            if self.progressive_bias:
                attrs['bias_correction_type'] = 'progressive_monthly'
                # Store monthly factors as strings for each variable
                for var, factors in self.applied_bias_factors.items():
                    if hasattr(factors, '__iter__') and not isinstance(factors, str):
                        # Convert array to string representation
                        factors_str = '[' + ', '.join(f'{f:.6f}' for f in factors) + ']'
                        attrs[f'bias_factors_{var}'] = factors_str
                    else:
                        attrs[f'bias_factors_{var}'] = str(factors)
            else:
                attrs['bias_correction_type'] = 'fixed_multiplicative'
                # Store fixed factors
                for var, factor in self.applied_bias_factors.items():
                    attrs[f'bias_factors_{var}'] = str(factor)
        else:
            # Explicitly state no bias correction was applied
            attrs['bias_correction_type'] = 'none'
            if self.bias_region is not None:
                attrs['standardization_bias_region_requested'] = f'{self.bias_region} (no factors found)'
        
        ds.attrs.update(attrs)
        return ds

    def _reproject_to_target(self, ds: xr.Dataset) -> xr.Dataset:
        """Reproject to ``self.target_crs``; no-op when CRS already matches."""
        if self.target_crs is None or ds.rio.crs is None:
            return ds
        try:
            current_epsg = ds.rio.crs.to_epsg()
        except Exception:
            current_epsg = None
        if current_epsg is not None and current_epsg == self.target_crs:
            return ds
        return geo.reproject(ds, dst_crs=self.target_crs)

    def _polish_cf_compliance(self, ds: xr.Dataset) -> xr.Dataset:
        """Drop ERA5/GRIB leftovers, then apply the shared CF-1.7 polish.

        ERA5-specific: the scalar ``number`` ensemble-member coordinate
        added by ECMWF's GRIB->NetCDF conversion carries no information
        (ERA5-Land is deterministic) and breaks CF-1.7 strict checks.
        Drop the coord and strip ``coordinates: 'number'`` attributes
        from any data variable that referenced it.
        """
        if 'number' in ds.coords:
            ds = ds.drop_vars('number')
        for v in ds.data_vars:
            coords_attr = ds[v].attrs.get('coordinates', '')
            if isinstance(coords_attr, str) and 'number' in coords_attr.split():
                tokens = [t for t in coords_attr.split() if t != 'number']
                if tokens:
                    ds[v].attrs['coordinates'] = ' '.join(tokens)
                else:
                    del ds[v].attrs['coordinates']

        return polish_cf_compliance(ds)

    def _extract_variable_name(self, ds: xr.Dataset, input_path: Optional[pathlib.Path] = None, variable_name: str = '') -> str:
        """Extract variable name for filename generation."""
        if variable_name:
            return variable_name
        return list(ds.data_vars)[0] if ds.data_vars else 'data'
        
    def _save_dataset(self, ds: xr.Dataset, input_path: Optional[pathlib.Path] = None, variable_name: str = '') -> pathlib.Path:
        """Save dataset to file with proper naming."""
        
        if self.output_dir is None:
            output_dir = input_path.parent if input_path else pathlib.Path.cwd()
        elif self.output_dir == "auto":
            output_dir = (input_path.parent / "auto") if input_path else (pathlib.Path.cwd() / "auto")
        else:
            output_dir = pathlib.Path(self.output_dir)
            
        var_name = self._extract_variable_name(ds, input_path, variable_name)
            
        date_range = ""
        if 'time' in ds.dims and ds.sizes['time'] > 0:
            try:
                time_values = ds['time'].values
                start_date = pd.to_datetime(time_values[0]).strftime('%Y%m%d')
                end_date = pd.to_datetime(time_values[-1]).strftime('%Y%m%d')
                
                if start_date == end_date:
                    date_range = f"_{start_date}"
                else:
                    date_range = f"_{start_date}{end_date}"
                    
                print(f"Extracted date range: {date_range}")
            except Exception as e:
                print(f"Could not extract date range: {e}")
                date_range = ""
            
        # Build filename
        processing_level = "DER" if var_name in SECONDARY_VARIABLES else "STD"
        
        # For invariant variables, don't include frequency
        if self.target_frequency == 'invariant':
            filename_parts = [part for part in [self.output_prefix, var_name, processing_level] if part is not None]
        else:
            filename_parts = [part for part in [self.output_prefix, var_name, processing_level, self.target_frequency] if part is not None]
        
        filename = '_'.join(filename_parts) + date_range + '.nc'
            
        output_file_path = output_dir / filename
        
        output_file_path.parent.mkdir(parents=True, exist_ok=True)
        
        print(f"Saving to: {output_file_path}")
        geo.export(ds, str(output_file_path))
        
        return output_file_path

# Helper functions

def _create_standardizer(data_input, **kwargs):
    """Create a standardizer, only forwarding kwargs the caller actually set."""
    forwarded = {k: v for k, v in kwargs.items() if v is not None}
    return ERA5StandardizerWL(data=data_input, **forwarded)

def _validate_era5_batch(file_dict: Dict[str, Union[str, pathlib.Path]]) -> list:
    """Check all files in batch contain ERA5-Land variables."""
    invalid_files = []

    for var_name, file_path in file_dict.items():
        file_path_obj = pathlib.Path(file_path)
        if not file_path_obj.exists():
            invalid_files.append(f"{var_name}: File not found - {file_path}")
            continue

        if file_path_obj.is_file():
            validation_result = _validate_era5_input(file_path_obj)
            if not validation_result.get('is_era5_land', False):
                error_msg = validation_result.get('error', 'No ERA5-Land variables found')
                invalid_files.append(f"{var_name}: {error_msg} - {file_path}")

    return invalid_files

# Public API functions
[docs] def standardize_era5_land( data: Union[str, pathlib.Path, xr.Dataset, Dict[str, Union[str, pathlib.Path]]], target_frequency: str = 'auto', apply_bias_correction: bool = False, bias_region: str = None, custom_bias_factors: Optional[Dict[str, float]] = None, progressive_bias: bool = False, custom_progressive_factors: Optional[Dict[str, np.ndarray]] = None, output_path: Union[bool, str, pathlib.Path, None] = None, output_prefix: str = 'ERA5Land', force_reprocess: bool = False, compute_secondary: Optional[Union[bool, list]] = None, target_crs: Union[int, str, None] = 4326, ) -> Union[xr.Dataset, Dict[str, pathlib.Path]]: """Standardize ERA5-Land data. Processes files individually without merging. Parameters ---------- data : str, pathlib.Path, xr.Dataset, or dict Input data: - File path: process single file - Directory path: process all ERA5-Land files individually - Dataset: process in memory - Dict: process each file path individually target_frequency : str Target temporal frequency ('auto', 'hourly', 'daily', 'monthly') apply_bias_correction : bool Whether to apply bias correction bias_region : str Region for bias correction factors custom_bias_factors : dict, optional Custom bias factors {var_name: factor} progressive_bias : bool Use progressive (monthly) bias instead of fixed bias custom_progressive_factors : dict, optional Custom monthly factors {var_name: array[12]} output_path : bool, str, pathlib.Path, or None Output directory control: - False: No files saved, return datasets in memory only - None or True: Save to current directory or input file directory - str/Path: Save to specified directory output_prefix : str, default 'ERA5Land' Prefix for output filenames force_reprocess : bool Force reprocessing even if already standardized compute_secondary : Optional[Union[bool, list]] Control secondary variable computation: - None or False: No secondary variables computed - True: Compute all possible secondary variables - List[str]: Compute only specified secondary variables Available: ['wind_speed', 'relative_humidity', 'ET0', 'EW0'] Returns ------- Union[xr.Dataset, Dict[str, pathlib.Path]] - Single file/dataset input: standardized dataset or output path - Directory/dict input: dict of output paths """ # Handle batch input if isinstance(data, dict): print(f"Batch processing {len(data)} files individually") invalid_files = _validate_era5_batch(data) if invalid_files: raise ValueError(f"Files without ERA5-Land variables detected in batch input:\n" + "\n".join(invalid_files)) standardized_files = {} for var_name, file_path in data.items(): print(f"Processing file: {var_name}") try: # Create standardizer instance standardizer = _create_standardizer( data_input=file_path, target_frequency=target_frequency, apply_bias_correction=apply_bias_correction, bias_region=bias_region, custom_bias_factors=custom_bias_factors, progressive_bias=progressive_bias, custom_progressive_factors=custom_progressive_factors, output_path=output_path, output_prefix=output_prefix, force_reprocess=force_reprocess, compute_secondary=compute_secondary, target_crs=target_crs, ) result = standardizer.standardize() standardized_files[var_name] = result except Exception as e: print(f"Failed to process {var_name}: {e}") continue return standardized_files else: # Handle single input standardizer = _create_standardizer( data_input=data, target_frequency=target_frequency, apply_bias_correction=apply_bias_correction, bias_region=bias_region, custom_bias_factors=custom_bias_factors, progressive_bias=progressive_bias, custom_progressive_factors=custom_progressive_factors, output_path=output_path, output_prefix=output_prefix, force_reprocess=force_reprocess, compute_secondary=compute_secondary, target_crs=target_crs, ) result = standardizer.standardize() if isinstance(result, dict): return result return result