88 "outputs": [],
99 "source": [
1010 "import sys\n",
11+ "\n",
1112 "sys.path.append(\"..\")"
1213 ]
1314 },
1819 "metadata": {},
1920 "outputs": [],
2021 "source": [
22+ "import os\n",
23+ "\n",
2124 "import matplotlib.pyplot as plt\n",
2225 "import numpy as np\n",
23- "from einops import rearrange\n",
24- "import torch\n",
25- "import stackstac\n",
26- "from pystac_client import Client\n",
27- "import boto3\n",
28- "import xarray as xr\n",
29- "import numpy as np\n",
30- "import os\n",
3126 "import rioxarray\n",
32- "from box import Box \n",
27+ "import torch \n",
3328 "import yaml\n",
29+ "from box import Box\n",
30+ "from pystac_client import Client\n",
3431 "\n",
35- "from src.datamodule import ClayDataModule\n",
3632 "from src.model_clay_v1 import ClayMAEModule"
3733 ]
3834 },
4440 "outputs": [],
4541 "source": [
4642 "def plot_rgb(stack):\n",
47- " stack.sel(band=[1, 2, 3]).plot.imshow(\n",
48- " rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n",
49- " )\n",
43+ " stack.sel(band=[1, 2, 3]).plot.imshow(rgb=\"band\", vmin=0, vmax=2000, col_wrap=6)\n",
5044 " plt.show()\n",
51- " \n",
45+ "\n",
46+ "\n",
5247 "def normalize_latlon(lat, lon):\n",
5348 " lat = lat * np.pi / 180\n",
5449 " lon = lon * np.pi / 180\n",
124119 "\n",
125120 " # The first embedding is the class token, which is the\n",
126121 " # overall single embedding. We extract that for PCA below.\n",
127- " return unmsk_patch[:, 0, :].cpu().numpy()\n "
122+ " return unmsk_patch[:, 0, :].cpu().numpy()"
128123 ]
129124 },
130125 {
@@ -61598,11 +61593,10 @@
6159861593 " assets = item.assets\n",
6159961594 " dataset = rioxarray.open_rasterio(item.assets[\"image\"].href).sel(band=[1, 2, 3, 4])\n",
6160061595 " print(\"dataset: \", dataset)\n",
61601- " granule_name = item.assets[\"image\"].href.split('/' )[-1]\n",
61596+ " granule_name = item.assets[\"image\"].href.split(\"/\" )[-1]\n",
6160261597 " stackstac_datasets.append(dataset)\n",
6160361598 " granule_names.append(granule_name)\n",
61604- " \n",
61605- " \n",
61599+ "\n",
6160661600 "\n",
6160761601 "# Function to tile dataset into 256x256 image chips and drop any excess border regions\n",
6160861602 "def tile_dataset(dataset, granule_name):\n",
@@ -61628,26 +61622,27 @@
6162861622 " y_end = y_start + 256\n",
6162961623 "\n",
6163061624 " # Extract the tile from the cropped dataset\n",
61631- " tile = cropped_dataset.isel(x=slice(x_start, x_end), y=slice(y_start, y_end))\n",
61625+ " tile = cropped_dataset.isel(\n",
61626+ " x=slice(x_start, x_end), y=slice(y_start, y_end)\n",
61627+ " )\n",
6163261628 " print(tile.shape)\n",
6163361629 "\n",
6163461630 " # Save the tile as a GeoTIFF\n",
6163561631 " tile_path = f\"{save_dir}/{granule_name[:-4]}_{x_idx}_{y_idx}.tif\"\n",
61636- " #tile.rio.to_raster(tile_path)\n",
61632+ " # tile.rio.to_raster(tile_path)\n",
6163761633 " print(tile)\n",
6163861634 " tiles.append(tile)\n",
61639- " \n",
61635+ "\n",
6164061636 " return tiles\n",
61641- " \n",
6164261637 "\n",
6164361638 "\n",
6164461639 "# Tile each dataset\n",
6164561640 "for dataset, granule_name in zip(stackstac_datasets[0:2], granule_names[0:2]):\n",
6164661641 " tiles = tile_dataset(dataset, granule_name)\n",
61647- " \n",
61642+ "\n",
6164861643 "tile_0 = tiles[0]\n",
6164961644 "\n",
61650- "plot_rgb(tile_0)\n "
61645+ "plot_rgb(tile_0)"
6165161646 ]
6165261647 },
6165361648 {
0 commit comments