Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/.backend_git_ref
Original file line number Diff line number Diff line change
@@ -1 +1 @@
7aadfa383e6eee63442e366890dfb1160114caed
9fcd0e8d520b3e7679d29c969263345ea190ec46
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -34,7 +34,15 @@
"import ipyvuetify as vue\n",
"import numpy as np\n",
"import xarray as xr\n",
"from geoengine_openapi_client.models import MlModelMetadata, RasterDataType\n",
"from geoengine_openapi_client.models import (\n",
" MlModelInputNoDataHandling,\n",
" MlModelInputNoDataHandlingVariant,\n",
" MlModelMetadata,\n",
" MlModelOutputNoDataHandling,\n",
" MlModelOutputNoDataHandlingVariant,\n",
" MlTensorShape3D,\n",
" RasterDataType,\n",
")\n",
"from matplotlib import pyplot as plt\n",
"from matplotlib.patches import Circle\n",
"from onnx.checker import check_model\n",
Expand Down Expand Up @@ -335,11 +343,18 @@
" onnx_model=onnx_clf,\n",
" model_config=ge.ml.MlModelConfig(\n",
" name=model_name,\n",
" file_name=\"model.onnx\",\n",
" metadata=MlModelMetadata(\n",
" file_name=\"model.onnx\",\n",
" input_type=RasterDataType.F32,\n",
" num_input_bands=4,\n",
" output_type=RasterDataType.U8,\n",
" inputType=RasterDataType.F32,\n",
" outputType=RasterDataType.U8,\n",
" inputShape=MlTensorShape3D(x=1, y=1, bands=4),\n",
" outputShape=MlTensorShape3D(x=1, y=1, bands=1),\n",
" inputNoDataHandling=MlModelInputNoDataHandling(\n",
" variant=MlModelInputNoDataHandlingVariant.SKIPIFNODATA\n",
" ),\n",
" outputNoDataHandling=MlModelOutputNoDataHandling(\n",
" variant=MlModelOutputNoDataHandlingVariant.NANISNODATA\n",
" ),\n",
" ),\n",
" display_name=\"Decision Tree\",\n",
" description=\"A simple decision tree model\",\n",
Expand Down Expand Up @@ -813,7 +828,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
36 changes: 28 additions & 8 deletions examples/ml_pipeline.ipynb

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions geoengine/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pathlib import Path

import geoengine_openapi_client
import geoengine_openapi_client.models
from geoengine_openapi_client.models import MlModel, MlModelMetadata, MlTensorShape3D, RasterDataType
from onnx import ModelProto, TensorProto, TypeProto
from onnx.helper import tensor_dtype_to_string
Expand All @@ -24,6 +23,7 @@ class MlModelConfig:
"""Configuration for an ml model"""

name: str
file_name: str
metadata: MlModelMetadata
display_name: str = "My Ml Model"
description: str = "My Ml Model Description"
Expand All @@ -47,7 +47,7 @@ def register_ml_model(

with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
with tempfile.TemporaryDirectory() as temp_dir:
file_name = Path(temp_dir) / model_config.metadata.file_name
file_name = Path(temp_dir) / model_config.file_name

with open(file_name, "wb") as file:
file.write(onnx_model.SerializeToString())
Expand All @@ -61,6 +61,7 @@ def register_ml_model(

model = MlModel(
name=model_config.name,
file_name=model_config.file_name,
upload=str(upload_id),
metadata=model_config.metadata,
display_name=model_config.display_name,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ readme = { file = "README.md", content-type = "text/markdown" }
license-files = ["LICENSE"]
requires-python = ">=3.10"
dependencies = [
"geoengine-openapi-client == 0.0.25",
"geoengine-openapi-client == 0.0.26",
"geopandas >=1.0,<2.0",
"matplotlib >=3.5,<3.11",
"numpy >=1.21,<2.4",
Expand Down
74 changes: 53 additions & 21 deletions tests/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
import unittest

import numpy as np
from geoengine_openapi_client.models import MlModelMetadata, MlTensorShape3D, RasterDataType
from geoengine_openapi_client.models import (
MlModelInputNoDataHandling,
MlModelInputNoDataHandlingVariant,
MlModelMetadata,
MlModelOutputNoDataHandling,
MlModelOutputNoDataHandlingVariant,
MlTensorShape3D,
RasterDataType,
)
from onnx import TensorShapeProto as TSP
from skl2onnx import to_onnx
from sklearn.ensemble import RandomForestClassifier
Expand Down Expand Up @@ -84,12 +92,18 @@ def test_uploading_onnx_model(self):
onnx_model=onnx_clf,
model_config=ge.ml.MlModelConfig(
name=model_name,
file_name="model.onnx",
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F32,
output_type=RasterDataType.I64,
input_shape=MlTensorShape3D(y=1, x=1, bands=2),
output_shape=MlTensorShape3D(y=1, x=1, bands=1),
inputType=RasterDataType.F32,
outputType=RasterDataType.I64,
inputShape=MlTensorShape3D(y=1, x=1, bands=2),
outputShape=MlTensorShape3D(y=1, x=1, bands=1),
inputNoDataHandling=MlModelInputNoDataHandling(
variant=MlModelInputNoDataHandlingVariant.SKIPIFNODATA
),
outputNoDataHandling=MlModelOutputNoDataHandling(
variant=MlModelOutputNoDataHandlingVariant.NANISNODATA
),
),
display_name="Decision Tree",
description="A simple decision tree model",
Expand Down Expand Up @@ -120,12 +134,18 @@ def test_uploading_onnx_model(self):
onnx_model=onnx_clf,
model_config=ge.ml.MlModelConfig(
name=model_name,
file_name="model.onnx",
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F32,
output_type=RasterDataType.I64,
input_shape=MlTensorShape3D(y=1, x=1, bands=4),
output_shape=MlTensorShape3D(y=1, x=1, bands=1),
inputType=RasterDataType.F32,
outputType=RasterDataType.I64,
inputShape=MlTensorShape3D(y=1, x=1, bands=4),
outputShape=MlTensorShape3D(y=1, x=1, bands=1),
inputNoDataHandling=MlModelInputNoDataHandling(
variant=MlModelInputNoDataHandlingVariant.SKIPIFNODATA
),
outputNoDataHandling=MlModelOutputNoDataHandling(
variant=MlModelOutputNoDataHandlingVariant.NANISNODATA
),
),
display_name="Decision Tree",
description="A simple decision tree model",
Expand All @@ -140,12 +160,18 @@ def test_uploading_onnx_model(self):
onnx_model=onnx_clf,
model_config=ge.ml.MlModelConfig(
name=model_name,
file_name="model.onnx",
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F64,
output_type=RasterDataType.I64,
input_shape=MlTensorShape3D(y=1, x=1, bands=2),
output_shape=MlTensorShape3D(y=1, x=1, bands=1),
inputType=RasterDataType.F64,
outputType=RasterDataType.I64,
inputShape=MlTensorShape3D(y=1, x=1, bands=2),
outputShape=MlTensorShape3D(y=1, x=1, bands=1),
inputNoDataHandling=MlModelInputNoDataHandling(
variant=MlModelInputNoDataHandlingVariant.SKIPIFNODATA
),
outputNoDataHandling=MlModelOutputNoDataHandling(
variant=MlModelOutputNoDataHandlingVariant.NANISNODATA
),
),
display_name="Decision Tree",
description="A simple decision tree model",
Expand All @@ -161,12 +187,18 @@ def test_uploading_onnx_model(self):
onnx_model=onnx_clf,
model_config=ge.ml.MlModelConfig(
name="foo",
file_name="model.onnx",
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F32,
output_type=RasterDataType.I32,
input_shape=MlTensorShape3D(y=1, x=1, bands=2),
output_shape=MlTensorShape3D(y=1, x=1, bands=1),
inputType=RasterDataType.F32,
outputType=RasterDataType.I32,
inputShape=MlTensorShape3D(y=1, x=1, bands=2),
outputShape=MlTensorShape3D(y=1, x=1, bands=1),
inputNoDataHandling=MlModelInputNoDataHandling(
variant=MlModelInputNoDataHandlingVariant.SKIPIFNODATA
),
outputNoDataHandling=MlModelOutputNoDataHandling(
variant=MlModelOutputNoDataHandlingVariant.NANISNODATA
),
),
display_name="Decision Tree",
description="A simple decision tree model",
Expand Down