11'''Tests ML functionality'''
22
3+ from typing import List
34import unittest
5+ from onnx import TensorShapeProto as TSP
46from sklearn .ensemble import RandomForestClassifier
57from skl2onnx import to_onnx
68import numpy as np
7- from geoengine_openapi_client .models import MlModelMetadata , RasterDataType
9+ from geoengine_openapi_client .models import MlModelMetadata , RasterDataType , MlTensorShape3D
810import geoengine as ge
11+ from geoengine .ml import model_dim_to_tensorshape
912from tests .ge_test import GeoEngineTestInstance
1013
1114
12- class WorkflowStorageTests (unittest .TestCase ):
13- '''Test methods for storing workflows as datasets '''
15+ class MlModelTests (unittest .TestCase ):
16+ '''Test methods for MlModels '''
1417
1518 def setUp (self ) -> None :
1619 ge .reset (False )
1720
21+ def test_model_dim_to_tensorshape (self ):
22+ ''' Test model_dim_to_tensorshape '''
23+
24+ dim_1d : List [TSP .Dimension ] = [TSP .Dimension (dim_value = 7 )]
25+ mts_1d = MlTensorShape3D (bands = 7 , y = 1 , x = 1 )
26+ self .assertEqual (model_dim_to_tensorshape (dim_1d ), mts_1d )
27+
28+ dim_1d_v : List [TSP .Dimension ] = [TSP .Dimension (dim_value = None ), TSP .Dimension (dim_value = 7 )]
29+ mts_1d_v = MlTensorShape3D (bands = 7 , y = 1 , x = 1 )
30+ self .assertEqual (model_dim_to_tensorshape (dim_1d_v ), mts_1d_v )
31+
32+ dim_2d_t : List [TSP .Dimension ] = [TSP .Dimension (dim_value = 512 ), TSP .Dimension (dim_value = 512 )]
33+ mts_2d_t = MlTensorShape3D (bands = 1 , y = 512 , x = 512 )
34+ self .assertEqual (model_dim_to_tensorshape (dim_2d_t ), mts_2d_t )
35+
36+ dim_2d_1 : List [TSP .Dimension ] = [TSP .Dimension (dim_value = 1 ), TSP .Dimension (dim_value = 7 )]
37+ mts_2d_1 = MlTensorShape3D (bands = 7 , y = 1 , x = 1 )
38+ self .assertEqual (model_dim_to_tensorshape (dim_2d_1 ), mts_2d_1 )
39+
40+ dim_3d_t : List [TSP .Dimension ] = [
41+ TSP .Dimension (dim_value = 512 ), TSP .Dimension (dim_value = 512 ), TSP .Dimension (dim_value = 7 )
42+ ]
43+ mts_3d_t = MlTensorShape3D (bands = 7 , y = 512 , x = 512 )
44+ self .assertEqual (model_dim_to_tensorshape (dim_3d_t ), mts_3d_t )
45+
46+ dim_3d_v : List [TSP .Dimension ] = [
47+ TSP .Dimension (dim_value = None ), TSP .Dimension (dim_value = 512 ), TSP .Dimension (dim_value = 512 )
48+ ]
49+ mts_3d_v = MlTensorShape3D (bands = 1 , y = 512 , x = 512 )
50+ self .assertEqual (model_dim_to_tensorshape (dim_3d_v ), mts_3d_v )
51+
52+ dim_4d_v : List [TSP .Dimension ] = [
53+ TSP .Dimension (dim_value = None ),
54+ TSP .Dimension (dim_value = 512 ),
55+ TSP .Dimension (dim_value = 512 ),
56+ TSP .Dimension (dim_value = 4 )
57+ ]
58+ mts_4d_v = MlTensorShape3D (bands = 4 , y = 512 , x = 512 )
59+ self .assertEqual (model_dim_to_tensorshape (dim_4d_v ), mts_4d_v )
60+
1861 def test_uploading_onnx_model (self ):
1962
2063 clf = RandomForestClassifier (random_state = 42 )
@@ -40,8 +83,9 @@ def test_uploading_onnx_model(self):
4083 metadata = MlModelMetadata (
4184 file_name = "model.onnx" ,
4285 input_type = RasterDataType .F32 ,
43- num_input_bands = 2 ,
4486 output_type = RasterDataType .I64 ,
87+ input_shape = MlTensorShape3D (y = 1 , x = 1 , bands = 2 ),
88+ output_shape = MlTensorShape3D (y = 1 , x = 1 , bands = 1 )
4589 ),
4690 display_name = "Decision Tree" ,
4791 description = "A simple decision tree model" ,
@@ -77,16 +121,17 @@ def test_uploading_onnx_model(self):
77121 metadata = MlModelMetadata (
78122 file_name = "model.onnx" ,
79123 input_type = RasterDataType .F32 ,
80- num_input_bands = 4 ,
81124 output_type = RasterDataType .I64 ,
125+ input_shape = MlTensorShape3D (y = 1 , x = 1 , bands = 4 ),
126+ output_shape = MlTensorShape3D (y = 1 , x = 1 , bands = 1 )
82127 ),
83128 display_name = "Decision Tree" ,
84129 description = "A simple decision tree model" ,
85130 )
86131 )
87132 self .assertEqual (
88133 str (exception .exception ),
89- 'Model input has 2 bands, but 4 bands are expected '
134+ 'Input shape bands=2 x=1 y=1 and metadata bands=4 x=1 y=1 not equal! '
90135 )
91136
92137 with self .assertRaises (ge .InputException ) as exception :
@@ -97,8 +142,9 @@ def test_uploading_onnx_model(self):
97142 metadata = MlModelMetadata (
98143 file_name = "model.onnx" ,
99144 input_type = RasterDataType .F64 ,
100- num_input_bands = 2 ,
101145 output_type = RasterDataType .I64 ,
146+ input_shape = MlTensorShape3D (y = 1 , x = 1 , bands = 2 ),
147+ output_shape = MlTensorShape3D (y = 1 , x = 1 , bands = 1 )
102148 ),
103149 display_name = "Decision Tree" ,
104150 description = "A simple decision tree model" ,
@@ -117,8 +163,9 @@ def test_uploading_onnx_model(self):
117163 metadata = MlModelMetadata (
118164 file_name = "model.onnx" ,
119165 input_type = RasterDataType .F32 ,
120- num_input_bands = 2 ,
121166 output_type = RasterDataType .I32 ,
167+ input_shape = MlTensorShape3D (y = 1 , x = 1 , bands = 2 ),
168+ output_shape = MlTensorShape3D (y = 1 , x = 1 , bands = 1 )
122169 ),
123170 display_name = "Decision Tree" ,
124171 description = "A simple decision tree model" ,
0 commit comments