From 58e7bcb28afd749647d06614455ddd2555cf97af Mon Sep 17 00:00:00 2001 From: "Patrick J. McNerthney" Date: Fri, 1 Aug 2025 18:43:24 -1000 Subject: [PATCH] Implement unit test framework with initial unit tests. --- .github/workflows/ci.yaml | 28 +-- README.md | 2 + .../inline/composition.yaml | 2 +- examples/helm-copy-secret/composition.yaml | 51 ++-- examples/helm-copy-secret/definition.yaml | 17 -- examples/helm-copy-secret/xr.yaml | 6 +- .../{inline => single-purpose}/functions.yaml | 0 examples/{inline => single-purpose}/render.sh | 0 examples/{inline => single-purpose}/xr.yaml | 0 function/composite.py | 10 +- function/fn.py | 12 +- function/protobuf.py | 123 ++++++---- pyproject.toml | 13 +- tests/fn_cases/bucket.yaml | 19 ++ tests/fn_cases/inline.yaml | 82 +++++++ tests/test_fn.py | 130 ++++------- tests/utils.py | 219 ++++++++++++++++++ 17 files changed, 517 insertions(+), 197 deletions(-) delete mode 100644 examples/helm-copy-secret/definition.yaml rename examples/{inline => single-purpose}/functions.yaml (100%) rename examples/{inline => single-purpose}/render.sh (100%) rename examples/{inline => single-purpose}/xr.yaml (100%) create mode 100644 tests/fn_cases/bucket.yaml create mode 100644 tests/fn_cases/inline.yaml create mode 100644 tests/utils.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 640e85e..91702c5 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -48,27 +48,27 @@ jobs: # python-version: ${{ env.PYTHON_VERSION }} # - name: Setup Hatch - # run: pipx install hatch==1.7.0 + # run: pipx install hatch==1.14.1 # - name: Lint # run: hatch run lint:check - # unit-test: - # runs-on: ubuntu-24.04 - # steps: - # - name: Checkout - # uses: actions/checkout@v4 + unit-test: + runs-on: ubuntu-24.04 + steps: + - name: Checkout + uses: actions/checkout@v4 - # - name: Setup Python - # uses: actions/setup-python@v5 - # with: - # python-version: ${{ env.PYTHON_VERSION }} + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: ${{ env.PYTHON_VERSION }} - # - name: Setup Hatch - # run: pipx install hatch==1.7.0 + - name: Setup Hatch + run: pipx install hatch==1.14.1 - # - name: Run Unit Tests - # run: hatch run test:unit + - name: Run Unit Tests + run: hatch run test:unit # We want to build most packages for the amd64 and arm64 architectures. To # speed this up we build single-platform packages in parallel. We then upload diff --git a/README.md b/README.md index fc8e592..cae126b 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,8 @@ The following functions are provided to create Protobuf structures: | List | Create a new Protobuf list | | Yaml | Create a new Protobuf structure from a yaml string | | Json | Create a new Protobuf structure from a json string | +| Base64Encode | Encode a string into base 64 | +| Base64Decode | Decode a string from base 64 | The following items are supported in all the Protobuf Message wrapper classes: `bool`, `len`, `contains`, `iter`, `hash`, `==`, `str`, `format` diff --git a/examples/function-go-templating/inline/composition.yaml b/examples/function-go-templating/inline/composition.yaml index 739ca9e..2af167c 100644 --- a/examples/function-go-templating/inline/composition.yaml +++ b/examples/function-go-templating/inline/composition.yaml @@ -23,7 +23,7 @@ spec: user = f"test-user-{ix}" r = self.resources[user]('iam.aws.upbound.io/v1beta1', 'User') r.metadata.labels['testing.upbound.io/example-name'] = user - r.metadata.labels.dummy = r.observed.resource.metadata.labels.dummy or random.choice(['foo', 'bar', 'baz']) + r.metadata.labels.dummy = r.observed.metadata.labels.dummy or random.choice(['foo', 'bar', 'baz']) r = self.resources[f"sample-access-key-{ix}"]('iam.aws.upbound.io/v1beta1', 'AccessKey') r.spec.forProvider.userSelector.matchLabels['testing.upbound.io/example-name'] = user r.spec.writeConnectionSecretToRef.namespace = 'crossplane.system' diff --git a/examples/helm-copy-secret/composition.yaml b/examples/helm-copy-secret/composition.yaml index 07042d0..4a43145 100644 --- a/examples/helm-copy-secret/composition.yaml +++ b/examples/helm-copy-secret/composition.yaml @@ -1,11 +1,11 @@ apiVersion: apiextensions.crossplane.io/v1 kind: Composition metadata: - name: argocds.pythonic.fortra.com + name: xclusters.example.joebowbeer.com spec: compositeTypeRef: - apiVersion: pythonic.fortra.com/v1alpha1 - kind: ArgoCD + apiVersion: example.joebowbeer.com/v1alpha1 + kind: XCluster mode: Pipeline pipeline: - step: pythonic @@ -15,16 +15,41 @@ spec: apiVersion: pythonic.fn.fortra.com/v1alpha1 kind: Composite composite: | + def argcd_secret_config(secret): + config = Map() + config.tlsClientConfig.insecure = True + config.tlsClientConfig.caData = B64Decode(secret.data['certificate-authority']) + config.tlsClientConfig.certData = B64Decode(secret.data['client-certificate']) + config.tlsClientConfig.keyData = B64Decode(secret.data['client-key']) + return config + class Composite(BaseComposite): def compose(self): - argocd = self.resources.Release('helm.crossplane.io/v1beta1', 'Release') - argocd.externalName('argocd') - argocd.spec.forProvider.namespace = 'argocd' - argocd.spec.forProvider.chart.repository = 'https://argoproj.github.io/argo-helm' - argocd.spec.forProvider.chart.name = 'argo-cd' - argocd.spec.forProvider.chart.version = '8.0.7' + name = self.metadata.name + namespace = name + + release = self.resources.release('helm.crossplane.io/v1beta1', 'Release', name=name) + release.spec.rollbackLimit = 1 + release.spec.forProvider.chart.repository = 'https://charts.loft.sh' + release.spec.forProvider.chart.name = 'vcluster' + release.spec.forProvider.chart.version = '0.26.0' + release.spec.forProvider.namespace = namespace + release.spec.forProvider.values.controlPlane.proxy.extraSANs[0] = f'{name}.{namespace}' - # This will work once crossplane-sdk-python is updated to the V2 function api - #secret = self.requireds.Secret('v1', 'Secret', 'argocd', 'argocd-secret')[0] - secret = self.requireds.Secret('v1', 'Secret', labels={'app.kubernetes.io/name':'argocd-secret'})[0] - self.resources.Secret('v1', 'Secret', 'default', 'argocd-secret').data = secret.data + secret_name = f'vc-{name}' + # This will work once crossplane-sdk-python is updated to the v2 function api + #vcluster_secrets = self.requireds.Secret('v1', 'Secret', namespace, secret_name) + vcluster_secrets = self.requireds.Secret('v1', 'Secret', labels={'vcluster-name':name}) + for secret in vcluster_secrets: + if secret.metadata.name != secret_name: + continue + argocd_secret = self.resources.secret('v1', 'Secret', 'argocd', secret_name) + argocd_secret.metadata.labels['argocd.argoproj.io/secret-type'] = 'cluster' + argocd_secret.type = 'Opaque' + argocd_secret.data.name = B64Encode(name) + argocd_secret.data.server = B64Encode(f'https://{name}.{namespace}:443') + argocd_secret.data.config = B64Encode(format(argcd_secret_config(secret), 'json')) + argocd_secret.ready = argocd_secret.observed.data + break + else: + self.ready = False diff --git a/examples/helm-copy-secret/definition.yaml b/examples/helm-copy-secret/definition.yaml deleted file mode 100644 index ef6fd63..0000000 --- a/examples/helm-copy-secret/definition.yaml +++ /dev/null @@ -1,17 +0,0 @@ -apiVersion: apiextensions.crossplane.io/v1 -kind: CompositeResourceDefinition -metadata: - name: argocds.pythonic.fortra.com -spec: - group: pythonic.fortra.com - names: - kind: ArgoCD - plural: argocds - defaultCompositionRef: - name: argocds.pythonic.fortra.com - versions: - - name: v1alpha1 - served: true - referenceable: true - schema: - openAPIV3Schema: {} diff --git a/examples/helm-copy-secret/xr.yaml b/examples/helm-copy-secret/xr.yaml index 0523f27..2c97756 100644 --- a/examples/helm-copy-secret/xr.yaml +++ b/examples/helm-copy-secret/xr.yaml @@ -1,5 +1,5 @@ -apiVersion: pythonic.fortra.com/v1alpha1 -kind: ArgoCD +apiVersion: example.joebowbeer.com/v1alpha1 +kind: XCluster metadata: - name: argocd + name: xc1 spec: {} diff --git a/examples/inline/functions.yaml b/examples/single-purpose/functions.yaml similarity index 100% rename from examples/inline/functions.yaml rename to examples/single-purpose/functions.yaml diff --git a/examples/inline/render.sh b/examples/single-purpose/render.sh similarity index 100% rename from examples/inline/render.sh rename to examples/single-purpose/render.sh diff --git a/examples/inline/xr.yaml b/examples/single-purpose/xr.yaml similarity index 100% rename from examples/inline/xr.yaml rename to examples/single-purpose/xr.yaml diff --git a/function/composite.py b/function/composite.py index cc776fb..3813dce 100644 --- a/function/composite.py +++ b/function/composite.py @@ -54,7 +54,7 @@ def ready(self): def ready(self, ready): if ready: ready = fnv1.Ready.READY_TRUE - elif ready == None or (isinstance(ready, function.protobuf.Values) and ready._type == function.protobuf.Values.Type.UNKNOWN): + elif ready == None or (isinstance(ready, function.protobuf.Values) and ready._isUnknown): ready = fnv1.Ready.READY_UNSPECIFIED else: ready = fnv1.Ready.READY_FALSE @@ -209,7 +209,7 @@ def ready(self): def ready(self, ready): if ready: ready = fnv1.Ready.READY_TRUE - elif ready == None or (isinstance(ready, function.protobuf.Values) and ready._type == function.protobuf.Values.Type.UNKNOWN): + elif ready == None or (isinstance(ready, function.protobuf.Values) and ready._isUnknown): ready = fnv1.Ready.READY_UNSPECIFIED else: ready = fnv1.Ready.READY_FALSE @@ -415,7 +415,7 @@ def claim(self, claim): if bool(self): if claim: self._result.target = fnv1.Target.TARGET_COMPOSITE_AND_CLAIM - elif claim == None or (isinstance(claim, function.protobuf.Values) and claim._type == function.protobuf.Values.Type.UNKNOWN): + elif claim == None or (isinstance(claim, function.protobuf.Values) and claim._isUnknown): self._result.target = fnv1.Target.TARGET_UNSPECIFIED else: self._result.target = fnv1.Target.TARGET_COMPOSITE @@ -502,7 +502,7 @@ def status(self, status): condition.status = fnv1.Status.STATUS_CONDITION_TRUE elif status == None: condition.status = fnv1.Status.STATUS_CONDITION_UNKNOWN - elif isinstance(ready, function.protobuf.Values) and ready._type == function.protobuf.Values.Type.UNKNOWN: + elif isinstance(status, function.protobuf.Values) and status._isUnknown: condition.status = fnv1.Status.STATUS_CONDITION_UNSPECIFIED else: condition.status = fnv1.Status.STATUS_CONDITION_FALSE @@ -549,7 +549,7 @@ def claim(self, claim): condition = self._find_condition(True) if claim: condition.target = fnv1.Target.TARGET_COMPOSITE_AND_CLAIM - elif claim == None or (isinstance(claim, function.protobuf.Values) and claim._type == function.protobuf.Values.Type.UNKNOWN): + elif claim == None or (isinstance(claim, function.protobuf.Values) and claim._isUnknown): condition.target = fnv1.Target.TARGET_UNSPECIFIED else: condition.target = fnv1.Target.TARGET_COMPOSITE diff --git a/function/fn.py b/function/fn.py index f17df69..802afaf 100644 --- a/function/fn.py +++ b/function/fn.py @@ -1,6 +1,7 @@ """A Crossplane composition function.""" import asyncio +import base64 import inspect import grpc @@ -21,15 +22,8 @@ def __init__(self): self.modules = {} async def RunFunction( - self, request: fnv1.RunFunctionRequest, context: grpc.aio.ServicerContext + self, request: fnv1.RunFunctionRequest, _: grpc.aio.ServicerContext ) -> fnv1.RunFunctionResponse: - try: - return await self.run(request, context) - except Exception as e: - self.logger.exception('Error during RunFuction') - raise - - async def run(self, request, context): composite = request.observed.composite.resource logger = self.logger.bind( apiVersion=composite['apiVersion'], @@ -114,3 +108,5 @@ def __init__(self): self.List = function.protobuf.List self.Yaml = function.protobuf.Yaml self.Json = function.protobuf.Json + self.B64Encode = lambda s: base64.b64encode(s.encode('utf-8')).decode('utf-8') + self.B64Decode = lambda s: base64.b64decode(s.encode('utf-8')).decode('utf-8') diff --git a/function/protobuf.py b/function/protobuf.py index 4cb0a6f..ccd64a9 100644 --- a/function/protobuf.py +++ b/function/protobuf.py @@ -311,6 +311,8 @@ def __setitem__(self, key, message): self._messages = self._parent._create_child(self._key) if isinstance(message, Message): message = message._message + if isinstance(message, str) and self._field.type == self._field.TYPE_BYTES: + message = message.encode() self._messages[key] = message self._cache.pop(key, None) @@ -484,8 +486,8 @@ def __getitem__(self, key): self._cache[key] = value return value if isinstance(key, str): - if self._type != self.Type.MAP: - if self._type != self.Type.UNKNOWN: + if not self._isMap: + if not self._isUnknown: raise ValueError('Invalid key, must be a str for maps') self.__dict__['_type'] = self.Type.MAP if self._values is None or key not in self._values: @@ -493,8 +495,8 @@ def __getitem__(self, key): else: struct_value = self._values.fields[key] elif isinstance(key, int): - if self._type != self.Type.LIST: - if self._type != self.Type.UNKNOWN: + if not self._isList: + if not self._isUnknown: raise ValueError('Invalid key, must be an int for lists') self.__dict__['_type'] = self.Type.LIST if self._values is None or key >= len(self._values): @@ -536,24 +538,24 @@ def __len__(self): def __contains__(self, item): if self._values is not None: - if self._type == self.Type.MAP: + if self._isMap: return item in self._values or item in self._unknowns - if self._type == self.Type.LIST: + if self._isList: for value in self: if item == value: return True - if isinstance(item, Values) and item._type == Type.UNKNOWN: + if isinstance(item, Values) and item._isUnknown: return bool(self._unknowns) return False def __iter__(self): if self._values is not None: - if self._type == self.Type.MAP: + if self._isMap: for key in self._values: yield key, self[key] for key in self._unknowns: yield key, self[key] - elif self._type == self.Type.LIST: + elif self._isList: for ix in range(len(self._values)): yield self[ix] for ix in sorted(self._unknowns): @@ -562,9 +564,9 @@ def __iter__(self): def __hash__(self): if self._values is not None: - if self._type == self.Type.MAP: + if self._isMap: return hash(tuple(hash(item) for item in sorted(iter(self), key=lambda item: item[0]))) - if self._type == self.Type.LIST: + if self._isList: return hash(tuple(hash(item) for item in self)) return self._type @@ -579,13 +581,13 @@ def __eq__(self, other): return False if len(self) != len(other): return False - if self._type == self.Type.MAP: + if self._isMap: for key, value in self: if key not in other: return False if value != other[key]: return False - if self._type == self.Type.LIST: + if self._isList: for ix, value in enumerate(self): if value != other[ix]: return False @@ -604,23 +606,23 @@ def _fullName(self): parent = self._parent._fullName if parent: return f"{parent}.{self._key}" - return self._key + return str(self._key) return '' def _create_child(self, key, type): if self._readOnly: raise ValueError(f"{self._readOnly} is read only") if isinstance(key, str): - if self._type != self.Type.MAP: - if self._type != self.Type.UNKNOWN: + if not self._isMap: + if not self._isUnknown: raise ValueError('Invalid key, must be a str for maps') self.__dict__['_type'] = self.Type.MAP if self._values is None: self.__dict__['_values'] = self._parent._create_child(self._key, self._type) struct_value = self._values.fields[key] elif isinstance(key, int): - if self._type != self.Type.LIST: - if self._type != self.Type.UNKNOWN: + if not self._isList: + if not self._isUnknown: raise ValueError('Invalid key, must be an int for lists') self.__dict__['_type'] = self.Type.LIST if self._values is None: @@ -646,8 +648,8 @@ def __call__(self, *args, **kwargs): self._cache.clear() self._unknowns.clear() if len(kwargs): - if self._type != self.Type.MAP: - if self._type != self.Type.UNKNOWN: + if not self._isMap: + if not self._isUnknown: raise ValueError('Cannot specify kwargs on lists') self.__dict__['_type'] = self.Type.MAP if len(args): @@ -658,8 +660,8 @@ def __call__(self, *args, **kwargs): for key, value in kwargs.items(): self[key] = value elif len(args): - if self._type != self.Type.LIST: - if self._type != self.Type.UNKNOWN: + if not self._isList: + if not self._isUnknown: raise ValueError('Cannot specify args on maps') self.__dict__['_type'] = self.Type.LIST if len(kwargs): @@ -670,8 +672,8 @@ def __call__(self, *args, **kwargs): for key in range(len(args)): self[key] = args[key] else: - if self._type != self.Type.MAP: - if self._type != self.Type.UNKNOWN: + if not self._isMap: + if not self._isUnknown: self.__dict__['_type'] = self.Type.MAP # Assume a map is wanted if self._values is None: self.__dict__['_values'] = self._parent._create_child(self._key, self._type) @@ -685,16 +687,16 @@ def __setitem__(self, key, value): if self._readOnly: raise ValueError(f"{self._readOnly} is read only") if isinstance(key, str): - if self._type != self.Type.MAP: - if self._type != self.Type.UNKNOWN: + if not self._isMap: + if not self._isUnknown: raise ValueError('Invalid key, must be a str for maps') self.__dict__['_type'] = self.Type.MAP if self._values is None: self.__dict__['_values'] = self._parent._create_child(self._key, self._type) values = self._values.fields elif isinstance(key, int): - if self._type != self.Type.LIST: - if self._type != self.Type.UNKNOWN: + if not self._isList: + if not self._isUnknown: raise ValueError('Invalid key, must be an int for lists') self.__dict__['_type'] = self.Type.LIST if self._values is None: @@ -727,24 +729,28 @@ def __setitem__(self, key, value): for k, v in enumerate(value): sv[k] = v elif isinstance(value, Values): - if value._type == value.Type.MAP: + if value._isMap: values[key].struct_value.Clear() sv = self[key] for k, v in value: sv[k] = v - elif value._type == value.Type.LIST: + elif value._isList: values[key].list_value.Clear() sv = self[key] for k, v in enumerate(value): sv[k] = v else: self._unknowns.add(key) - if self._type == self.Type.MAP: + if self._isMap: if key in values: del values[key] - else: + elif self._isList: if key < len(values): - del values[key] + values[key].Clear() + for ix in reversed(range(len(values))): + if ix not in self._unknowns: + break + del values[ix] else: raise ValueError('Unexpected type') @@ -755,8 +761,8 @@ def __delitem__(self, key): if self._readOnly: raise ValueError(f"{self._readOnly} is read only") if isinstance(key, str): - if self._type != self.Type.MAP: - if self._type != self.Type.UNKNOWN: + if not self._isMap: + if not self._isUnknown: raise ValueError('Invalid key, must be a str for maps') self.__dict__['_type'] = self.Type.MAP if self._values is not None: @@ -765,27 +771,48 @@ def __delitem__(self, key): self._cache.pop(key, None) self._unknowns.discard(key) elif isinstance(key, int): - if self._type != self.Type.LIST: - if self._type != self.Type.UNKNOWN: + if not self._isList: + if not self._isUnknown: raise ValueError('Invalid key, must be an int for lists') self.__dict__['_type'] = self.Type.LIST if self._values is not None: if key < len(self._values): - self._values.values[key].Clear() + del self._values[key] self._cache.pop(key, None) self._unknowns.discard(key) + for ix in sorted([ix in self._unknowns]): + if ix > key: + self._cache.pop(ix, None) + self._unknowns.add(ix - 1) + self._unknowns.disacard(ix) + for ix in reversed(range(len(self._values))): + if ix not in self._unknowns: + break + del self._values[ix] else: raise ValueError('Unexpected key type') + @property + def _isUnknown(self): + return self._type == self.Type.UNKNOWN + + @property + def _isMap(self): + return self._type == self.Type.MAP + + @property + def _isList(self): + return self._type == self.Type.LIST + @property def _hasUnknowns(self): if self._unknowns: return True - if self._type == self.Type.MAP: + if self._isMap: for key, value in self: if isinstance(value, Values) and value._hasUnknowns: return True - elif self._type == self.Type.LIST: + elif self._isList: for value in self: if isinstance(value, Values) and value._hasUnknowns: return True @@ -794,13 +821,13 @@ def _hasUnknowns(self): def _patchUnknowns(self, patches): for key in [key for key in self._unknowns]: self[key] = patches[key] - if self._type == self.Type.MAP: + if self._isMap: for key, value in self: if isinstance(value, Values) and len(value): patch = patches[key] if isinstance(patch, Values) and patch._type == value._type and len(patch): value._patchUnknowns(patch) - elif self._type == self.Type.LIST: + elif self._isList: for ix, value in enumerate(self): if isinstance(value, Values) and len(value): patch = patches[ix] @@ -835,11 +862,11 @@ def default(self, object): return [value for value in object] return None if isinstance(object, Values): - if object._type == Values.Type.MAP: + if object._isMap: return {key: value for key, value in object} - if object._type == Values.Type.LIST: + if object._isList: return [value for value in object] - if object._type == Values.Type.UNKNOWN: + if object._isUnknown: return '<>' return '<>' if isinstance(object, datetime.datetime): @@ -859,11 +886,11 @@ def represent_message_list(self, messages): return self.represent_list([value for value in messages]) def represent_values(self, values): - if values._type == Values.Type.MAP: + if values._isMap: return self.represent_dict({key: value for key, value in values}) - if values._type == Values.Type.LIST: + if values._isList: return self.represent_list([value for value in values]) - if values._type == Values.Type.UNKNOWN: + if values._isUnknown: return self.represent_scalar('tag:yaml.org,2002:str', '<>') return self.represent_scalar('tag:yaml.org,2002:str', '<>') diff --git a/pyproject.toml b/pyproject.toml index e005c1e..e028a3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,25 +43,20 @@ validate-bump = false # Allow going from 0.0.0.dev0+x to 0.0.0.dev0+y type = "virtual" path = ".venv-default" dependencies = ["ipython==9.1.0"] - -[tool.hatch.envs.default.scripts] -development = "python function/main.py --insecure --debug" +scripts = { development = "python function/main.py --insecure --debug" } [tool.hatch.envs.lint] type = "virtual" detached = true path = ".venv-lint" dependencies = ["ruff==0.11.5"] - -[tool.hatch.envs.lint.scripts] -check = "ruff check function tests && ruff format --diff function tests" +scripts = { check = "ruff check function tests && ruff format --diff function tests" } [tool.hatch.envs.test] type = "virtual" path = ".venv-test" - -[tool.hatch.envs.test.scripts] -unit = "python -m unittest tests/*.py" +dependencies = ["pytest==8.4.1", "pytest-asyncio==1.1.0"] +scripts = { unit = "python -m pytest tests" } [tool.ruff] target-version = "py311" diff --git a/tests/fn_cases/bucket.yaml b/tests/fn_cases/bucket.yaml new file mode 100644 index 0000000..93841d3 --- /dev/null +++ b/tests/fn_cases/bucket.yaml @@ -0,0 +1,19 @@ +request: + input: + composite: | + class Composite(BaseComposite): + def compose(self): + self.resources.bucket.apiVersion = 's3.aws.upbound.io/v1beta2' + self.resources.bucket.kind = 'Bucket' + self.resources.bucket.spec.forProvider.region = 'us-east-1' + +response: + desired: + resources: + bucket: + resource: + apiVersion: s3.aws.upbound.io/v1beta2 + kind: Bucket + spec: + forProvider: + region: us-east-1 diff --git a/tests/fn_cases/inline.yaml b/tests/fn_cases/inline.yaml new file mode 100644 index 0000000..59bd195 --- /dev/null +++ b/tests/fn_cases/inline.yaml @@ -0,0 +1,82 @@ +request: + observed: + composite: + resource: + spec: + count: 2 + resources: + sample-access-key-0: + connection_details: + username: test-user-0 + password: hxt1nwu9vmt@qtb3HVW + input: + composite: | + class Composite(BaseComposite): + def compose(self): + for ix in range(self.spec.count): + user = f"test-user-{ix}" + r = self.resources[user]('iam.aws.upbound.io/v1beta1', 'User') + r.metadata.labels['testing.upbound.io/example-name'] = user + r.metadata.labels.dummy = r.observed.metadata.labels.dummy or ['foo', 'bar', 'baz'][ix] + r = self.resources[f"sample-access-key-{ix}"]('iam.aws.upbound.io/v1beta1', 'AccessKey') + r.spec.forProvider.userSelector.matchLabels['testing.upbound.io/example-name'] = user + r.spec.writeConnectionSecretToRef.namespace = 'crossplane.system' + r.spec.writeConnectionSecretToRef.name = f"sample-access-key-secret-{ix}" + connection = self.resources['sample-access-key-0'].connection + if connection: + self.connection.url = 'http://www.example.com' + self.connection.username = connection.username + self.connection.password = connection.password + self.status.dummy = 'cool-status' + +response: + desired: + composite: + resource: + status: + dummy: cool-status + connection_details: + url: http://www.example.com + username: test-user-0 + password: hxt1nwu9vmt@qtb3HVW + resources: + test-user-0: + resource: + apiVersion: iam.aws.upbound.io/v1beta1 + kind: User + metadata: + labels: + testing.upbound.io/example-name: test-user-0 + dummy: foo + sample-access-key-0: + resource: + apiVersion: iam.aws.upbound.io/v1beta1 + kind: AccessKey + spec: + forProvider: + userSelector: + matchLabels: + testing.upbound.io/example-name: test-user-0 + writeConnectionSecretToRef: + namespace: crossplane.system + name: sample-access-key-secret-0 + test-user-1: + resource: + apiVersion: iam.aws.upbound.io/v1beta1 + kind: User + metadata: + labels: + testing.upbound.io/example-name: test-user-1 + dummy: bar + sample-access-key-1: + resource: + apiVersion: iam.aws.upbound.io/v1beta1 + kind: AccessKey + spec: + forProvider: + userSelector: + matchLabels: + testing.upbound.io/example-name: test-user-1 + writeConnectionSecretToRef: + namespace: crossplane.system + name: sample-access-key-secret-1 diff --git a/tests/test_fn.py b/tests/test_fn.py index d0909fa..641b5c9 100644 --- a/tests/test_fn.py +++ b/tests/test_fn.py @@ -1,85 +1,57 @@ -import dataclasses -import unittest -from crossplane.function import logging, resource +import warnings +warnings.filterwarnings('ignore', module='^google[.]protobuf[.]runtime_version$', lineno=98) + +import pathlib +import pytest +import yaml from crossplane.function.proto.v1 import run_function_pb2 as fnv1 -from google.protobuf import duration_pb2 as durationpb from google.protobuf import json_format -from google.protobuf import struct_pb2 as structpb from function import fn - -composite = """ -class Composite(BaseComposite): - def compose(self): - self.resources.bucket.apiVersion = 's3.aws.upbound.io/v1beta2' - self.resources.bucket.kind = 'Bucket' - self.resources.bucket.spec.forProvider.region = 'us-east-1' -""" - - -class TestFunctionRunner(unittest.IsolatedAsyncioTestCase): - def setUp(self) -> None: - # Allow larger diffs, since we diff large strings of JSON. - self.maxDiff = 2000 - logging.configure(level=logging.Level.DISABLED) - - async def test_run_function(self) -> None: - @dataclasses.dataclass - class TestCase: - reason: str - req: fnv1.RunFunctionRequest - want: fnv1.RunFunctionResponse - - cases = [ - TestCase( - reason="The function should return the input as a result.", - req=fnv1.RunFunctionRequest( - observed=fnv1.State( - composite=fnv1.Resource( - resource={ - 'apiVersion': 'unittest.crossplane.io/v1beta1', - 'kind': 'XR', - 'metadata': { - 'name': 'test', - }, - }, - ), - ), - input={"composite": composite} - ), - want=fnv1.RunFunctionResponse( - meta=fnv1.ResponseMeta(ttl=durationpb.Duration(seconds=60)), - desired=fnv1.State( - resources={ - "bucket": fnv1.Resource( - resource={ - "apiVersion": "s3.aws.upbound.io/v1beta2", - "kind": "Bucket", - "spec": { - "forProvider": { - "region": "us-east-1", - }, - }, - }, - ), - }, - ), - context={}, - ), +from tests import utils + + +@pytest.mark.parametrize( + 'fn_case', + [ + path + for path in (pathlib.Path(__file__).parent / 'fn_cases').iterdir() + if path.is_file() and path.suffix == '.yaml' + ], +) +@pytest.mark.asyncio +async def test_run_function(fn_case): + test = yaml.safe_load(fn_case.read_text()) + + request = fnv1.RunFunctionRequest( + observed=fnv1.State( + composite=fnv1.Resource( + resource={ + 'apiVersion': 'pythonic.fortra.com/v1alpha1', + 'kind': 'PyTest', + 'metadata': { + 'name': fn_case.stem, + }, + }, ), - ] - - runner = fn.FunctionRunner() - - for case in cases: - got = await runner.RunFunction(case.req, None) - self.assertEqual( - json_format.MessageToDict(case.want), - json_format.MessageToDict(got), - "-want, +got", - ) - - -if __name__ == "__main__": - unittest.main() + ), + ) + utils.message_merge(request, test['request']) + utils.map_defaults(test['response'], { + 'meta': { + 'ttl': { + 'seconds': 60, + }, + }, + 'context': {} + }) + + response = utils.message_dict( + await fn.FunctionRunner().RunFunction(request, None) + ) + + #print(yaml.dump(response)) + #assert False + + assert response == test['response'] diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..e7b1032 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,219 @@ + +from google.protobuf.struct_pb2 import Struct, ListValue + + +def message_merge(message, values): + for field, value in values.items(): + if isinstance(value, (dict, list, tuple)): + current = getattr(message, field) + if isinstance(value, dict): + if isinstance(current, Struct): + map_merge(current, value) + else: + descriptor = message.DESCRIPTOR.fields_by_name.get(field) + if descriptor.label == descriptor.LABEL_REPEATED: + if descriptor.message_type.GetOptions().map_entry: + message_map_merge(descriptor, current, value) + else: + message_list_merge(current, value) + else: + message_merge(current, value) + else: + list_merge(current, value) + continue + setattr(message, field, value) + +def message_map_merge(descriptor, message, values): + descriptor = descriptor.message_type.fields_by_name['value'] + for field, value in values.items(): + if isinstance(value, (dict, list, tuple)): + current = message[field] + if isinstance(current, Struct): + map_merge(current, value) + else: + message_merge(current, value) + else: + if isinstance(value, str) and descriptor.type == descriptor.TYPE_BYTES: + value = value.encode() + message[field] = value + +def message_list_merge(message, values): + for ix, value in enumerate(values): + if ix < len(message): + message_merge(message[ix], value) + else: + message_merge(message.add(), value) + +def map_merge(message, values): + for field, value in values.items(): + if isinstance(value, (dict, list, tuple)): + if field in message: + current = message[field] + if isinstance(value, dict): + map_merge(current, value) + else: + list_merge(current, value) + continue + message[field] = value + +def list_merge(message, values): + for ix, value in enumerate(values): + if ix < len(message): + if isinstance(value, (dict, list, tuple)): + current = message[ix] + if isinstance(value, dict): + map_merge(current, value) + else: + list_merge(current, value) + else: + message[ix] = value + else: + message.append(value) + + +def message_defaults(message, defaults): + for field, value in defaults.items(): + if isinstance(value, (dict, list, tuple)): + #if field in message: + current = getattr(message, field) + if isinstance(value, dict): + if isinstance(current, Struct): + map_defaults(current, value) + else: + descriptor = message.DESCRIPTOR.fields_by_name.get(field) + if descriptor.label == descriptor.LABEL_REPEATED: + if descriptor.message_type.GetOptions().map_entry: + message_map_defaults(descriptor, current, value) + else: + message_list_defaults(current, value) + else: + message_defaults(current, value) + else: + list_defaults(current, value) + else: + setattr(message, field, value) + +def message_map_defaults(descriptor, message, values): + descriptor = descriptor.message_type.fields_by_name['value'] + for field, value in values.items(): + if isinstance(value, (dict, list, tuple)): + current = message[field] + if isinstance(current, Struct): + map_defaults(current, value) + else: + message_defaults(current, value) + else: + if field not in message: + message[field] = value + +def message_list_defaults(message, values): + for ix, value in enumerate(values): + if ix < len(message): + message_defaults(message[ix], value) + else: + message_defaults(message.add(), value) + +def map_defaults(message, defaults): + for field, value in defaults.items(): + current = message.get(field, None) + if isinstance(value, (dict, list, tuple)): + if current is not None: + if isinstance(value, dict): + map_defaults(current, value) + else: + list_defaults(current, value) + continue + if current is None: + message[field] = value + +def list_defaults(message, defaults): + for ix, value in enumerate(defaults): + if ix < len(message): + if isinstance(value, (dict, list, tuple)): + current = message[ix] + if isinstance(value, dict): + map_defaults(current, value) + else: + list_defaults(current, value) + else: + message.append(value) + + +def message_dict(message): + result = {} + for field, value in message.ListFields(): + if field.type == field.TYPE_MESSAGE: + if field.message_type.name == 'Struct': + value = map_dict(value) + elif field.message_type.name == 'ListValue': + value = list_list(value) + elif field.label == field.LABEL_REPEATED: + if field.message_type.GetOptions().map_entry: + value = message_map_dict(field, value) + else: + value = message_list_list(field, value) + else: + value = message_dict(value) + result[field.name] = value + return result + +def message_map_dict(descriptor, message): + descriptor = descriptor.message_type.fields_by_name['value'] + result = {} + for field, value in message.items(): + if descriptor.type == descriptor.TYPE_MESSAGE: + if descriptor.message_type.name == 'Struct': + value = map_dict(value) + elif descriptor.message_type.name == 'ListValue': + value = list_list(value) + elif descriptor.label == descriptor.LABEL_REPEATED: + if descriptor.message_type.GetOptions().map_entry: + value = message_map_dict(value) + else: + value = message_list_list(value) + else: + value = message_dict(value) + elif descriptor.type == descriptor.TYPE_BYTES: + value = value.decode() + result[field] = value + return result + +def message_list_list(descriptor, message): + result = [] + for value in message: + if descriptor.type == descrptor.TYPE_MESSAGE: + if descriptor.message_type.name == 'Struct': + value = map_dict(value) + elif descriptor.message_type.name == 'ListValue': + value = list_list(value) + elif descriptor.label == descriptor.LABEL_REPEATED: + if descriptor.message_type.GetOptions().map_entry: + value = message_map_dict(value) + else: + value = message_list_list(value) + else: + value = message_dict(value) + elif descriptor.type == descriptor.TYPE_BYTES: + value = value.decode() + result.append(value) + return result + +def map_dict(message): + result = {} + for field, value in message.items(): + if isinstance(value, Struct): + value = map_dict(value) + elif isinstance(value, ListValue): + value = list_list(value) + result[field] = value + return result + +def list_list(message): + result = [] + for value in message: + if isinstance(value, Struct): + value = map_dict(value) + elif isinstance(value, ListValue): + value = list_list(value) + result.append(value) + return result