Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
001851f
first partial commit of dataset registry and caching
Feb 23, 2026
76177ed
chore: allow arbitrary dhis2eo download funcs, get dataset cache info…
Feb 23, 2026
b3d66bf
chore: define datasets router prefix centrally
Feb 23, 2026
8d1b3af
complete aggregate pipeline with unit conversion, hacky constants module
Feb 23, 2026
8e99b86
make aggregation statistics dynamic and not hardcoded for each dataset
Feb 24, 2026
073a6a1
reorganize api so time and space aggregation happens hierarchically b…
Feb 24, 2026
937a277
move some funcs to utils.py
Feb 24, 2026
b48bd80
add raster with time period aggrgation download endpoint as part of t…
Feb 24, 2026
7752544
add dummy tiles endpoint
Feb 24, 2026
21c6d00
apply unit conversion earlier so it also affects raster download
Feb 24, 2026
297c6e7
switch to proper logging
Feb 24, 2026
7890d1f
check and raise dataset id errors more gracefully
Feb 24, 2026
b9d27c2
validate target period type and skip temporal aggregation if same per…
Feb 24, 2026
d95c1d6
fix misc dataset metadata
Feb 24, 2026
e243b8d
fix period type validation bug, redo json serialization and show corr…
Feb 24, 2026
ccca427
easier to switch out to different org units, cache builds to latest d…
Feb 24, 2026
0e56e82
switch to more reliable cache background worker, fix dynamic bbox error
Feb 24, 2026
825e3ea
speedup unit conversion by doing it on fewer values, stabilize intern…
Feb 24, 2026
134070e
add cache optimization which builds zarr archive, dataset openers use…
Feb 24, 2026
367b06d
improve read and write of zarr cache
Feb 25, 2026
62c33db
fix same period type aggregation error, clarify array subsetting in t…
Feb 25, 2026
0c426e2
smaller chunk size for improved read speed, memory, and optimize buil…
Feb 25, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
__pycache__/

datasets/cache
Empty file added __init__.py
Empty file.
5,577 changes: 5,577 additions & 0 deletions brazil-municipalities.geojson

Large diffs are not rendered by default.

34 changes: 34 additions & 0 deletions brazil-regions.geojson

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import json
import geopandas as gpd

# constants for org units bbox and country code (hacky hardcoded for now)
# TODO: these should be defined differently or retrieved from DHIS2 connection

# sierra leone
# GEOJSON_FILE = 'sierra-leone-districts.geojson'
# COUNTRY_CODE = 'SLE'
# CACHE_OVERRIDE = r'C:\Users\karimba\Documents\Github\eo-api\datasets\cache\SLE'

# brazil
GEOJSON_FILE = 'brazil-regions.geojson'
COUNTRY_CODE = 'BRA'
CACHE_OVERRIDE = None

#################################################
ORG_UNITS_GEOJSON = json.load(open(GEOJSON_FILE))
BBOX = list(map(float, gpd.read_file(GEOJSON_FILE).total_bounds))
Empty file added datasets/__init__.py
Empty file.
112 changes: 112 additions & 0 deletions datasets/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@

from fastapi import APIRouter, HTTPException, BackgroundTasks
from fastapi.responses import FileResponse
from starlette.background import BackgroundTask

import constants
from . import registry
from . import cache
from . import raster
from . import units
from . import serialize

router = APIRouter()

@router.get("/")
def list_datasets():
"""
Returned list of available datasets from registry.
"""
datasets = registry.list_datasets()
return datasets

def get_dataset_or_404(dataset_id: str):
dataset = registry.get_dataset(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail=f"Dataset '{dataset_id}' not found")
return dataset

@router.get("/{dataset_id}", response_model=dict)
def get_dataset(dataset_id: str):
"""
Get a single dataset by ID.
"""
dataset = get_dataset_or_404(dataset_id)
cache_info = cache.get_cache_info(dataset)
dataset.update(cache_info)
return dataset

@router.get("/{dataset_id}/build_cache", response_model=dict)
def build_dataset_cache(dataset_id: str, start: str, end: str | None = None, overwrite: bool = False, background_tasks: BackgroundTasks = None):
"""
Download and cache dataset as local netcdf files direct from the source.
"""
dataset = get_dataset_or_404(dataset_id)
cache.build_dataset_cache(dataset, start=start, end=end, overwrite=overwrite, background_tasks=background_tasks)
return {'status': 'Dataset caching request submitted for processing'}

@router.get("/{dataset_id}/optimize_cache", response_model=dict)
def optimize_dataset_cache(dataset_id: str, background_tasks: BackgroundTasks = None):
"""
Optimize dataset cache by collecting all cache files to a single zarr archive.
"""
dataset = get_dataset_or_404(dataset_id)
background_tasks.add_task(cache.optimize_dataset_cache, dataset)
return {'status': 'Dataset cache optimization submitted for processing'}

@router.get("/{dataset_id}/{period_type}/orgunits", response_model=list)
def get_dataset_period_type_org_units(dataset_id: str, period_type: str, start: str, end: str, temporal_aggregation: str, spatial_aggregation: str):
"""
Get a dataset dynamically aggregated to a given period type and org units and return json values.
"""
# get dataset metadata
dataset = get_dataset_or_404(dataset_id)

# get raster data
ds = raster.get_data(dataset, start, end)

# aggregate to period type
ds = raster.to_timeperiod(ds, dataset, period_type, statistic=temporal_aggregation)

# aggregate to geojson features
df = raster.to_features(ds, dataset, features=constants.ORG_UNITS_GEOJSON, statistic=spatial_aggregation)

# convert units if needed (inplace)
# NOTE: here we do it after agggregation to dataframe to speedup computation
units.convert_pandas_units(df, dataset)

# serialize to json
data = serialize.dataframe_to_json_data(df, dataset, period_type)
return data

@router.get("/{dataset_id}/{period_type}/raster")
def get_dataset_period_type_raster(dataset_id: str, period_type: str, start: str, end: str, temporal_aggregation: str):
"""
Get a dataset dynamically aggregated to a given period type and return as downloadable raster file.
"""
# get dataset metadata
dataset = get_dataset_or_404(dataset_id)

# get raster data
ds = raster.get_data(dataset, start, end)

# aggregate to period type
ds = raster.to_timeperiod(ds, dataset, period_type, statistic=temporal_aggregation)

# convert units if needed (inplace)
units.convert_xarray_units(ds, dataset)

# serialize to temporary netcdf
file_path = serialize.xarray_to_temporary_netcdf(ds)

# return as streaming file and delete after completion
return FileResponse(
file_path,
media_type="application/x-netcdf",
filename='eo-api-raster-download.nc',
background=BackgroundTask(serialize.cleanup_file, file_path)
)

@router.get("/{dataset_id}/{period_type}/tiles")
def get_dataset_period_type_tiles(dataset_id: str, period_type: str, start: str, end: str, temporal_aggregation: str):
pass
179 changes: 179 additions & 0 deletions datasets/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import atexit
import importlib
import inspect
import logging
import datetime
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor

import xarray as xr
import numpy as np

from . import registry
from .utils import get_time_dim, get_lon_lat_dims, numpy_period_string
from constants import BBOX, COUNTRY_CODE, CACHE_OVERRIDE

# logger
logger = logging.getLogger(__name__)

# paths
SCRIPT_DIR = Path(__file__).parent.resolve()
CACHE_DIR = SCRIPT_DIR / 'cache'
if CACHE_OVERRIDE:
CACHE_DIR = Path(CACHE_OVERRIDE)

def build_dataset_cache(dataset, start, end, overwrite, background_tasks):
# get download function
cache_info = dataset['cacheInfo']
eo_download_func_path = cache_info['eoFunction']
eo_download_func = get_dynamic_function(eo_download_func_path)
#logger.info(eo_download_func_path, eo_download_func)

# construct standard params
params = cache_info['defaultParams']
params.update({
'start': start,
'end': end or datetime.date.today().isoformat(), # todays date if empty
'dirname': CACHE_DIR,
'prefix': get_cache_prefix(dataset),
'overwrite': overwrite,
})

# add in varying spatial args
sig = inspect.signature(eo_download_func)
if 'bbox' in sig.parameters.keys():
params['bbox'] = BBOX
elif 'country_code' in sig.parameters.keys():
params['country_code'] = COUNTRY_CODE

# execute the download
background_tasks.add_task(eo_download_func, **params)

def optimize_dataset_cache(dataset):
logger.info(f'Optimizing cache for dataset {dataset["id"]}')

# open all cache files as xarray
files = get_cache_files(dataset)
logger.info(f'Opening {len(files)} files from cache')
ds = xr.open_mfdataset(files)

# trim to only minimal vars and coords
logger.info('Trimming unnecessary variables and coordinates')
varname = dataset['variable']
ds = ds[[varname]]
keep_coords = [get_time_dim(ds)] + list(get_lon_lat_dims(ds))
drop_coords = [
c for c in ds.coords
if c not in keep_coords
]
ds = ds.drop_vars(drop_coords)

# determine optimal chunk sizes
logger.info(f'Determining optimal chunk size for zarr archive')
ds_autochunk = ds.chunk('auto').unify_chunks()
# extract the first chunk size for each dimension to force uniformity
uniform_chunks = {dim: ds_autochunk.chunks[dim][0] for dim in ds_autochunk.dims}
# override with time space chunks
time_space_chunks = compute_time_space_chunks(ds, dataset)
uniform_chunks.update( time_space_chunks )
logging.info(f'--> {uniform_chunks}')

# save as zarr
logger.info(f'Saving to optimized zarr file')
zarr_path = CACHE_DIR / f'{get_cache_prefix(dataset)}.zarr'
ds_chunked = ds.chunk(uniform_chunks)
ds_chunked.to_zarr(zarr_path, mode='w')
ds_chunked.close()

logger.info('Finished cache optimization')

def compute_time_space_chunks(ds, dataset, max_spatial_chunk=256):
chunks = {}

# time
# set to common access patterns depending on original dataset period
# TODO: could potentially allow this to be customized in the dataset yaml file
dim = get_time_dim(ds)
period_type = dataset['periodType']
if period_type == 'hourly':
chunks[dim] = 24 * 7
elif period_type == 'daily':
chunks[dim] = 30
elif period_type == 'monthly':
chunks[dim] = 12
elif period_type == 'yearly':
chunks[dim] = 1

# space
lon_dim,lat_dim = get_lon_lat_dims(ds)
chunks[lon_dim] = min(ds.sizes[lon_dim], max_spatial_chunk)
chunks[lat_dim] = min(ds.sizes[lat_dim], max_spatial_chunk)

return chunks

def get_cache_info(dataset):
# find all files with cache prefix
files = get_cache_files(dataset)
if not files:
cache_info = dict(
temporal_coverage = None,
spatial_coverage = None,
)
return cache_info

# open first of sorted filenames, should be sufficient to get earliest time period
ds = xr.open_dataset(sorted(files)[0])

# get dim names
time_dim = get_time_dim(ds)
lon_dim, lat_dim = get_lon_lat_dims(ds)

# get start time
start = numpy_period_string(ds[time_dim].min().values, dataset['periodType'])

# get space scope
xmin,xmax = ds[lon_dim].min().item(), ds[lon_dim].max().item()
ymin,ymax = ds[lat_dim].min().item(), ds[lat_dim].max().item()

# open last of sorted filenames, should be sufficient to get latest time period
ds = xr.open_dataset(sorted(files)[-1])

# get end time
end = numpy_period_string(ds[time_dim].max().values, dataset['periodType'])

# cache info
cache_info = dict(
coverage=dict(
temporal = {'start': start, 'end': end},
spatial = {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax},
)
)
return cache_info

def get_cache_prefix(dataset):
prefix = dataset['id']
return prefix

def get_cache_files(dataset):
# TODO: this is not bulletproof, eg 2m_temperature might also get another dataset named 2m_temperature_modified
# ...probably need a delimeter to specify end of dataset name...
prefix = get_cache_prefix(dataset)
files = list(CACHE_DIR.glob(f'{prefix}*.nc'))
return files

def get_zarr_path(dataset):
prefix = get_cache_prefix(dataset)
optimized = CACHE_DIR / f'{prefix}.zarr'
if optimized.exists():
return optimized

def get_dynamic_function(full_path):
# Split the path into: 'dhis2eo.data.cds.era5_land.hourly' and 'function'
parts = full_path.split('.')
module_path = ".".join(parts[:-1])
function_name = parts[-1]

# This handles all the intermediate sub-package imports automatically
module = importlib.import_module(module_path)

return getattr(module, function_name)
28 changes: 28 additions & 0 deletions datasets/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import logging

import xarray as xr


# logger
logger = logging.getLogger(__name__)


def deaccumulate_era5(ds_cumul):
'''Convert ERA5 cumulative hourly data to incremental hourly data'''

logger.info('Deaccumulating ERA5 dataset')
# NOTE: this is hardcoded to era5 specific cumulative patterns and varnames

# shift all values to previous hour, so the values don't spill over to the next day
ds_cumul = ds_cumul.shift(valid_time=-1)

# convert cumulative to diffs
ds_diffs = ds_cumul.diff(dim='valid_time')
ds_diffs = ds_diffs.reindex(valid_time=ds_cumul.valid_time)

# use cumul values where accumulation resets (00:00) and diff everywhere else
is_reset = ds_cumul['valid_time'].dt.hour == 0
ds_hourly = xr.where(is_reset, ds_cumul, ds_diffs)

# return
return ds_hourly
Loading