Skip to content

Commit c8152a7

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 506a664 commit c8152a7

1 file changed

Lines changed: 19 additions & 24 deletions

File tree

nbs/240508-inference-naip.ipynb

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"outputs": [],
99
"source": [
1010
"import sys\n",
11+
"\n",
1112
"sys.path.append(\"..\")"
1213
]
1314
},
@@ -18,21 +19,16 @@
1819
"metadata": {},
1920
"outputs": [],
2021
"source": [
22+
"import os\n",
23+
"\n",
2124
"import matplotlib.pyplot as plt\n",
2225
"import numpy as np\n",
23-
"from einops import rearrange\n",
24-
"import torch\n",
25-
"import stackstac\n",
26-
"from pystac_client import Client\n",
27-
"import boto3\n",
28-
"import xarray as xr\n",
29-
"import numpy as np\n",
30-
"import os\n",
3126
"import rioxarray\n",
32-
"from box import Box\n",
27+
"import torch\n",
3328
"import yaml\n",
29+
"from box import Box\n",
30+
"from pystac_client import Client\n",
3431
"\n",
35-
"from src.datamodule import ClayDataModule\n",
3632
"from src.model_clay_v1 import ClayMAEModule"
3733
]
3834
},
@@ -44,11 +40,10 @@
4440
"outputs": [],
4541
"source": [
4642
"def plot_rgb(stack):\n",
47-
" stack.sel(band=[1, 2, 3]).plot.imshow(\n",
48-
" rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n",
49-
" )\n",
43+
" stack.sel(band=[1, 2, 3]).plot.imshow(rgb=\"band\", vmin=0, vmax=2000, col_wrap=6)\n",
5044
" plt.show()\n",
51-
" \n",
45+
"\n",
46+
"\n",
5247
"def normalize_latlon(lat, lon):\n",
5348
" lat = lat * np.pi / 180\n",
5449
" lon = lon * np.pi / 180\n",
@@ -124,7 +119,7 @@
124119
"\n",
125120
" # The first embedding is the class token, which is the\n",
126121
" # overall single embedding. We extract that for PCA below.\n",
127-
" return unmsk_patch[:, 0, :].cpu().numpy()\n"
122+
" return unmsk_patch[:, 0, :].cpu().numpy()"
128123
]
129124
},
130125
{
@@ -61598,11 +61593,10 @@
6159861593
" assets = item.assets\n",
6159961594
" dataset = rioxarray.open_rasterio(item.assets[\"image\"].href).sel(band=[1, 2, 3, 4])\n",
6160061595
" print(\"dataset: \", dataset)\n",
61601-
" granule_name = item.assets[\"image\"].href.split('/')[-1]\n",
61596+
" granule_name = item.assets[\"image\"].href.split(\"/\")[-1]\n",
6160261597
" stackstac_datasets.append(dataset)\n",
6160361598
" granule_names.append(granule_name)\n",
61604-
" \n",
61605-
" \n",
61599+
"\n",
6160661600
"\n",
6160761601
"# Function to tile dataset into 256x256 image chips and drop any excess border regions\n",
6160861602
"def tile_dataset(dataset, granule_name):\n",
@@ -61628,26 +61622,27 @@
6162861622
" y_end = y_start + 256\n",
6162961623
"\n",
6163061624
" # Extract the tile from the cropped dataset\n",
61631-
" tile = cropped_dataset.isel(x=slice(x_start, x_end), y=slice(y_start, y_end))\n",
61625+
" tile = cropped_dataset.isel(\n",
61626+
" x=slice(x_start, x_end), y=slice(y_start, y_end)\n",
61627+
" )\n",
6163261628
" print(tile.shape)\n",
6163361629
"\n",
6163461630
" # Save the tile as a GeoTIFF\n",
6163561631
" tile_path = f\"{save_dir}/{granule_name[:-4]}_{x_idx}_{y_idx}.tif\"\n",
61636-
" #tile.rio.to_raster(tile_path)\n",
61632+
" # tile.rio.to_raster(tile_path)\n",
6163761633
" print(tile)\n",
6163861634
" tiles.append(tile)\n",
61639-
" \n",
61635+
"\n",
6164061636
" return tiles\n",
61641-
" \n",
6164261637
"\n",
6164361638
"\n",
6164461639
"# Tile each dataset\n",
6164561640
"for dataset, granule_name in zip(stackstac_datasets[0:2], granule_names[0:2]):\n",
6164661641
" tiles = tile_dataset(dataset, granule_name)\n",
61647-
" \n",
61642+
"\n",
6164861643
"tile_0 = tiles[0]\n",
6164961644
"\n",
61650-
"plot_rgb(tile_0)\n"
61645+
"plot_rgb(tile_0)"
6165161646
]
6165261647
},
6165361648
{

0 commit comments

Comments
 (0)