Skip to content

Commit e937d29

Browse files
Add a laplace smoothing to preserve coastal fills and restrict second pass to interior
1 parent ea04a76 commit e937d29

1 file changed

Lines changed: 106 additions & 0 deletions

File tree

external_tidal_generation/generate_bottom_roughness_intermediate_woa.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,112 @@ def gatherv_indexed(
381381
return out1d
382382

383383

384+
def fill_missing_data_laplace(
385+
field: np.ndarray, mask: np.ndarray, periodic_lon: bool = True
386+
) -> np.ndarray:
387+
"""
388+
Fill nans smoothly by solving a discrete Laplace problem over the wet domain.
389+
390+
This is adapted from https://github.com/ACCESS-NRI/om3-scripts/blob/main/chlorophyll/chl_climatology_and_fill.py,
391+
which originally adapted from https://github.com/adcroft/interp_and_fill/blob/main/Interpolate%20and%20fill%20SeaWIFS.ipynb
392+
393+
This implementation otherwise assumes a regular lat/lon grid tri-polar (WOA),
394+
hence tripolar topology is intentionally not handled here.
395+
396+
Periodic boundary conditions are supported in longitude only (global configuration).
397+
For regional configurations, set periodic_lon=False is not implemented yet.
398+
"""
399+
nj, ni = field.shape
400+
# Find the missing points to fill (nan in field but mask > 0)
401+
missing_mask = np.isnan(field) & (mask > 0)
402+
if not np.any(missing_mask):
403+
# no missing data to fill but also guarantee nans on dry cells
404+
return np.where(mask > 0, field, np.nan)
405+
406+
# change nan to 0 for the sparse matrix construction
407+
work = np.where(np.isnan(field), 0.0, field)
408+
missing_j, missing_i = np.where(missing_mask)
409+
n_missing = missing_j.size
410+
ind = np.full((nj, ni), -1, dtype=np.int64)
411+
ind[missing_j, missing_i] = np.arange(n_missing)
412+
413+
# Sparse matrix
414+
A = sp.lil_matrix((n_missing, n_missing))
415+
b = np.zeros(n_missing)
416+
ld = np.zeros(n_missing)
417+
418+
def _process_neighbour(n: int, jn: int, in_: int) -> None:
419+
"""Process neighbour at (jn, in_) for row n."""
420+
if mask[jn, in_] <= 0:
421+
return
422+
423+
ld[n] -= 1
424+
idx = ind[jn, in_]
425+
426+
if idx >= 0:
427+
A[n, idx] = 1
428+
else:
429+
b[n] -= work[jn, in_]
430+
431+
for n in range(n_missing):
432+
j = missing_j[n]
433+
i = missing_i[n]
434+
435+
if periodic_lon:
436+
im1 = (i - 1) % ni # west
437+
ip1 = (i + 1) % ni # east
438+
_process_neighbour(n, j, im1)
439+
_process_neighbour(n, j, ip1)
440+
else:
441+
# TODO handle non-periodic case if needed
442+
raise NotImplementedError(
443+
"Non-periodic longitude is not implemented yet. "
444+
"Set periodic_lon=True for global grids."
445+
)
446+
447+
if j > 0:
448+
_process_neighbour(n, j - 1, i) # south
449+
if j < nj - 1:
450+
_process_neighbour(n, j + 1, i) # north
451+
452+
stabilizer = 1e-14 # prevent singular matrix
453+
A[np.arange(n_missing), np.arange(n_missing)] = ld - stabilizer
454+
x = spla.spsolve(A.tocsr(), b)
455+
work[missing_j, missing_i] = x
456+
work = np.where(mask > 0, work, np.nan)
457+
return work
458+
459+
460+
def laplace_smooth(
461+
field: np.ndarray, mask: np.ndarray, erosion_iters: int = 2
462+
) -> np.ndarray:
463+
"""
464+
Smooth field over wet cells only by applying a Laplace smoother iteratively.
465+
stage1: Fill missing values over the original wet mask using a Laplacian solver.
466+
This ensures coastal nans gets filled and the field is continuous across the entire ocean domain.
467+
stage2: Erode the wet mask inward by `erosion_iters` number of grid cells, and apply the
468+
Laplacian solver again over this reduced (interior) region. This step smooths the interior.
469+
470+
erosion_iters: Number of grid cells by which to shrink the wet mask before the 2nd smoothing stage.
471+
Larger values reduce coastal influence more strongly.
472+
"""
473+
wet = mask > 0
474+
stage1 = fill_missing_data_laplace(field, mask=wet, periodic_lon=True)
475+
476+
if erosion_iters <= 0:
477+
return np.where(wet, stage1, np.nan)
478+
479+
eroded_mask = ndimage.binary_erosion(wet, iterations=erosion_iters)
480+
481+
stage2_interior = fill_missing_data_laplace(
482+
stage1, mask=eroded_mask, periodic_lon=True
483+
)
484+
485+
out = stage1.copy()
486+
out[eroded_mask] = stage2_interior[eroded_mask] # only overwrite interior
487+
return np.where(wet, out, np.nan)
488+
489+
384490
def compute_mean_depth_and_var_points(
385491
lon_np: np.ndarray,
386492
lat_np: np.ndarray,

0 commit comments

Comments
 (0)