Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ dependencies = [
"GDAL",
"azure-storage-blob",
"overturemaps",
"async",
"uvloop",
"aiofiles",
"aiohttp",
"requests",
"azure-storage-file-share",
"asyncclick",
"rio-cogeo",
Expand All @@ -44,6 +45,7 @@ dependencies = [
"playwright",
"pystac_client",
"requests-oauthlib",
"python-dateutil",
"fiona",
"nest_asyncio",
"tensorflow==2.16.2",
Expand Down
3 changes: 3 additions & 0 deletions rapida/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import click
import nest_asyncio
nest_asyncio.apply()
import uvloop
import asyncio
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())



Expand Down
198 changes: 179 additions & 19 deletions rapida/components/landuse/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import asyncio
import datetime
import rasterio
import logging
import os
from concurrent.futures.thread import ThreadPoolExecutor
from threading import Event
from concurrent.futures import as_completed
from typing import List
from osgeo import gdal
from affine import Affine
from rasterio.enums import Resampling
from osgeo import gdal, gdal_array
from osgeo_utils.gdal_calc import Calc
from rich.progress import Progress
from rich.progress import Progress, TimeElapsedColumn
import geopandas as gpd

from rapida.components.landuse.search_utils.s2item import Sentinel2Item
from rapida.components.landuse.constants import SENTINEL2_BAND_MAP
from rapida.components.landuse.download import download_stac, find_sentinel_imagery
from rapida.components.landuse.constants import STAC_MAP
from rapida.components.landuse.sentinel_item import SentinelItem
Expand All @@ -19,10 +26,24 @@
from rapida.stats.raster_zonal_stats import zst
from rapida.util import geo

from rapida.components.landuse.search_utils.search import fetch_s2_tiles



logger = logging.getLogger('rapida')


def gdal_rich_callback(complete, message, user_data):
progress, task, stop = user_data
if stop and stop.is_set():
logger.info(f'GDAL received timeout signal')
return 0
if progress is not None and task is not None:
logger.info(f'{complete}')
progress.update(task, completed=int(complete * 100))
return 1


class LanduseComponent(Component):

def __call__(self, variables: List[str], datetime_range: str=None, cloud_cover:int = None, **kwargs):
Expand Down Expand Up @@ -101,22 +122,24 @@ def prediction_output_image(self) -> str:


def __init__(self, **kwargs):
kwargs.update({'s2_tiles':{}}) # only way to patch because of pydantic
super().__init__(**kwargs)
project = Project(path=os.getcwd())
geopackage_path = project.geopackage_file_path
output_filename = f"{self.name}.tif"
self.local_path = os.path.join(os.path.dirname(geopackage_path), self.component, output_filename)


def __call__(self, *args, **kwargs):
progress: Progress = kwargs.get('progress', None)

variable_task = None
if progress is not None:
variable_task = progress.add_task(
description=f'[blue] Assessing {self.component}->{self.name}', total=None)

try:
self.download(**kwargs)

self.compute(**kwargs)

if progress is not None and variable_task is not None:
Expand All @@ -136,27 +159,156 @@ def __call__(self, *args, **kwargs):
def download_new(self, force=False, **kwargs):

project = Project(os.getcwd())

start_date, end_date = self.datetime_range.split('/')
progress: Progress = kwargs.get('progress', None)
stop = Event()
total_n_files = 0
s2_images = {}



if force or not os.path.exists(self.prediction_output_image):
s2_items = find_sentinel_imagery( stac_url=self.stac_url,
collection_id=self.collection_id,
geopackage_file_path=project.geopackage_file_path,
polygons_layer_name=project.polygons_layer_name,
datetime_range=self.datetime_range,
cloud_cover=self.cloud_cover,
progress=progress)

s2_tiles_dict = fetch_s2_tiles(stac_url=self.stac_url, bbox=project.geobounds,
start_date=start_date, end_date=end_date,
max_cloud_cover=self.cloud_cover, progress=progress, prune=True,#filter_for_dev=['36MYB']
)


for k, v in s2_tiles_dict.items():
total_n_files+= len(v) * len(SENTINEL2_BAND_MAP)

output_dir = os.path.dirname(self.prediction_output_image)
os.makedirs(output_dir, exist_ok=True)
download_task = None
if progress:
download_task = progress.add_task(f"[cyan]Downloading {len(s2_items)} Sentinel2 items", total=len(s2_items))
for item in s2_items:
item.download_assets(download_dir=output_dir, progress=progress, force=force)
if progress and download_task:
progress.update(download_task, description=f"[green]Downloaded {item.id} ", advance=1)
if progress and download_task:
progress.update(download_task, description=f"[green]Downloaded {len(s2_items)} Sentinel2 items ")
download_task = progress.add_task(f"[cyan]Downloading {total_n_files} Sentinel2 images in {len(s2_tiles_dict)} grids ", total=len(s2_tiles_dict))

failed= {}
ndone = 0
downloaded = {}
with ThreadPoolExecutor(max_workers=5, ) as tpe:
jobs = dict()
for mgrs_grid_id, candidates in s2_tiles_dict.items():
#print(json.dumps(candidates[0].assets, indent=4))
s2i = Sentinel2Item(mgrs_grid=mgrs_grid_id, s2_tiles=candidates, workdir=output_dir, target_crs=project.projection)
self.s2_tiles[mgrs_grid_id] = s2i
jobs[tpe.submit(s2i.download, bands=s2i.bands, progress=progress, force=force)] = mgrs_grid_id

try:
for future in as_completed(jobs):
grid = jobs[future]
try:
downloaded_files = future.result()
ndone+=1
if progress is not None and download_task is not None:
progress.update(download_task,advance=1)
except Exception as e:
failed[grid] = e
if progress is not None and download_task is not None:
progress.update(download_task, advance=1)
raise
downloaded[grid] = downloaded_files
except KeyboardInterrupt:
stop.set()
for s2i in self.s2_tiles.values():
loop = getattr(s2i, "_loop", None)
task = getattr(s2i, "_task", None)
if loop and task and not task.done():
try:
loop.call_soon_threadsafe(task.cancel)
except RuntimeError:
pass
# this only cancels pending (not running) futures, still good to call:
tpe.shutdown(wait=False, cancel_futures=True)
raise

finally:

if progress is not None and download_task is not None:
progress.update(download_task,
#description=f'[red]Downloaded Sentinel2 imagery in {ndone} MGRS grids',
advance=1)

for grid, err in failed.items():
logger.error(f'Failed to download S2 imagery in {grid}, {err}')
# it is debatable if the error is to be swallowed or propagated but at least user should be aware of the
# fact the download failed in some grid/tiles


for mgrs_grid, s2itm in self.s2_tiles.items():
for band, vrt in s2itm.vrts.items():
if vrt is None:
continue
if band not in s2_images:s2_images[band] = []
s2_images[band].append(vrt)

vrts = []
gdal_cache_mb = 4096
env = rasterio.Env(
GDAL_NUM_THREADS="ALL_CPUS", # multithread warp + COG creation
GDAL_CACHEMAX=gdal_cache_mb, # warp cache
)
try:
for band, band_vrt_files in s2_images.items():

#band_vrt = os.path.join(output_dir, f'{band}.vrt')
band_cog = os.path.join(output_dir, f'{band}.tif')
#band_cog = band_vrt.replace('.vrt', '.tif')
band_vrt = f'/vsimem/{start_date}_{end_date}_{band}.vrt'

with env, gdal.BuildVRT(destName=band_vrt, srcDSOrSrcDSTab=band_vrt_files) as src_ds:
src_ds.FlushCache()
vrts.append(band_vrt)

# rio_copy(band_vrt, band_cog, driver='COG', COMPRESS='NONE', BLOCKSIZE=1024,NUM_THREADS='ALL_CPUS',
# RESAMPLING='cubic', OVERVIEW_LEVELS='NONE', callback=gdal_rich_callback,
# callback_data=(progress, cp_task, stop)
# )
src_band = src_ds.GetRasterBand(1)

transform = Affine.from_gdal(*src_ds.GetGeoTransform())
dtp = gdal_array.GDALTypeCodeToNumericTypeCode(src_band.DataType)
srs = src_ds.GetProjectionRef()

with rasterio.open(band_cog, mode='w', driver='COG', width=src_ds.RasterXSize, height=src_ds.RasterYSize, count=1, crs=srs,
transform=transform, dtype=dtp, nodata=src_band.GetNoDataValue(), blocksize=1024, num_threads='ALL_CPUS',
resampling=Resampling.cubic,overview_levels=None, compress=None) as dst:


wins = {blck: win for blck, win in dst.block_windows()}
cp_task = None
if progress is not None:
cp_task = progress.add_task(description=f'[red]Saving {band_vrt} to {band_cog}', total=len(wins))
for bl, win in wins.items():
if stop is not None and stop.is_set():
raise KeyboardInterrupt


src_data = src_ds.ReadAsArray(xoff=win.col_off, yoff=win.row_off,
xsize=win.width, ysize=win.height,band_list=[1]
)
dst.write(src_data, 1, window=win)

if progress is not None and cp_task is not None:
progress.update(cp_task, advance=1)



if progress is not None and cp_task is not None:
progress.remove_task(cp_task)

gdal.UnlinkBatch(vrts)

except KeyboardInterrupt:
stop.set()
raise








Expand Down Expand Up @@ -250,8 +402,16 @@ def _compute_affected_(self, **kwargs):

return affected_local_path


def compute(self, **kwargs):
progress = kwargs.get('progress', None)
variable_task = None
if progress:
variable_task = progress.add_task(f"[red]Creating variable {self.name}", total=100)

source_value = self.target_band_value


def compute_old(self, **kwargs):
progress = kwargs.get('progress', None)
variable_task = None
if progress:
Expand Down
4 changes: 3 additions & 1 deletion rapida/components/landuse/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@
"cirrus": "B10", # cloud
"swir16": "B11", # land use, cloud
"swir22": "B12", # land use, cloud
}
}

SENTINEL2_BAND_MAP = {v:k for k,v in SENTINEL2_ASSET_MAP.items()}
12 changes: 8 additions & 4 deletions rapida/components/landuse/prediction/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,11 @@ def predict(self,
bounds = rasterio.windows.bounds(window, dst_transform)
tile_geom = box(bounds[0], bounds[1], bounds[2], bounds[3])

if mask_union.intersects(tile_geom):
tile_jobs.append((row, col))
all_cols.append(col)
all_rows.append(row)
if mask_union is not None and not mask_union.intersects(tile_geom):
continue
tile_jobs.append((row, col))
all_cols.append(col)
all_rows.append(row)

# Determine bounding box
min_col = min(all_cols)
Expand Down Expand Up @@ -381,6 +382,7 @@ def predict(self,

if num_workers is None:
num_workers = psutil.cpu_count(logical=False)
logger.info(f'Going to use {num_workers} workers in {len(tile_jobs)} tiles')

with ProcessPoolExecutor(max_workers=num_workers) as executor:
job_iter = iter(tile_jobs)
Expand All @@ -397,6 +399,8 @@ def predict(self,
fut = executor.submit(self.process_tile, row, col, img_paths, min_resolution_path, mask_union)
running_futures[fut] = (row, col, task_id)

print(len(running_futures))

while running_futures:
# wait for any future to complete
done, _ = concurrent.futures.wait(running_futures, return_when=concurrent.futures.FIRST_COMPLETED)
Expand Down
2 changes: 2 additions & 0 deletions rapida/components/landuse/prediction/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@


Empty file.
Loading