In [1]:
# import warnings
# warnings.simplefilter(action='ignore', category=FutureWarning)
import xarray as xr
import rioxarray as rio
from rasterio.warp import reproject, Resampling
import pandas as pd
import numpy as np
import datetime
import os

In [2]:
# ------------------------------
# Core Data Processing Functions
# ------------------------------

def aggregate_time(da):
    """
    Aggregates an xarray Dataset/DataArray over its time dimension.
    
    - Negative values are replaced with NaN.
    - For Datasets: Sums all variables except 'se_root', which is averaged.
    - For DataArrays: Returns sum if the variable is not 'se_root', else returns mean.
    """
    da = da.where(da >= 0, np.nan)
    time_len = len(da.time)
    da_mean = da.mean(dim="time", skipna=True)

    if isinstance(da, xr.Dataset):
        ds_out = da_mean * time_len
        if "se_root" in ds_out:
            ds_out["se_root"] = da_mean["se_root"]  # average for 'se_root'
        return ds_out
    else:
        if da.name != "se_root":
            return da_mean * time_len
        else:
            return da_mean

def restore_metadata(ds, orig_encoding, temp_res):
    """
    Reapplies original encoding and updates variable attributes after aggregation.
    """
    for var in ds.data_vars:
        if var in orig_encoding:
            ds[var].encoding.update(orig_encoding[var])
        # Update attributes
        attrs = ds[var].attrs
        attrs["long_name"] = attrs.get("long_name", "").replace("Daily", temp_res).replace("mm", "")
        attrs["source_data"] = "Aggregated from ET_Look model output"
        # Adjust units to reflect the new temporal resolution (e.g., mm/dekadal)
        attrs["units"] = f"{attrs.get('units', '')}/{temp_res[:-2]}"
        attrs["temporal_resolution"] = temp_res
    return ds

# ------------------------------
# Temporal Aggregation Functions
# ------------------------------

def dekadal_sum(ds):
    """
    Aggregates daily data into dekadal (10-day) intervals.
    """
    orig_encoding = {var: ds[var].encoding for var in ds.data_vars}
    
    # Compute the offset to get the start date of each dekad
    day = ds.time.dt.day
    d = day - np.clip((day-1) // 10, 0, 2)*10 - 1
    date = ds.time.values - np.array(d, dtype="timedelta64[D]")
    ds['time'] = date
    ds_dk = ds.groupby("time").map(aggregate_time)

    return restore_metadata(ds_dk, orig_encoding, "Dekadal")


def monthly_sum(ds):
    """
    Aggregates daily data into monthly intervals.
    """
    orig_encoding = {var: ds[var].encoding for var in ds.data_vars}
    
    # Resample using calendar month end; "M" is sufficient in many cases.
    ds_mn = ds.resample(time="ME").map(aggregate_time)
    return restore_metadata(ds_mn, orig_encoding, "Monthly")

def select_season(ds, season_start_date, season_end_date):
    """
    Subsets the dataset to a given seasonal window.
    Warns if the season extends beyond the datasetâ€™s time range.
    """
    sos = pd.to_datetime(season_start_date)
    eos = pd.to_datetime(season_end_date)
    ds_start = pd.to_datetime(ds.time.data[0])
    ds_end = pd.to_datetime(ds.time.data[-1])
    
    if (sos < ds_start) | (eos > ds_end):
        print("Warning: Season dates partially/completely outside dataset range")
        print(f"sos:{sos}, eos:{eos}, ds_strt:{ds_start}, ds_end:{eos, ds_end}")
    return ds.sel(time=slice(sos, eos))

def seasonal_sum(ds, season_start_date, season_end_date):
    """
    Aggregates data over a user-defined seasonal period.
    """
    orig_encoding = {var: ds[var].encoding for var in ds.data_vars}
    ds_season = select_season(ds, season_start_date, season_end_date).map(aggregate_time)
    ds_season = restore_metadata(ds_season, orig_encoding, "Seasonal")
    
    # Add season dates to each variable's encoding
    for var in ds_season.data_vars:
        ds_season[var].encoding.update({"sos": season_start_date, "eos": season_end_date})
    return ds_season

# ------------------------------
# Spatial Processing Functions
# ------------------------------

def reproject_ds(ds, to_crs):
    """
    Reprojects the dataset to a target CRS.
    
    Accepts either:
    - An EPSG code string (e.g., "EPSG:32643")
    - A file path to a template raster with valid CRS information.
    """
    try:
        if isinstance(to_crs, str) and ("EPSG") in to_crs:
            return ds.rio.reproject(to_crs, nodata= np.nan)
        elif os.path.exists(to_crs):
            template = rio.open_rasterio(to_crs)
            if template.rio.crs is not None:
                return ds.rio.reproject_match(template)
            else:
                raise ValueError("Template raster lacks CRS info")
        else:
            raise ValueError("Invalid CRS or template path provided")
    except Exception as e:
        print(f"Reprojection failed: {e}")
        return ds

# ------------------------------
# Output Generation Functions
# ------------------------------

# Mapping dictionary for standardized naming
switcher = {
    "aeti": "AETI",  # Actual Evapotranspiration
    "e": "E",        # Evaporation
    "int": "I",      # Interception
    "npp": "NPP",    # Net Primary Productivity
    "t": "T",        # Transpiration
    "se": "RSM",     # Root Soil Moisture
    "dekadal": "D",
    "monthly": "M",
    "seasonal": "S",
}

def write_file(da, to_crs, fname, encoding, date, attrs):
    """
    Writes a processed data array to GeoTIFF.
    """
    if to_crs:
        da = reproject_ds(da, to_crs)
    da.attrs.update({"date": date})
    da = da.round(2)
    if '_FillValue' in da.attrs:
        del da.attrs['_FillValue']
    da.encoding = encoding
    da.rio.to_raster(f"{fname}.tif", driver="GTiff", compress="LZW")

def write2gtiff(ds, temporal_res, dir_out, to_crs=None):
    """
    Writes aggregated dataset variables to GeoTIFFs.
    
    - Creates output directories.
    - For non-seasonal aggregations, writes one file per time slice.
    - For seasonal data, writes a single file with season dates in the filename.
    """
    # Convert the time coordinates to strings
    date_strs = pd.to_datetime(ds.time.data).strftime("%Y-%m-%d") if 'time' in ds.dims else None
    
    for var in ds.data_vars:
        # Build standardized variable and time codes
        var_key = var.split("_")[0]
        var_root = switcher.get(var_key, var_key)
        time_code = switcher.get(temporal_res.lower(), temporal_res)
        var_name = f"{var_root}_{time_code}"
        
        # Build the output directory path and create it if needed
        output_dir = os.path.join(dir_out, temporal_res, f"pywapor_{var_name}")
        os.makedirs(output_dir, exist_ok=True)
        
        encoding = ds[var].encoding.copy()
        encoding.update({"dtype": "float32", "scale_factor": 1.0, "_FillValue": np.nan})
        attrs = ds[var].attrs
        
        if temporal_res.lower() != "seasonal":
            for idx, date in enumerate(date_strs):
                fname = os.path.join(output_dir, f"pywapor_{var_name}_{date}")
                # Drop the time coordinate for each slice
                da_slice = ds[var].isel(time=idx).drop_vars("time")
                write_file(da_slice, to_crs, fname, encoding, date, attrs)
        else:
            season_date = f"{encoding.get('sos')}_{encoding.get('eos')}"
            fname = os.path.join(output_dir, f"pywapor_{var_name}_{season_date}")
            write_file(ds[var], to_crs, fname, encoding, season_date, attrs)


#### Step 1: Read pywapor output

In [3]:
# path to the et_look_out/nc file
path_et_look_out = r"d:\pywapor\Kenya_2023\et_look_out.nc"
dir_out = r"d:\pywapor\Kenya_2023\et_look_out_tiffs" # folder to save the geotif files
xr.set_options(keep_attrs=True)
ds = xr.open_dataset(path_et_look_out, decode_coords="all")
if 'time_bins' in ds.dims:
    ds = ds.rename({'time_bins': 'time'})

# # select the period for time of interest wich may be a season
# season = season = ["2022-11-01", "2023-05-31"] # Iraq
season = season = ["2023-03-21", "2023-09-11"] # Jordan
season = season = ["2023-03-01", "2023-12-31"] # Kenya
ds = ds.sel(time=slice(season[0], season[1]))
ds = ds.where((ds>=0.0) | ds.isnull(), 0.0)
to_crs = f'EPSG:{ds.rio.estimate_utm_crs().to_epsg()}'

#### Step 2: Aggregate to the required timestep (dekadal, monthly or seasonal) and write the result to individual geotiff files per time step
The ET_look output is in EPSG:4326, if you would like to reproject the dataset to other projections such UTM zone, provide the required epsg code or a path to raster template file with valid CRS information. The defualt is an estimated utm crs from the dataset. if you want to change provide the crs in the following style: to_crs = f"EPSG:{epsg code}" 

In [4]:
# aggregate to dekadal timestep
ds_dk = dekadal_sum(ds) # dekadal
temporal_res = 'dekadal'
write2gtiff(ds_dk, temporal_res, dir_out, to_crs)

In [22]:
# aggregate to monthly timestep
ds_mn = monthly_sum(ds) # monthly
temporal_res = 'monthly'
write2gtiff(ds_mn, temporal_res, dir_out, to_crs)

In [21]:
# aggregate to a season
ds_sn = seasonal_sum(ds, season[0], season[1])
temporal_res = 'seasonal'
write2gtiff(ds_sn, temporal_res, dir_out, to_crs )