Source code for download_era5

# -*- coding: utf-8 -*-
"""
This script designed to group classes and functions for downloading 
all types of data at a worldwide scale.

Created on Wed Jun 03 2025
@author: Bastien Boivin
@contact: bastien.boivin@proton.me
"""

import pathlib
from typing import Union, List, Dict, Optional, Tuple, Any
from datetime import datetime
import hashlib
import time
import concurrent.futures
import shutil

import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import rioxarray  # ensure .rio accessor is registered
# CDS API key must be configured in ~/.cdsapirc
import cdsapi

from geop4th import (
    geobricks as geo,
    utils,
    )
import json

from geop4th.datasets import (
    ERA5_LAND,
    get_variables,
    get_variable_info,
    get_variable_mapping,
    get_short_to_long_name_mapping
)

# ERA5-Land configuration from centralized config
ERA5_INVARIANTS = get_variables(ERA5_LAND, variable_type='invariants')
ERA5_VARIABLES = ERA5_LAND['temporal_variables']
VARIABLE_PROFILES = ERA5_LAND['variable_profiles']
SECONDARY_VARIABLES = get_variables(ERA5_LAND, variable_type='secondary_variables')

def available():

    utils.available(__name__)


class ERA5LandDownloader:
    """
    Download ERA5-Land reanalysis data from the Copernicus Climate Data Store (CDS API).
    
    This class provides a unified interface for downloading ERA5-Land data with
    support for multiple data sources, flexible area specification, and optimized
    request queuing for the CDS API.
    
    Note: CDS API doesn't support true parallel downloads. n_parallel controls
    request queuing to minimize latency between consecutive downloads.
    
    Parameters
    ----------
    area : Union[str, pathlib.Path, tuple, gpd.GeoDataFrame]
        Area specification (bounding box, file path, or GeoDataFrame).
        For bounding box, use a tuple (North, West, South, East) in WGS84.
    variables : Union[str, List[str]]
        Variables to download (names, categories, or profiles).
    start_date : Union[str, datetime]
        Start date for download period. Supports partial dates:
        - '2025-04-23' -> exact date
        - '2025-04' -> first day of month (2025-04-01)  
        - '2025' -> first day of year (2025-01-01)
    end_date : Union[str, datetime]
        End date for download period. Supports partial dates:
        - '2025-04-23' -> exact date
        - '2025-04' -> last day of month (2025-04-30)
        - '2025' -> last day of year (2025-12-31)
    output_dir : Optional[Union[str, pathlib.Path]], default None
        Output directory for downloaded files.
        Defaults to current working directory with 'ERA5Land_data' subdirectory.
    n_parallel : int, default 1
        Number of concurrent request handlers for optimized queuing.
        Not true parallelization due to CDS API limitations. Recommended: 3-4.
        CDS recommends maximum 20 concurrent requests.
    keep_chunks : bool, default False
        Whether to keep intermediate monthly chunk files (raw data).
        If True, files are not deleted after processing and can be reused.
    force_overwrite : bool, default False
        Whether to overwrite existing files.
    output_prefix : str, default 'ERA5LAND'
        Prefix for output filenames. Final names: {prefix}_{variable}_RAW_{dates}.nc
    """
    
    def __init__(
        self,
        area: Union[str, pathlib.Path, tuple, gpd.GeoDataFrame],
        variables: Union[str, List[str]],
        start_date: Union[str, datetime],
        end_date: Union[str, datetime],
        output_dir: Optional[Union[str, pathlib.Path]] = None,
        n_parallel: int = 1,
        keep_chunks: bool = False,
        force_overwrite: bool = False,
        output_prefix: str = 'ERA5LAND'
    ):
        self.area = area
        self.variables = variables if isinstance(variables, list) else [variables]
        
        # Parse dates with smart handling of partial formats (YYYY, YYYY-MM, YYYY-MM-DD)
        try:
            self.start_date = self._parse_smart_date(start_date, is_end_date=False)
            self.end_date = self._parse_smart_date(end_date, is_end_date=True)
        except Exception as e:
            raise ValueError(f"Invalid date format: {e}")
        
        if self.start_date > self.end_date:
            raise ValueError(f"Invalid date range: start_date ({start_date}) cannot be after end_date ({end_date})")

        # Validate date ranges against ERA5-Land availability (roughly 1950 to present)
        max_future_date = pd.Timestamp.now()
        if self.end_date > max_future_date:
            print(f"Warning: End date ({self.end_date.strftime('%Y-%m-%d')}) is too far in future. "
                  f"ERA5-Land data may not be available beyond {max_future_date.strftime('%Y-%m-%d')}")
            self.end_date = max_future_date
            print(f"End date adjusted to: {self.end_date.strftime('%Y-%m-%d')}")
        
        min_date = pd.Timestamp('1950-01-01')
        if self.start_date < min_date:
            print(f"Warning: Start date ({self.start_date.strftime('%Y-%m-%d')}) is too far in past. "
                  f"ERA5-Land data may not be available before {min_date.strftime('%Y-%m-%d')}")
            self.start_date = min_date
            print(f"Start date adjusted to: {self.start_date.strftime('%Y-%m-%d')}")
        
        self.n_parallel = n_parallel
        self.keep_chunks = keep_chunks
        self.retry_attempts = 3  # Number of download retry attempts
        self.retry_delay = 10.0  # Base delay between retries (seconds)
        self.force_overwrite = force_overwrite
        self.output_prefix = output_prefix
        
        # Set up output directory
        if output_dir is None:
            output_dir = pathlib.Path.cwd() / 'ERA5Land_data'
        self.output_dir = pathlib.Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # Initialize CDS API client (requires ~/.cdsapirc configuration)
        try:
            self.client = cdsapi.Client()
        except Exception as e:
            print(f"Error: Failed to initialize CDS API client: {e}")
            print("Please ensure you have configured your CDS API key")
            raise
        
        # Process spatial area and variables
        self.bbox = self._process_area(area)
        self.area_segments = self._area_to_segments(self.bbox)  # Handle anti-meridian crossing
        self.processed_variables, self.secondary_vars_requested = self._process_variables(self.variables)
        
        self._validate_parameters()
        
        # Create temporary directory for monthly chunks
        self.temp_dir = self.output_dir / 'temp_chunks'
        self.temp_dir.mkdir(exist_ok=True)
        
        # Scan for existing files to avoid redundant downloads
        print("Scanning for existing files...")
        all_variable_names = self._get_all_possible_variable_names()
        print(f"Searching for {len(all_variable_names)} possible variable names")
        
        self.existing_files = geo.scan_files(
            paths=[self.output_dir, self.temp_dir],
            variables_to_find=all_variable_names,
            bbox=self.bbox,
            start_date=self.start_date,
            end_date=self.end_date,
            frequency=None,
            file_extension='*.nc'
        )
        
        total_existing_files = sum(len(files) for files in self.existing_files.values())
        if total_existing_files > 0:
            print(f"Found {total_existing_files} existing files for {len(self.existing_files)} variables")
            for var, files in self.existing_files.items():
                print(f"  {var}: {len(files)} files")
        else:
            print("No compatible existing files found")
        
        print(f"ERA5-Land downloader initialized")
        print(f"Area: {self.bbox}")
        if len(self.area_segments) > 1:
            print("Area crosses anti-meridian, will split into segments")
        print(f"Variables: {len(self.processed_variables)} primary variables")
        print(f"Period: {self.start_date.strftime('%Y-%m-%d')} to {self.end_date.strftime('%Y-%m-%d')}")
    
    def _parse_smart_date(self, date_input: Union[str, datetime], is_end_date: bool = False) -> pd.Timestamp:
        """Parse date with smart handling of partial dates."""
        if isinstance(date_input, datetime):
            return pd.Timestamp(date_input)
        
        date_str = str(date_input).strip()
        
        # Handle year-only format (e.g., "2024")
        if len(date_str) == 4 and date_str.isdigit():
            year = int(date_str)
            if is_end_date:
                return pd.Timestamp(year, 12, 31)  # Last day of year
            else:
                return pd.Timestamp(year, 1, 1)    # First day of year
        
        # Handle year-month format (e.g., "2024-03")
        if '-' in date_str:
            parts = date_str.split('-')
            if len(parts) == 2:
                try:
                    year = int(parts[0])
                    month = int(parts[1])
                    if 1 <= month <= 12:
                        if is_end_date:
                            return pd.Timestamp(year, month, 1) + pd.offsets.MonthEnd(0)  # Last day of month
                        else:
                            return pd.Timestamp(year, month, 1)  # First day of month
                except ValueError:
                    pass
        
        return pd.to_datetime(date_input)

    def download(self) -> Dict[str, pathlib.Path]:
        """Execute download workflow."""
        print("Starting ERA5-Land data download...")
        
        output_files = {}
        
        for i, variable in enumerate(self.processed_variables, 1):
            print(f"\nProcessing variable {i}/{len(self.processed_variables)}: {variable}")
            
            filename = self._generate_descriptive_filename(variable)
            output_file = self.output_dir / filename
            
            # Generate unique filename if file already exists
            if output_file.exists():
                name_part = output_file.stem
                extension = output_file.suffix
                request_hash = self._generate_request_hash()
                
                filename_with_hash = f"{name_part}_{request_hash}{extension}"
                output_file = self.output_dir / filename_with_hash
                print(f"Output file exists, using hashed name: {filename_with_hash}")
            
            existing_variable_key = self._find_existing_variable_key(variable)
            
            # Check if we can reuse existing files (unless forcing overwrite)
            if not self.force_overwrite and existing_variable_key:
                invariant_vars = get_variables(ERA5_LAND, variable_type='invariants')
                is_invariant = any(variable == var_info.get('name') for var_info in invariant_vars.values())
                
                # For invariant variables (time-independent), we can reuse any existing file
                if is_invariant:
                    for file_info in self.existing_files[existing_variable_key]:
                        if not file_info.get('has_time_dimension', True):
                            existing_file = file_info['file_path']
                            print(f"Found existing invariant data for {variable}: {existing_file.name}")
                            output_files[variable] = existing_file
                            break
                    else:
                        print(f"No invariant file found for {variable}, will download")
                    continue
                else:
                    print(f"Time-varying variable {variable} found in inventory, will analyze chunks")
            
            try:
                merged_data = self._download_variable(variable)
                if merged_data is not None:
                    geo.export(merged_data, output_file)
                    output_files[variable] = output_file
                    description = self._get_file_description(variable)
                    print(f"Saved: {output_file.name}")
                    print(f"  Content: {description}")
                else:
                    print(f"Failed to download: {variable}")
            except Exception as e:
                print(f"Error downloading {variable}: {e}")
                continue
        
        
        # Compute derived variables (wind speed, relative humidity, ET0, EW0)
        if self.secondary_vars_requested:
            print(f"\nComputing {len(self.secondary_vars_requested)} secondary variables...")
            secondary_files = self._compute_secondary_variables(output_files)
            output_files.update(secondary_files)
        
        # Clean up temporary files unless user wants to keep them
        if not self.keep_chunks:
            self._cleanup_temp_files()
        
        print(f"\nDownload complete: {len(output_files)} variables processed")
        return output_files
    
    def _process_area(self, area: Union[str, pathlib.Path, tuple, gpd.GeoDataFrame]) -> Tuple[float, float, float, float]:
        """Process area specification into bounding box."""
        if isinstance(area, tuple):
            if len(area) == 4:
                north, west, south, east = area
                if not (-90 <= south < north <= 90):
                    raise ValueError(f"Invalid latitudes: South={south}, North={north}")
                if not (-180 <= west <= 180 and -180 <= east <= 180):
                    raise ValueError(f"Invalid longitudes: West={west}, East={east}")
                if west == east:
                    raise ValueError("Invalid longitudes: West and East cannot be equal")
                return area
        
        try:
            spatial_data = geo.load(area)
        except Exception as e:
            raise ValueError(f"Failed to load area from {area}: {e}")
        
        if isinstance(spatial_data, gpd.GeoDataFrame):
            if spatial_data.crs is None:
                raise ValueError(f"GeoDataFrame has no CRS defined")
            
            if spatial_data.crs.to_epsg() != 4326:
                spatial_data = geo.reproject(spatial_data, dst_crs=4326)
            
            bounds = spatial_data.total_bounds
            bbox = (bounds[3], bounds[0], bounds[1], bounds[2])
            
            north, west, south, east = bbox
            if not (-90 <= south <= 90 and -90 <= north <= 90):
                raise ValueError(f"Computed area has invalid latitude values: South={south}, North={north}")
            if not (-180 <= west <= 180 and -180 <= east <= 180):
                raise ValueError(f"Computed area has invalid longitude values: West={west}, East={east}")
            if west == east:
                raise ValueError("Computed area has invalid longitude extent (zero width)")
            
            return bbox
        
        elif isinstance(spatial_data, (xr.Dataset, xr.DataArray)):
            if spatial_data.rio.crs is None:
                raise ValueError(f"Raster data has no CRS defined")
            
            if spatial_data.rio.crs.to_epsg() != 4326:
                spatial_data = geo.reproject(spatial_data, dst_crs=4326)
            
            left, bottom, right, top = spatial_data.rio.bounds()
            bbox = (top, left, bottom, right)
            
            north, west, south, east = bbox
            if not (-90 <= south <= 90 and -90 <= north <= 90):
                raise ValueError(f"Computed area has invalid latitude values: South={south}, North={north}")
            if not (-180 <= west <= 180 and -180 <= east <= 180):
                raise ValueError(f"Computed area has invalid longitude values: West={west}, East={east}")
            if west == east:
                raise ValueError("Computed area has invalid longitude extent (zero width)")
            
            return bbox
        
        else:
            raise ValueError(f"Unsupported area data type: {type(spatial_data)}")

    def _area_to_segments(self, bbox: Tuple[float, float, float, float]) -> List[Tuple[float, float, float, float]]:
        """Split anti-meridian bbox into segments."""
        north, west, south, east = bbox
        if west <= east:
            return [bbox]  # Normal case, no anti-meridian crossing
        # Area crosses the anti-meridian (date line), split into two segments
        seg1 = (north, west, south, 180.0)    # Western segment
        seg2 = (north, -180.0, south, east)   # Eastern segment
        return [seg1, seg2]

    def _process_variables(self, variables: List[str]) -> Tuple[List[str], List[str]]:
        """Process variable specifications."""
        processed = []
        secondary_vars_requested = []
        
        for var in variables:
            # Handle secondary variables (computed from primary variables)
            if var in SECONDARY_VARIABLES:
                secondary_vars_requested.append(var)
                required_vars = SECONDARY_VARIABLES[var]['requires']
                for req_var in required_vars:
                    if req_var not in processed:
                        processed.append(req_var)
                        print(f"Added required variable '{req_var}' for secondary variable '{var}'")
                continue
            
            # Handle variable profiles (predefined sets of variables)
            if var in VARIABLE_PROFILES:
                profile_vars = VARIABLE_PROFILES[var]['variables']
                if profile_vars == 'all':  # Special case: all available variables
                    for category in ERA5_VARIABLES.values():
                        for var_info in category.values():
                            var_name = var_info['name']
                            if var_name not in processed:
                                processed.append(var_name)
                    for var_info in ERA5_INVARIANTS.values():
                        var_name = var_info['name']
                        if var_name not in processed:
                            processed.append(var_name)
                else:
                    for profile_var in profile_vars:
                        if profile_var in SECONDARY_VARIABLES:
                            secondary_vars_requested.append(profile_var)
                            required_vars = SECONDARY_VARIABLES[profile_var]['requires']
                            for req_var in required_vars:
                                if req_var not in processed:
                                    processed.append(req_var)
                        else:
                            if profile_var not in processed:
                                processed.append(profile_var)
            
            elif var in ERA5_VARIABLES:
                for var_info in ERA5_VARIABLES[var].values():
                    var_name = var_info['name']
                    if var_name not in processed:
                        processed.append(var_name)
            
            else:
                var_info = get_variable_info(ERA5_LAND, var)
                if var_info and var not in processed:
                    processed.append(var)
                else:
                    print(f"Warning: Unknown variable '{var}', skipping")
        
        if secondary_vars_requested:
            print(f"Secondary variables to be computed: {secondary_vars_requested}")
        
        return processed, secondary_vars_requested
    
    
    def _validate_parameters(self):
        """
        Validate input parameters for consistency and correctness.
        
        Raises
        ------
        ValueError
            If n_parallel is less than 1.
        """
        if self.n_parallel < 1:
            raise ValueError("n_parallel must be at least 1")
        if self.n_parallel > 20:
            print("Warning: CDS API recommends maximum 20 parallel downloads. Recommended: 3-4.")
        if self.retry_attempts < 1:
            raise ValueError("retry_attempts must be at least 1")
        if self.retry_delay < 0:
            raise ValueError("retry_delay must be non-negative")
        if not self.processed_variables:
            raise ValueError("No valid variables specified")
    
    def _download_variable(self, variable: str) -> Optional[xr.Dataset]:
        """Download data for a single variable."""
        invariant_vars = get_variables(ERA5_LAND, variable_type='invariants')
        is_invariant = any(variable == var_info.get('name') for var_info in invariant_vars.values())
        
        # Optimized handling for time-independent variables (topography, land mask, etc.)
        if is_invariant:
            first_period = pd.period_range(start=self.start_date, end=self.start_date, freq='M')[0]
            chunk_file = self._download_chunk(variable, first_period.year, first_period.month)
            
            if chunk_file and chunk_file.exists():
                invariant_data = geo.load(chunk_file)
                
                # Remove time dimension for invariant variables
                if 'time' in invariant_data.dims and invariant_data.sizes['time'] == 1:
                    invariant_data = invariant_data.isel(time=0, drop=True)
                
                # Add provenance metadata for direct download
                invariant_data = self._add_provenance_metadata(invariant_data, 'download')
                
                if not self.keep_chunks:
                    chunk_file.unlink()
                    
                print(f"Optimized download for invariant variable: {variable} (single request)")
                return invariant_data
            else:
                return None
        
        # For time-varying variables, download monthly chunks and merge
        periods = pd.period_range(start=self.start_date, end=self.end_date, freq='M')
        
        existing_variable_key = self._find_existing_variable_key(variable)
        strategy = self._get_chunk_strategy(variable, periods, existing_variable_key)
        
        all_chunk_datasets = []
        
        for chunk_info in strategy['chunks_from_existing']:
            year, month = chunk_info['year'], chunk_info['month']
            file_path = chunk_info['file_path']
            action = chunk_info['action']
            
            try:
                if action == 'use':
                    print(f"Using existing chunk: {year}-{month:02d} from {file_path.name}")
                    if file_path.parent == self.temp_dir:
                        chunk_data = geo.load(file_path)
                    else:
                        temp_path = self.temp_dir / f"{variable}_{year}{month:02d}.nc"
                        shutil.copy2(file_path, temp_path)
                        chunk_data = geo.load(temp_path)
                    
                    # Preserve or add provenance metadata for reused chunks
                    # If chunk already has provenance metadata, preserve it
                    # Otherwise, note that it comes from existing file (original source unknown)
                    if not any(attr in chunk_data.attrs for attr in ['data_source', 'download_date', 'clip_date']):
                        chunk_data = self._add_provenance_metadata(chunk_data, 'clip')
                
                elif action in ['clip_spatial', 'clip_temporal', 'clip_both']:
                    print(f"Clipping existing chunk: {year}-{month:02d} from {file_path.name}")
                    
                    # Use lazy loading for memory efficiency during clipping
                    chunk_data = xr.open_dataset(file_path)
                    
                    chunk_start = pd.Timestamp(year, month, 1)
                    chunk_end = chunk_start + pd.offsets.MonthEnd(1)
                    
                    # Detect time dimension name
                    time_dim = 'time'
                    time_dim = geo.main_time_dims(chunk_data)[0] # by priority order: 'time', 't', 'valid_time'                     
                    
                    clipped_data = self._memory_efficient_clipping(
                        chunk_data, action, time_dim, chunk_start, chunk_end
                    )
                    
                    clipped_data = clipped_data.load()
                    chunk_data = self._update_grib_metadata_after_clipping(clipped_data)
                    
                    # Add provenance metadata for clipped data
                    chunk_data = self._add_provenance_metadata(chunk_data, 'clip')
                    print(f"Successfully clipped chunk: {year}-{month:02d}")
                
                if chunk_data is not None:
                    all_chunk_datasets.append((pd.Period(year=year, month=month, freq='M'), chunk_data))
                
            except Exception as e:
                print(f"Error processing existing chunk {year}-{month:02d}: {e}")
                # Download this chunk as fallback
                strategy['chunks_to_download'].append((year, month))
        
        # Download missing chunks in parallel (within CDS API limits)
        if strategy['chunks_to_download']:
            print(f"Downloading {len(strategy['chunks_to_download'])} missing chunks...")
            
            chunk_files = []
            with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel) as executor:
                future_to_chunk = {
                    executor.submit(self._download_chunk, variable, year, month): (year, month)
                    for year, month in strategy['chunks_to_download']
                }
                
                for future in concurrent.futures.as_completed(future_to_chunk):
                    year, month = future_to_chunk[future]
                    try:
                        chunk_file = future.result()
                        if chunk_file and chunk_file.exists():
                            chunk_files.append((pd.Period(year=year, month=month, freq='M'), chunk_file))
                            print(f"Downloaded chunk: {year}-{month:02d}")
                        else:
                            print(f"Failed chunk: {year}-{month:02d}")
                    except Exception as e:
                        print(f"Error downloading chunk {year}-{month:02d}: {e}")
            
            for period, chunk_file in chunk_files:
                try:
                    chunk_data = geo.load(chunk_file)
                    if chunk_data is not None:
                        # Add provenance metadata for downloaded chunks
                        chunk_data = self._add_provenance_metadata(chunk_data, 'download')
                        all_chunk_datasets.append((period, chunk_data))
                except Exception as e:
                    print(f"Error loading downloaded chunk {period}: {e}")
        
        # Merge all chunks into a single dataset
        if all_chunk_datasets:
            try:
                all_chunk_datasets.sort(key=lambda x: (x[0].year, x[0].month))
                datasets_only = [ds for _, ds in all_chunk_datasets]
                datasets_only = [ds for ds in datasets_only if ds is not None]
                # Harmonize dimension names (ERA5 sometimes uses 'valid_time' instead of 'time')
                harmonized_datasets = []
                for ds in datasets_only:
                    if ds is not None:
                        if 'valid_time' in ds.dims and 'time' not in ds.dims:
                            ds = ds.rename({'valid_time': 'time'})
                        harmonized_datasets.append(ds)
                
                datasets_only = harmonized_datasets
                
                if len(datasets_only) == 1:
                    merged_data = datasets_only[0]
                else:
                    time_dim = 'time'
                    if datasets_only and hasattr(datasets_only[0], 'dims') and 'valid_time' in datasets_only[0].dims:
                        time_dim = 'valid_time'
                    
                    merged_data = xr.concat(datasets_only, dim=time_dim).sortby(time_dim)
                    
                    # Remove any duplicate timestamps that might occur at chunk boundaries
                    try:
                        idx = merged_data[time_dim].to_index()
                        unique_mask = ~idx.duplicated(keep='first')
                        merged_data = merged_data.isel({time_dim: unique_mask})
                    except Exception:
                        pass  # Skip deduplication if it fails
                
                print(f"Successfully merged {len(all_chunk_datasets)} chunks for {variable}")
                return merged_data
            except Exception as e:
                print(f"Error merging chunks for {variable}: {e}")
                return None
        
        return None
    
    def _generate_request_hash(self) -> str:
        """Generate deterministic hash for download context."""

        def norm_bbox(b):
            return tuple(round(float(v), 4) for v in b)

        bbox_norm = norm_bbox(self.bbox)
        segments_norm = [norm_bbox(seg) for seg in getattr(self, 'area_segments', [self.bbox])]

        params = {
            'area_bbox': bbox_norm,
            'area_segments': segments_norm,
            'start_date': pd.Timestamp(self.start_date).strftime('%Y-%m-%d'),
            'end_date': pd.Timestamp(self.end_date).strftime('%Y-%m-%d'),
            'output_prefix': self.output_prefix,
        }

        params_str = json.dumps(params, sort_keys=True, separators=(',', ':'))
        return hashlib.md5(params_str.encode('utf-8')).hexdigest()[:8]
    
    def _generate_descriptive_filename(self, var_name: str) -> str:
        """Generate descriptive filename."""
        # Check if variable is invariant (time-independent)
        invariant_vars = get_variables(ERA5_LAND, variable_type='invariants')
        is_invariant = any(var_name == var_info.get('name') for var_info in invariant_vars.values())
        
        if is_invariant:
            # For invariant variables, no date range needed
            parts = [self.output_prefix, var_name, "RAW"]
        else:
            # For temporal variables, include date range
            start_str = self.start_date.strftime('%Y%m%d')
            end_str = self.end_date.strftime('%Y%m%d')
            parts = [self.output_prefix, var_name, "RAW", f"{start_str}_{end_str}"]
        
        return "_".join(parts) + ".nc"
    
    def _get_file_description(self, var_name: str) -> str:
        """Get file content description."""
        if var_name in SECONDARY_VARIABLES or var_name in {v['name'] for v in SECONDARY_VARIABLES.values()}:
            return f"Derived ERA5-Land data for variable '{var_name}'"
        return f"Raw ERA5-Land hourly data for variable '{var_name}'"
    
    def _get_chunk_strategy(self, variable: str, target_periods: List[pd.Period], existing_variable_key: Optional[str] = None) -> Dict[str, Any]:
        """
        Analyze existing files for a variable and determine the best strategy for each chunk.
        
        Parameters
        ----------
        variable : str
            Variable name to analyze
        target_periods : List[pd.Period]
            List of monthly periods to obtain
        existing_variable_key : Optional[str]
            The actual variable name used in existing files (for name mapping)
            
        Returns
        -------
        Dict[str, Any]
            Strategy dictionary with chunk actions:
            {
                'chunks_to_download': [(year, month), ...],
                'chunks_from_existing': [{'year': int, 'month': int, 'file_path': Path, 'action': str}, ...]
            }
        """
        strategy = {
            'chunks_to_download': [],
            'chunks_from_existing': []
        }
        
        variable_key = existing_variable_key or variable
        existing_files = self.existing_files.get(variable_key, [])
        
        if not existing_files:
            strategy['chunks_to_download'] = [(p.year, p.month) for p in target_periods]
            return strategy
        
        print(f"Analyzing chunks for {variable}: {len(target_periods)} periods requested")
        
        for period in target_periods:
            year, month = period.year, period.month
            chunk_found = False
            
            for file_info in existing_files:
                temporal_range = file_info.get('temporal_range')
                file_bbox = file_info.get('bbox')
                
                if temporal_range:
                    file_start, file_end = temporal_range
                    chunk_start = pd.Timestamp(year, month, 1)
                    chunk_end = chunk_start + pd.offsets.MonthEnd(1)
                    
                    if file_start <= chunk_start and file_end >= chunk_end:
                        print(f"  Found chunk {year}-{month:02d} in {file_info['file_path'].name}")
                        
                        action = 'use'
                        if file_bbox and self.bbox:
                            target_north, target_west, target_south, target_east = self.bbox
                            file_north, file_west, file_south, file_east = file_bbox
                            
                            if (file_north >= target_north and file_west <= target_west and
                                file_south <= target_south and file_east >= target_east):
                                
                                is_exact = (abs(file_north - target_north) < 0.01 and 
                                           abs(file_west - target_west) < 0.01 and
                                           abs(file_south - target_south) < 0.01 and 
                                           abs(file_east - target_east) < 0.01)
                                
                                if is_exact:
                                    action = 'use'
                                    print(f"    Exact spatial match")
                                else:
                                    action = 'clip_spatial'
                                    print(f"    Needs spatial clipping")
                            else:
                                print(f"    File too small for requested area, skipping")
                                continue
                        
                        if file_start < chunk_start or file_end > chunk_end:
                            action = 'clip_temporal' if action == 'use' else 'clip_both'
                            print(f"    Needs temporal clipping")
                        
                        strategy['chunks_from_existing'].append({
                            'year': year,
                            'month': month,
                            'file_path': file_info['file_path'],
                            'action': action,
                            'temporal_range': temporal_range
                        })
                        chunk_found = True
                        break
            
            if not chunk_found:
                print(f"  Chunk {year}-{month:02d} needs to be downloaded")
                strategy['chunks_to_download'].append((year, month))
        
        print(f"Strategy: {len(strategy['chunks_to_download'])} to download, "
              f"{len(strategy['chunks_from_existing'])} from existing files")
        
        return strategy
    
    def _get_all_possible_variable_names(self) -> List[str]:
        """
        Get all possible variable names (short and long) for all requested variables.
        
        Includes both temporal and invariant variables, handling ERA5-Land naming conventions.
        
        Returns
        -------
        List[str]
            Complete list of variable names to search for in existing files
        """
        all_names = set()
        
        short_to_long = get_short_to_long_name_mapping(ERA5_LAND)
        long_to_short = {v: k for k, v in short_to_long.items()}
        
        for variable in self.processed_variables:
            all_names.add(variable)
            
            if variable in long_to_short:
                all_names.add(long_to_short[variable])
            
            if variable in short_to_long:
                all_names.add(short_to_long[variable])
        
        for sec_var in self.secondary_vars_requested:
            all_names.add(sec_var)
            if sec_var in long_to_short:
                all_names.add(long_to_short[sec_var])
            if sec_var in short_to_long:
                all_names.add(short_to_long[sec_var])
        
        era5_aliases = {
            't2m': '2m_temperature',
            'd2m': '2m_dewpoint_temperature', 
            'u10': '10m_u_component_of_wind',
            'v10': '10m_v_component_of_wind',
            'tp': 'total_precipitation',
            'sp': 'surface_pressure',
            'ssrd': 'surface_solar_radiation_downwards'
        }
        
        current_names = list(all_names)
        for name in current_names:
            if name in era5_aliases:
                all_names.add(era5_aliases[name])
            for alias, full_name in era5_aliases.items():
                if name == full_name:
                    all_names.add(alias)
        
        result = sorted(list(all_names))
        print(f"Variable names to search: {result}")
        return result
    
    def _find_existing_variable_key(self, variable: str) -> Optional[str]:
        """
        Find the actual variable name used in existing files inventory.
        
        Handles ERA5-Land naming conventions with short/long name mapping.
        
        Parameters
        ----------
        variable : str
            Variable name to search for
            
        Returns
        -------
        Optional[str]
            The actual variable key used in existing files, or None if not found
        """
        if variable in self.existing_files:
            return variable
        
        short_to_long = get_short_to_long_name_mapping(ERA5_LAND)
        long_to_short = {v: k for k, v in short_to_long.items()}
        
        if variable in long_to_short and long_to_short[variable] in self.existing_files:
            return long_to_short[variable]
        elif variable in short_to_long and short_to_long[variable] in self.existing_files:
            return short_to_long[variable]
            
        return None
    
    def _download_chunk(self, variable: str, year: int, month: int) -> Optional[pathlib.Path]:
        """
        Download one monthly chunk of data.
        
        Parameters
        ----------
        variable : str
            CDS variable name.
        year : int
            Year to download.
        month : int
            Month to download.
            
        Returns
        -------
        Optional[pathlib.Path]
            Path to downloaded chunk file or None if failed.
        """
        request_hash = self._generate_request_hash()
        filename = f"{variable}_{year}{month:02d}_{request_hash}.nc"
        output_path = self.temp_dir / filename
        
        # Check if chunk already exists and is valid
        if output_path.exists():
            try:
                test_ds = geo.load(output_path)
                if test_ds is not None:
                    return output_path
            except (IOError, OSError, ValueError) as e:
                print(f"File validation failed: {e}")
                try:
                    output_path.unlink()  # Remove corrupted file
                except OSError:
                    pass
        
        full_variable_name = variable
        
        # Generate request parameters for the entire month
        days_in_month = pd.Period(year=year, month=month, freq="M").days_in_month
        day_list = [f"{d:02d}" for d in range(1, days_in_month + 1)]
        time_list = [f"{h:02d}:00" for h in range(24)]  # All 24 hours
        
        segments = getattr(self, 'area_segments', [self.bbox])
        
        # Download with retry logic
        for attempt in range(self.retry_attempts):
            try:
                if len(segments) == 1:  # Single area request
                    request = {
                        'variable': [full_variable_name],
                        'year': str(year),
                        'month': f'{month:02d}',
                        'day': day_list,
                        'time': time_list,
                        'area': list(segments[0]),  # [North, West, South, East]
                        'data_format': 'netcdf',
                        'download_format': 'unarchived'
                    }
                    print(f"Downloading {variable} {year}-{month:02d} (attempt {attempt + 1})")
                    result = self.client.retrieve('reanalysis-era5-land', request)
                    result.download(str(output_path))
                    
                    if output_path.exists() and output_path.stat().st_size > 0:
                        test_ds = geo.load(output_path)
                        if test_ds is not None:
                            time.sleep(1)
                            return output_path
                    raise Exception("Downloaded file is invalid")
                else:  # Multi-segment request for anti-meridian areas
                    seg_paths: List[pathlib.Path] = []
                    for idx, seg in enumerate(segments):
                        seg_path = self.temp_dir / f"{variable}_{year}{month:02d}_seg{idx}.nc"
                        request = {
                            'variable': [full_variable_name],
                            'year': str(year),
                            'month': f'{month:02d}',
                            'day': day_list,
                            'time': time_list,
                            'area': list(seg),
                            'data_format': 'netcdf',
                            'download_format': 'unarchived'
                        }
                        print(f"Downloading {variable} {year}-{month:02d} segment {idx+1}/{len(segments)} (attempt {attempt + 1})")
                        result = self.client.retrieve('reanalysis-era5-land', request)
                        result.download(str(seg_path))
                        if not (seg_path.exists() and seg_path.stat().st_size > 0):
                            raise Exception(f"Segment {idx} download is invalid")
                        seg_paths.append(seg_path)
                    
                    # Merge the segments back together
                    try:
                        if len(seg_paths) > 4:
                            print(f"Warning: Loading {len(seg_paths)} segments simultaneously. "
                                  "Consider reducing area size if memory issues occur.")
                        datasets = [geo.load(p) for p in seg_paths]
                        try:
                            merged_seg = xr.combine_by_coords(datasets)  # Automatic coordinate-based merge
                        except Exception:
                            # Fallback: manual longitude concatenation
                            lon_dim = None
                            for cand in ('longitude', 'lon', 'x'):
                                if cand in datasets[0].dims:
                                    lon_dim = cand
                                    break
                            if lon_dim:
                                merged_seg = xr.concat(datasets, dim=lon_dim).sortby(lon_dim)
                            else:
                                merged_seg = xr.merge(datasets)
                        
                        geo.export(merged_seg, output_path)
                        
                        if output_path.exists() and output_path.stat().st_size > 0:
                            test_ds = geo.load(output_path)
                            if test_ds is not None:
                                for p in seg_paths:
                                    try:
                                        p.unlink(missing_ok=True)
                                    except Exception:
                                        pass
                                time.sleep(1)
                                return output_path
                        raise Exception("Merged segment file is invalid")
                    finally:
                        for p in seg_paths:
                            try:
                                p.unlink(missing_ok=True)
                            except Exception:
                                pass
            except Exception as e:
                print(f"Download attempt {attempt + 1} failed: {e}")
                if output_path.exists():
                    try:
                        output_path.unlink()
                    except Exception:
                        pass
                if attempt < self.retry_attempts - 1:
                    wait_time = self.retry_delay * (2 ** attempt)  # Exponential backoff
                    print(f"Retrying in {wait_time} seconds...")
                    time.sleep(wait_time)
                else:
                    print(f"All attempts failed for {variable} {year}-{month:02d}")
                    return None
        
        return None
    
    def _compute_secondary_variables(self, primary_files: Dict[str, pathlib.Path]) -> Dict[str, pathlib.Path]:
        """
        Compute secondary variables from downloaded primary variables.
        
        Parameters
        ----------
        primary_files : Dict[str, pathlib.Path]
            Dictionary mapping primary variable names to file paths.
            
        Returns
        -------
        Dict[str, pathlib.Path]
            Dictionary mapping secondary variable names to computed file paths.
        """
        secondary_files = {}
        
        for secondary_var in self.secondary_vars_requested:
            print(f"Computing {secondary_var}...")
            
            # Compute wind speed from U and V components
            if secondary_var == 'wind_speed':
                if '10m_u_component_of_wind' in primary_files and '10m_v_component_of_wind' in primary_files:
                    try:
                        u10_ds = geo.load(primary_files['10m_u_component_of_wind'])
                        v10_ds = geo.load(primary_files['10m_v_component_of_wind'])
                        
                        wind_ds = geo.compute_wind_speed(u10_ds, v10_ds)
                        
                        # Ensure we have a Dataset for metadata
                        if isinstance(wind_ds, xr.DataArray):
                            wind_ds = xr.Dataset({'wind_speed': wind_ds})
                        
                        filename = self._generate_descriptive_filename('wind_speed')
                        output_file = self.output_dir / filename
                        
                        # Add provenance metadata for computed variable
                        wind_ds_with_provenance = self._add_provenance_metadata(wind_ds, 'computed')
                        self._save_secondary_variable(wind_ds_with_provenance, 'wind_speed', output_file)
                        secondary_files['wind_speed'] = output_file
                        
                    except Exception as e:
                        print(f"  Error computing wind_speed: {e}")
                        
            # Compute relative humidity from temperature, dewpoint, and pressure
            elif secondary_var == 'relative_humidity':
                required_vars = ['2m_temperature', '2m_dewpoint_temperature', 'surface_pressure']
                if all(var in primary_files for var in required_vars):
                    try:
                        t2m_ds = geo.load(primary_files['2m_temperature'])
                        d2m_ds = geo.load(primary_files['2m_dewpoint_temperature'])
                        sp_ds = geo.load(primary_files['surface_pressure'])
                        
                        rh_ds = geo.compute_relative_humidity(
                            dewpoint_input_file=d2m_ds['2m_dewpoint_temperature'],
                            temperature_input_file=t2m_ds['2m_temperature'],
                            pressure_input_file=sp_ds['surface_pressure']
                        )
                        
                        # Ensure we have a Dataset for metadata
                        if isinstance(rh_ds, xr.DataArray):
                            rh_ds = xr.Dataset({'relative_humidity': rh_ds})
                        
                        filename = self._generate_descriptive_filename('relative_humidity')
                        output_file = self.output_dir / filename
                        
                        # Add provenance metadata for computed variable
                        rh_ds_with_provenance = self._add_provenance_metadata(rh_ds, 'computed')
                        self._save_secondary_variable(rh_ds_with_provenance, 'relative_humidity', output_file)
                        secondary_files['relative_humidity'] = output_file
                        
                    except Exception as e:
                        print(f"  Error computing relative_humidity: {e}")
                        
            # Compute reference evapotranspiration (ET0) or open water evaporation (EW0)
            elif secondary_var in ['ET0', 'EW0']:
                if 'potential_evaporation' in primary_files:
                    try:
                        pev_ds = geo.load(primary_files['potential_evaporation'])
                        
                        # Convert potential evaporation to reference values
                        et0_ds, ew0_ds = geo.compute_Erefs_from_Epan(pev_ds['potential_evaporation'])
                        
                        if secondary_var == 'ET0':
                            computed_ds = et0_ds
                        else:
                            computed_ds = ew0_ds
                        
                        # Ensure we have a Dataset for metadata
                        if isinstance(computed_ds, xr.DataArray):
                            computed_ds = xr.Dataset({secondary_var: computed_ds})
                        
                        filename = self._generate_descriptive_filename(secondary_var)
                        output_file = self.output_dir / filename
                        
                        # Add provenance metadata for computed variable
                        computed_ds_with_provenance = self._add_provenance_metadata(computed_ds, 'computed')
                        self._save_secondary_variable(computed_ds_with_provenance, secondary_var, output_file)
                        secondary_files[secondary_var] = output_file
                        
                    except Exception as e:
                        print(f"  Error computing {secondary_var}: {e}")
        
        return secondary_files

    def _save_secondary_variable(self, data, var_name: str, output_file: pathlib.Path):
        """Save secondary variable data to file, handling Dataset/DataArray conversion."""
        if isinstance(data, xr.Dataset):
            geo.export(data, output_file)
        else:
            dataset = xr.Dataset({var_name: data})
            geo.export(dataset, output_file)
        print(f"  Saved: {output_file.name}")

    def _apply_spatial_clipping(self, dataset: xr.Dataset) -> xr.Dataset:
        """Apply spatial clipping to dataset."""
        target_north, target_west, target_south, target_east = self.bbox
        x_var, y_var = geo.main_space_dims(dataset)[0]
        
        if x_var and y_var and x_var in dataset.coords and y_var in dataset.coords:
            lon_coords = dataset[x_var].values
            lat_coords = dataset[y_var].values
            
            # Handle coordinate ordering (ascending vs descending)
            if lon_coords[0] < lon_coords[-1]:  # Ascending longitude
                lon_slice = slice(target_west, target_east)
            else:  # Descending longitude
                lon_slice = slice(target_east, target_west)
            
            if lat_coords[0] > lat_coords[-1]:  # Descending latitude (common in climate data)
                lat_slice = slice(target_north, target_south)
            else:  # Ascending latitude
                lat_slice = slice(target_south, target_north)
            
            return dataset.sel({
                x_var: lon_slice,
                y_var: lat_slice
            })
        else:
            return dataset

    def _memory_efficient_clipping(self, dataset: xr.Dataset, action: str, 
                                  time_dim: str, chunk_start: pd.Timestamp, 
                                  chunk_end: pd.Timestamp) -> xr.Dataset:
        """
        Perform memory-efficient clipping using lazy loading and chunked operations.
        
        Parameters
        ----------
        dataset : xr.Dataset
            Input dataset (opened with lazy loading)
        action : str
            Type of clipping ('clip_spatial', 'clip_temporal', 'clip_both')
        time_dim : str
            Name of time dimension
        chunk_start : pd.Timestamp
            Start of temporal range
        chunk_end : pd.Timestamp
            End of temporal range
            
        Returns
        -------
        xr.Dataset
            Clipped dataset (still lazy until actually used)
        """
        if action == 'clip_spatial':
            clipped_data = self._apply_spatial_clipping(dataset)
                
        elif action == 'clip_temporal':
            clipped_data = dataset.sel({time_dim: slice(chunk_start, chunk_end)})
            
        else:  # clip_both
            temp_clipped = dataset.sel({time_dim: slice(chunk_start, chunk_end)})
            clipped_data = self._apply_spatial_clipping(temp_clipped)
        
        return clipped_data

    def _update_grib_metadata_after_clipping(self, dataset: xr.Dataset) -> xr.Dataset:
        """
        Update GRIB metadata attributes to match the actual clipped data dimensions.
        
        Parameters
        ----------
        dataset : xr.Dataset
            Dataset that has been spatially clipped
            
        Returns
        -------
        xr.Dataset
            Dataset with corrected GRIB metadata
        """
        updated_dataset = dataset.copy()
        
        x_var, y_var = geo.main_space_dims(updated_dataset)[0]
        if not x_var or not y_var:
            return updated_dataset
            
        if x_var in updated_dataset.coords and y_var in updated_dataset.coords:
            lon_coords = updated_dataset[x_var].values
            lat_coords = updated_dataset[y_var].values
            
            actual_nx = len(lon_coords)
            actual_ny = len(lat_coords)
            
            if len(lon_coords) > 1:
                lon_increment = abs(float(lon_coords[1] - lon_coords[0]))
            else:
                lon_increment = 0.1
                
            if len(lat_coords) > 1:
                lat_increment = abs(float(lat_coords[1] - lat_coords[0]))
            else:
                lat_increment = 0.1
            
            first_lon = float(lon_coords[0])
            last_lon = float(lon_coords[-1])
            first_lat = float(lat_coords[0])
            last_lat = float(lat_coords[-1])
            
            # Update GRIB metadata to reflect the clipped dimensions
            for _, data_var in updated_dataset.data_vars.items():
                if hasattr(data_var, 'attrs'):
                    # Update grid dimensions
                    if 'GRIB_Nx' in data_var.attrs:
                        data_var.attrs['GRIB_Nx'] = actual_nx
                    if 'GRIB_Ny' in data_var.attrs:
                        data_var.attrs['GRIB_Ny'] = actual_ny
                    
                    if 'GRIB_longitudeOfFirstGridPointInDegrees' in data_var.attrs:
                        data_var.attrs['GRIB_longitudeOfFirstGridPointInDegrees'] = first_lon
                    if 'GRIB_longitudeOfLastGridPointInDegrees' in data_var.attrs:
                        data_var.attrs['GRIB_longitudeOfLastGridPointInDegrees'] = last_lon
                    if 'GRIB_latitudeOfFirstGridPointInDegrees' in data_var.attrs:
                        data_var.attrs['GRIB_latitudeOfFirstGridPointInDegrees'] = first_lat
                    if 'GRIB_latitudeOfLastGridPointInDegrees' in data_var.attrs:
                        data_var.attrs['GRIB_latitudeOfLastGridPointInDegrees'] = last_lat
                    
                    # Update grid increments and total points
                    if 'GRIB_iDirectionIncrementInDegrees' in data_var.attrs:
                        data_var.attrs['GRIB_iDirectionIncrementInDegrees'] = lon_increment
                    if 'GRIB_jDirectionIncrementInDegrees' in data_var.attrs:
                        data_var.attrs['GRIB_jDirectionIncrementInDegrees'] = lat_increment
                    
                    if 'GRIB_numberOfPoints' in data_var.attrs:
                        data_var.attrs['GRIB_numberOfPoints'] = actual_nx * actual_ny
        
        return updated_dataset

    def _add_provenance_metadata(self, dataset: xr.Dataset, operation_type: str) -> xr.Dataset:
        """
        Add provenance metadata to track data sources and processing history.
        
        Parameters
        ----------
        dataset : xr.Dataset
            Dataset to add metadata to
        operation_type : str
            Type of operation: 'download' for direct CDS downloads, 
            'clip' for clipped data, 'computed' for derived variables
            
        Returns
        -------
        xr.Dataset
            Dataset with added provenance metadata
        """
        current_time = pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S UTC')
        
        if operation_type == 'download':
            # Direct download from Copernicus CDS
            dataset.attrs['data_source'] = 'Copernicus Climate Data Store (CDS) - direct download'
            dataset.attrs['download_date'] = current_time
            dataset.attrs['data_provenance'] = 'original'
        elif operation_type == 'clip':
            # Clipped from existing file (original download date unknown)
            dataset.attrs['data_source'] = 'Copernicus Climate Data Store (CDS) - clipped from existing file'
            dataset.attrs['clip_date'] = current_time
            dataset.attrs['data_provenance'] = 'clipped (original download date unknown)'
        elif operation_type == 'computed':
            # Computed/derived variable from primary variables
            dataset.attrs['data_source'] = 'Computed from Copernicus Climate Data Store (CDS) primary variables'
            dataset.attrs['computation_date'] = current_time
            dataset.attrs['data_provenance'] = 'computed/derived from primary variables'
        
        return dataset

    def _cleanup_temp_files(self):
        """
        Clean up temporary chunk files if not preserved.
        
        Removes all NetCDF files from the temporary directory unless
        keep_chunks is True.
        """
        try:
            for file in self.temp_dir.glob('*.nc'):
                file.unlink()
            self.temp_dir.rmdir()
        except (OSError, FileNotFoundError):
            pass


[docs] def download_era5_land( area: Union[str, pathlib.Path, tuple, gpd.GeoDataFrame], variables: Union[str, List[str]], start_date: Union[str, datetime], end_date: Union[str, datetime], output_dir: Optional[Union[str, pathlib.Path]] = None, n_parallel: int = 1, keep_chunks: bool = False, force_overwrite: bool = False, output_prefix: str = 'ERA5LAND', ) -> Dict[str, pathlib.Path]: """ Download ERA5-Land data with simplified interface. Parameters ---------- area : Union[str, pathlib.Path, tuple, gpd.GeoDataFrame] Area specification (bounding box, file path, or GeoDataFrame). variables : Union[str, List[str]] Variables to download (names, categories, or profiles). start_date : Union[str, datetime] Start date for download period. Supports partial dates: - '2025-04-23' -> exact date - '2025-04' -> first day of month (2025-04-01) - '2025' -> first day of year (2025-01-01) end_date : Union[str, datetime] End date for download period. Supports partial dates: - '2025-04-23' -> exact date - '2025-04' -> last day of month (2025-04-30) - '2025' -> last day of year (2025-12-31) output_dir : Optional[Union[str, pathlib.Path]], default None Output directory for downloaded files. n_parallel : int, default 1 Number of parallel download threads. keep_chunks : bool, default False Whether to keep intermediate monthly chunk files. force_overwrite : bool, default False Whether to overwrite existing files. output_prefix : str, default 'ERA5LAND' Prefix for output filenames. Returns ------- Dict[str, pathlib.Path] Dictionary mapping variable names to output file paths. """ downloader = ERA5LandDownloader( area=area, variables=variables, start_date=start_date, end_date=end_date, output_dir=output_dir, n_parallel=n_parallel, keep_chunks=keep_chunks, force_overwrite=force_overwrite, output_prefix=output_prefix, ) return downloader.download()