diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml new file mode 100644 index 0000000..05d0497 --- /dev/null +++ b/.github/workflows/unittest.yaml @@ -0,0 +1,30 @@ +name: Python Unit Test + +on: [push] + +jobs: + build: + name: GitHub Action for Pytest + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8] + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + + - name: Test with unittest + run: | + pytest ./odata/tests diff --git a/odata/connection.py b/odata/connection.py index 995557c..2d3b9e6 100644 --- a/odata/connection.py +++ b/odata/connection.py @@ -24,7 +24,6 @@ def inner(*args, **kwargs): class ODataConnection(object): base_headers = { - 'Accept': 'application/json', 'OData-Version': '4.0', 'User-Agent': 'python-odata {0}'.format(version), } @@ -38,6 +37,9 @@ def __init__(self, session=None, auth=None): self.auth = auth self.log = logging.getLogger('odata.connection') + def __del__(self): + self.session.close() + def _apply_options(self, kwargs): kwargs['timeout'] = self.timeout @@ -75,6 +77,7 @@ def _handle_odata_error(self, response): response_ct = response.headers.get('content-type', '') if 'application/json' in response_ct: + self.log.debug(u'JSON: {0}'.format(response.json())) errordata = response.json() if 'error' in errordata: @@ -88,6 +91,9 @@ def _handle_odata_error(self, response): ie = odata_error['innererror'] detailed_message = ie.get('message') or detailed_message + response.close() + self.log.info(u'Closed response on failure with HTTP status {0}'.format(code)) + msg = ' | '.join([status_code, code, message, detailed_message]) err = ODataError(msg) err.status_code = status_code @@ -97,64 +103,104 @@ def _handle_odata_error(self, response): raise err def execute_get(self, url, params=None): - headers = {} - headers.update(self.base_headers) - - self.log.info(u'GET {0}'.format(url)) - if params: - self.log.info(u'Query: {0}'.format(params)) - - response = self._do_get(url, params=params, headers=headers) - self._handle_odata_error(response) - response_ct = response.headers.get('content-type', '') - if response.status_code == requests.codes.no_content: - return - if 'application/json' in response_ct: - data = response.json() - return data - else: - msg = u'Unsupported response Content-Type: {0}'.format(response_ct) - raise ODataError(msg) + try: + response = None + headers = {} + headers.update(self.base_headers) + + self.log.info(u'GET {0}'.format(url)) + if params: + self.log.info(u'Query: {0}'.format(params)) + + response = self._do_get(url, params=params, headers=headers) + self._handle_odata_error(response) + response_ct = response.headers.get('content-type', '') + if response.status_code == requests.codes.no_content: + return + if 'application/json' in response_ct: + self.log.debug(u'JSON: {0}'.format(response.json())) + return response.json() + else: + msg = u'Unsupported response Content-Type: {0}'.format(response_ct) + raise ODataError(msg) + except: + raise + finally: + if response: + response.close() + self.log.info(u'Closed GET response for {0}'.format(url)) def execute_post(self, url, data, params=None): - headers = { - 'Content-Type': 'application/json', - } - headers.update(self.base_headers) + try: + response = None + headers = { + 'Content-Type': 'application/json', + } + headers.update(self.base_headers) - data = json.dumps(data) + data = json.dumps(data) - self.log.info(u'POST {0}'.format(url)) - self.log.info(u'Payload: {0}'.format(data)) + self.log.info(u'POST {0}'.format(url)) + self.log.info(u'Payload: {0}'.format(data)) - response = self._do_post(url, data=data, headers=headers, params=params) - self._handle_odata_error(response) - response_ct = response.headers.get('content-type', '') - if response.status_code == requests.codes.no_content: - return - if 'application/json' in response_ct: - return response.json() - # no exceptions here, POSTing to Actions may not return data + response = self._do_post(url, data=data, headers=headers, params=params) + self._handle_odata_error(response) + response_ct = response.headers.get('content-type', '') + if response.status_code == requests.codes.no_content: + return + if 'application/json' in response_ct: + self.log.debug(u'JSON: {0}'.format(response.json())) + return response.json() + # no exceptions here, POSTing to Actions may not return data + except: + raise + finally: + if response: + response.close() + self.log.info(u'Closed POST response for {0}'.format(url)) def execute_patch(self, url, data): - headers = { - 'Content-Type': 'application/json', - } - headers.update(self.base_headers) + try: + response = None + headers = { + 'Content-Type': 'application/json', + } + headers.update(self.base_headers) - data = json.dumps(data) + data = json.dumps(data) - self.log.info(u'PATCH {0}'.format(url)) - self.log.info(u'Payload: {0}'.format(data)) + self.log.info(u'PATCH {0}'.format(url)) + self.log.info(u'Payload: {0}'.format(data)) - response = self._do_patch(url, data=data, headers=headers) - self._handle_odata_error(response) + response = self._do_patch(url, data=data, headers=headers) + self._handle_odata_error(response) + response_ct = response.headers.get('content-type', '') + if 'application/json' in response_ct: + self.log.debug(u'JSON: {0}'.format(response.json())) + return response.json() + except: + raise + finally: + if response: + response.close() + self.log.info(u'Closed PATCH response for {0}'.format(url)) def execute_delete(self, url): - headers = {} - headers.update(self.base_headers) + try: + response = None + headers = {} + headers.update(self.base_headers) - self.log.info(u'DELETE {0}'.format(url)) + self.log.info(u'DELETE {0}'.format(url)) - response = self._do_delete(url, headers=headers) - self._handle_odata_error(response) + response = self._do_delete(url, headers=headers) + self._handle_odata_error(response) + response_ct = response.headers.get('content-type', '') + if 'application/json' in response_ct: + self.log.debug(u'JSON: {0}'.format(response.json())) + except: + raise + finally: + if response: + response.close() + self.log.info(u'Closed DELETE response for {0}'.format(url)) diff --git a/odata/context.py b/odata/context.py index d85e47d..38407e4 100644 --- a/odata/context.py +++ b/odata/context.py @@ -13,8 +13,8 @@ def __init__(self, session=None, auth=None): self.log = logging.getLogger('odata.context') self.connection = ODataConnection(session=session, auth=auth) - def query(self, entitycls): - q = Query(entitycls, connection=self.connection) + def query(self, entitycls, options=None): + q = Query(entitycls, connection=self.connection, options=options) return q def call(self, action_or_function, **parameters): @@ -44,11 +44,29 @@ def delete(self, entity): :type entity: EntityBase :raises ODataConnectionError: Delete not allowed or a serverside error. Server returned an HTTP error code """ - self.log.info(u'Deleting entity: {0}'.format(entity)) + self.log.debug(u'Deleting entity: {0}'.format(entity)) url = entity.__odata__.instance_url self.connection.execute_delete(url) entity.__odata__.persisted = False - self.log.info(u'Success') + entity.__odata__.persisted_id = None + self.log.debug(u'Success') + + def get(self, entity): + """ + Creates a GET call to the service, fetching the entity + + :type entity: EntityBase + """ + self.log.debug(u'Fetching entity: {0}'.format(entity)) + url = entity.__odata__.instance_url + data = self.connection.execute_get(url) + entity.__odata__.reset() + if data is not None: + entity.__odata__.update(data) + entity.__odata__.persisted = True + entity.__odata__.persisted_id = entity.__odata__.id + self.log.debug(u'Success') + return entity def save(self, entity, force_refresh=True): """ @@ -76,12 +94,16 @@ def _insert_new(self, entity): :type entity: EntityBase """ - url = entity.__odata_url__() + + if entity.__odata__.odata_scope: + url = entity.__odata__.odata_scope + else: + url = entity.__odata_url__() if url is None: msg = 'Cannot insert Entity that does not belong to EntitySet: {0}'.format(entity) raise ODataError(msg) - self.log.info(u'Saving new entity') + self.log.debug(u'Saving new entity') es = entity.__odata__ insert_data = es.data_for_insert() @@ -92,8 +114,9 @@ def _insert_new(self, entity): if saved_data is not None: es.update(saved_data) + es.persisted_id = es.id - self.log.info(u'Success') + self.log.debug(u'Success') def _update_existing(self, entity, force_refresh=True): """ @@ -112,7 +135,7 @@ def _update_existing(self, entity, force_refresh=True): self.log.debug(u'Nothing to update: {0}'.format(entity)) return - self.log.info(u'Updating existing entity: {0}'.format(entity)) + self.log.debug(u'Updating existing entity: {0}'.format(entity)) url = es.instance_url @@ -120,10 +143,10 @@ def _update_existing(self, entity, force_refresh=True): es.reset() if saved_data is None and force_refresh: - self.log.info(u'Reloading entity from service') + self.log.debug(u'Reloading entity from service') saved_data = self.connection.execute_get(url) if saved_data is not None: entity.__odata__.update(saved_data) - self.log.info(u'Success') + self.log.debug(u'Success') diff --git a/odata/entity.py b/odata/entity.py index feab6bb..8388bde 100644 --- a/odata/entity.py +++ b/odata/entity.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from typing import Dict """ Entity classes @@ -92,36 +93,80 @@ class EntityBase(object): __odata_type__ = 'ODataSchema.Entity' __odata_singleton__ = False __odata_schema__ = None + __odata_scope__ = None @classmethod def __odata_url__(cls): # used by Query if cls.__odata_collection__: - return urljoin(cls.__odata_service__.url, cls.__odata_collection__) + if cls.__odata_scope__: + if callable(cls.__odata_scope__): + return "/".join([ + urljoin( + cls.__odata_service__.url, + cls.__odata_scope__() + ), + cls.__odata_collection__ + + ]) + else: + return "/".join([ + urljoin( + cls.__odata_service__.url, + cls.__odata_scope__ + ), + cls.__odata_collection__ + + ]) + else: + return urljoin(cls.__odata_service__.url, cls.__odata_collection__) def __new__(cls, *args, **kwargs): i = super(EntityBase, cls).__new__(cls) i.__odata__ = es = EntityState(i) + if len(args) > 0: + data = args[0] + else: + data = {} if 'from_data' in kwargs: raw_data = kwargs.pop('from_data') - # check for values from $expand - for prop_name, prop in es.navigation_properties: - if prop.name in raw_data: - expanded_data = raw_data.pop(prop.name) - if prop.is_collection: - es.nav_cache[prop.name] = dict(collection=prop.instances_from_data(expanded_data)) - else: - es.nav_cache[prop.name] = dict(single=prop.instances_from_data(expanded_data)) + if raw_data: + for k, v in raw_data.items(): + if '@odata' in k: + i.__odata__[k] = v - for prop_name, prop in es.properties: - i.__odata__[prop.name] = raw_data.get(prop.name) + # check for values from $expand + for prop_name, prop in es.navigation_properties: + if prop.name in raw_data: + expanded_data = raw_data.pop(prop.name) + if prop.is_collection: + es.nav_cache[prop.name] = dict(collection=prop.instances_from_data(expanded_data)) + else: + es.nav_cache[prop.name] = dict(single=prop.instances_from_data(expanded_data)) + + for prop_name, prop in es.properties: + i.__odata__[prop.name] = raw_data.get(prop.name) i.__odata__.persisted = True + i.__odata__.persisted_id = i.__odata__.id else: + for prop_name, prop in es.navigation_properties: + if prop_name in data.keys(): + if prop.is_collection: + es.nav_cache[prop.name] = dict(collection=prop.instances_from_data(data[prop_name])) + else: + if isinstance(data[prop_name], Dict): + es.nav_cache[prop.name] = dict(single=prop.instances_from_data(data[prop_name])) + else: + es.nav_cache[prop.name] = dict(single=data[prop_name]) + for prop_name, prop in es.properties: - i.__odata__[prop.name] = None + if prop_name in data.keys(): + i.__odata__[prop.name] = data[prop_name] + else: + i.__odata__[prop.name] = None return i diff --git a/odata/metadata.py b/odata/metadata.py index 5b76be8..ec22ee0 100644 --- a/odata/metadata.py +++ b/odata/metadata.py @@ -20,6 +20,8 @@ class MetaData(object): + cached_entity_sets = {} + cached_metadata = {} log = logging.getLogger('odata.metadata') namespaces = { 'edm': 'http://docs.oasis-open.org/odata/ns/edm', @@ -45,6 +47,11 @@ def __init__(self, service): self.connection = service.default_context.connection self.service = service + @classmethod + def flush_cache(cls): + cls.cached_entity_sets = {} + cls.cached_metadata = {} + def property_type_to_python(self, edm_type): return self.property_types.get(edm_type, StringProperty) @@ -210,7 +217,12 @@ def _create_functions(self, all_types, functions, get_entity_or_prop_from_type): self.service.functions[function['name']] = function_class() def get_entity_sets(self, base=None): - document = self.load_document() + if base not in MetaData.cached_entity_sets: + MetaData.cached_entity_sets[base] = self._get_entity_sets(base) + return MetaData.cached_entity_sets[base] + + def _get_entity_sets(self, base=None): + document = self.load_document(self.url) schemas, entity_sets, actions, functions = self.parse_document(document) base_class = base or declarative_base() @@ -249,10 +261,12 @@ def get_entity_or_prop_from_type(typename): self.log.info('Loaded {0} entity sets, total {1} types'.format(len(sets), len(all_types))) return base_class, sets, all_types - def load_document(self): - self.log.info('Loading metadata document: {0}'.format(self.url)) - response = self.connection._do_get(self.url) - return ET.fromstring(response.content) + def load_document(self, url): + if url not in MetaData.cached_metadata: + self.log.info('Loading metadata document: {0}'.format(url)) + response = self.connection._do_get(url) + MetaData.cached_metadata[url] = ET.fromstring(response.content) + return MetaData.cached_metadata[url] def _parse_action(self, xmlq, action_element, schema_name): action = { diff --git a/odata/navproperty.py b/odata/navproperty.py index 4c5527c..2fae216 100644 --- a/odata/navproperty.py +++ b/odata/navproperty.py @@ -47,6 +47,7 @@ def __init__(self, name, entitycls, collection=False, foreign_key=None): self.name = name self.entitycls = entitycls self.is_collection = collection + self.is_computed_value = False if isinstance(foreign_key, PropertyBase): self.foreign_key = foreign_key.name else: @@ -110,6 +111,7 @@ def __get__(self, instance, owner): cache['collection'] = self.instances_from_data(raw_data['value']) else: cache['collection'] = [] + [c.__odata__.set_scope(url) for c in cache['collection'] if c] return cache['collection'] else: if 'single' not in cache: @@ -118,4 +120,6 @@ def __get__(self, instance, owner): cache['single'] = self.instances_from_data(raw_data) else: cache['single'] = None + if cache['single']: + cache['single'].__odata__.set_scope(url) return cache['single'] diff --git a/odata/query.py b/odata/query.py index 4d276ad..d0aa6ca 100644 --- a/odata/query.py +++ b/odata/query.py @@ -61,6 +61,7 @@ class Query(object): def __init__(self, entitycls, connection=None, options=None): self.entity = entitycls self.options = options or dict() + self.default_opts = options self.connection = connection def __iter__(self): @@ -98,7 +99,7 @@ def _get_options(self): Format current query options to a dict that can be passed to requests :return: Dictionary """ - options = dict() + options = self.options _top = self.options.get('$top') if _top is not None: @@ -149,7 +150,7 @@ def _new_query(self): :return: Query instance """ - o = dict() + o = self.default_opts or dict() o['$top'] = self.options.get('$top', None) o['$skip'] = self.options.get('$skip', None) o['$select'] = self.options.get('$select', [])[:] diff --git a/odata/service.py b/odata/service.py index 6238462..a07b9b9 100644 --- a/odata/service.py +++ b/odata/service.py @@ -158,6 +158,10 @@ def create_context(self, auth=None, session=None): """ return Context(auth=auth, session=session) + @classmethod + def flush_cache(cls): + MetaData.flush_cache() + def describe(self, entity): """ Print a debug screen of an entity instance @@ -170,14 +174,22 @@ def is_entity_saved(self, entity): """Returns boolean indicating entity's status""" return self.default_context.is_entity_saved(entity) - def query(self, entitycls): + def query(self, entitycls, options=None): """ Start a new query for given entity class :param entitycls: Entity to query :return: Query object """ - return self.default_context.query(entitycls) + return self.default_context.query(entitycls, options=options) + + def get(self, entity): + """ + Creates a GET call to the service, fetching the entity + + :type entity: EntityBase + """ + return self.default_context.get(entity) def delete(self, entity): """ diff --git a/odata/state.py b/odata/state.py index c493c7e..ce70e75 100644 --- a/odata/state.py +++ b/odata/state.py @@ -1,213 +1,303 @@ -# -*- coding: utf-8 -*- - -from __future__ import print_function -import os -import inspect -from collections import OrderedDict - -from odata.property import PropertyBase, NavigationProperty - - -class EntityState(object): - - def __init__(self, entity): - """:type entity: EntityBase """ - self.entity = entity - self.dirty = [] - self.nav_cache = {} - self.data = {} - self.connection = None - # does this object exist serverside - self.persisted = False - - # dictionary access - def __getitem__(self, item): - return self.data[item] - - def __setitem__(self, key, value): - self.data[key] = value - - def __contains__(self, item): - return item in self.data - - def get(self, key, default): - return self.data.get(key, default=default) - - def update(self, other): - self.data.update(other) - # /dictionary access - - def __repr__(self): - return self.data.__repr__() - - def describe(self): - rows = [ - u'EntitySet: {0}'.format(self.entity.__odata_collection__), - u'Type: {0}'.format(self.entity.__odata_type__), - u'URL: {0}'.format(self.instance_url or self.entity.__odata_url__()), - u'', - u'Properties', - u'-' * 40, - ] - - for _, prop in self.properties: - name = prop.name - if prop.primary_key: - name += '*' - if prop.name in self.dirty: - name += ' (dirty)' - rows.append(name) - - rows.append(u'') - rows.append(u'Navigation Properties') - rows.append(u'-' * 40) - - for _, prop in self.navigation_properties: - rows.append(prop.name) - - rows = os.linesep.join(rows) - print(rows) - - def reset(self): - self.dirty = [] - self.nav_cache = {} - - @property - def id(self): - ids = [] - entity_name = self.entity.__odata_collection__ - if entity_name is None: - return - - for prop_name, prop in self.primary_key_properties: - value = self.data.get(prop.name) - if value: - ids.append((prop, str(prop.escape_value(value)))) - if len(ids) == 1: - key_value = ids[0][1] - return u'{0}({1})'.format(entity_name, - key_value) - if len(ids) > 1: - key_ids = [] - for prop, key_value in ids: - key_ids.append('{0}={1}'.format(prop.name, key_value)) - return u'{0}({1})'.format(entity_name, ','.join(key_ids)) - - @property - def instance_url(self): - if self.id: - return self.entity.__odata_url_base__ + self.id - - @property - def properties(self): - props = [] - cls = self.entity.__class__ - for key, value in inspect.getmembers(cls): - if isinstance(value, PropertyBase): - props.append((key, value)) - return props - - @property - def primary_key_properties(self): - pks = [] - for prop_name, prop in self.properties: - if prop.primary_key is True: - pks.append((prop_name, prop)) - return pks - - @property - def navigation_properties(self): - props = [] - cls = self.entity.__class__ - for key, value in inspect.getmembers(cls): - if isinstance(value, NavigationProperty): - props.append((key, value)) - return props - - @property - def dirty_properties(self): - rv = [] - for prop_name, prop in self.properties: - if prop.name in self.dirty: - rv.append((prop_name, prop)) - return rv - - def set_property_dirty(self, prop): - if prop.name not in self.dirty: - self.dirty.append(prop.name) - - def data_for_insert(self): - return self._clean_new_entity(self.entity) - - def data_for_update(self): - update_data = OrderedDict() - update_data['@odata.type'] = self.entity.__odata_type__ - - for _, prop in self.dirty_properties: - if prop.is_computed_value: - continue - - update_data[prop.name] = self.data[prop.name] - - for prop_name, prop in self.navigation_properties: - if prop.name in self.dirty: - value = getattr(self.entity, prop_name, None) # get the related object - """:type : None | odata.entity.EntityBase | list[odata.entity.EntityBase]""" - if value is not None: - key = '{0}@odata.bind'.format(prop.name) - if prop.is_collection: - update_data[key] = [i.__odata__.id for i in value] - else: - update_data[key] = value.__odata__.id - return update_data - - def _clean_new_entity(self, entity): - """:type entity: odata.entity.EntityBase """ - insert_data = OrderedDict() - insert_data['@odata.type'] = entity.__odata_type__ - - es = entity.__odata__ - for _, prop in es.properties: - if prop.is_computed_value: - continue - - insert_data[prop.name] = es[prop.name] - - # Allow pk properties only if they have values - for _, pk_prop in es.primary_key_properties: - if insert_data[pk_prop.name] is None: - insert_data.pop(pk_prop.name) - - # Deep insert from nav properties - for prop_name, prop in es.navigation_properties: - if prop.foreign_key: - insert_data.pop(prop.foreign_key, None) - - value = getattr(entity, prop_name, None) - """:type : None | odata.entity.EntityBase | list[odata.entity.EntityBase]""" - if value is not None: - - if prop.is_collection: - binds = [] - - # binds must be added first - for i in [i for i in value if i.__odata__.id]: - binds.append(i.__odata__.id) - - if len(binds): - insert_data['{0}@odata.bind'.format(prop.name)] = binds - - new_entities = [] - for i in [i for i in value if i.__odata__.id is None]: - new_entities.append(self._clean_new_entity(i)) - - if len(new_entities): - insert_data[prop.name] = new_entities - - else: - if value.__odata__.id: - insert_data['{0}@odata.bind'.format(prop.name)] = value.__odata__.id - else: - insert_data[prop.name] = self._clean_new_entity(value) - - return insert_data +# -*- coding: utf-8 -*- + +from __future__ import print_function +import os +import inspect +import logging +import re +from collections import OrderedDict + +try: + # noinspection PyUnresolvedReferences + from urllib.parse import urljoin, urlparse +except ImportError: + # noinspection PyUnresolvedReferences + from urlparse import urljoin, urlparse + +from odata.property import PropertyBase, NavigationProperty +import odata + + +class EntityState(object): + + def __init__(self, entity): + self.log = logging.getLogger('odata.state') + """:type entity: EntityBase """ + self.entity = entity + self.dirty = [] + self.nav_cache = {} + self.data = OrderedDict() + self.connection = None + # does this object exist serverside + self.persisted = False + self.persisted_id = None + self.odata_scope = None + + # dictionary access + def __getitem__(self, item): + return self.data[item] + + def __setitem__(self, key, value): + self.data[key] = value + + def __contains__(self, item): + return item in self.data + + def get(self, key, default): + return self.data.get(key, default) + + def update(self, other): + self.data.update(other) + # /dictionary access + + def __repr__(self): + return self.data.__repr__() + + def describe(self): + rows = [ + u'EntitySet: {0}'.format(self.entity.__odata_collection__), + u'Type: {0}'.format(self.entity.__odata_type__), + u'URL: {0}'.format(self.instance_url or self.entity.__odata_url__()), + u'', + u'Properties', + u'-' * 40, + ] + + for _, prop in self.properties: + name = prop.name + if prop.primary_key: + name += '*' + if prop.name in self.dirty: + name += ' (dirty)' + rows.append(name) + + rows.append(u'') + rows.append(u'Navigation Properties') + rows.append(u'-' * 40) + + for _, prop in self.navigation_properties: + rows.append(prop.name) + + rows = os.linesep.join(rows) + print(rows) + + def reset(self): + self.dirty = [] + self.nav_cache = {} + self.persisted_id = None + + @property + def id(self): + if self.persisted and self.persisted_id: + return self.persisted_id + ids = [] + entity_name = self.entity.__odata_collection__ + if entity_name is None: + return + + for prop_name, prop in self.primary_key_properties: + value = self.data.get(prop.name) + if value is not None: + if isinstance(value, str): + ids.append((prop, str(prop.escape_value(value)))) + else: + ids.append((prop, value)) + else: + return + + if len(ids) == 1: + key_value = ids[0][1] + return u'{0}({1})'.format(entity_name, + key_value) + if len(ids) > 1: + key_ids = [] + for prop, key_value in ids: + key_ids.append('{0}={1}'.format(prop.name, key_value)) + return u'{0}({1})'.format(entity_name, ','.join(key_ids)) + + @property + def instance_url(self): + odata_id = self.get('@odata.id', None) + if self.id: + if self.odata_scope: + if self.odata_scope.endswith(self.entity.__odata_collection__): + url = re.sub(self.entity.__odata_collection__, '', self.odata_scope) + return urljoin(url, self.id) + else: + return self.odata_scope + elif odata_id and self.id in odata_id: + url = re.sub(self.entity.__odata_collection__, '', self.entity.__odata_url__()) + odata_id = odata_id.split('/')[-1] + return urljoin(url, odata_id) + else: + url = re.sub(self.entity.__odata_collection__, '', self.entity.__odata_url__()) + return urljoin(url, self.id) + elif odata_id: + url = re.sub(self.entity.__odata_collection__, '', self.entity.__odata_url__()) + odata_id = odata_id.split('/')[-1] + return urljoin(url, odata_id) + + @property + def properties(self): + props = [] + cls = self.entity.__class__ + for key, value in inspect.getmembers(cls): + if isinstance(value, PropertyBase): + props.append((key, value)) + return props + + @property + def primary_key_properties(self): + pks = [] + for prop_name, prop in self.properties: + if prop.primary_key is True: + pks.append((prop_name, prop)) + return pks + + @property + def navigation_properties(self): + props = [] + cls = self.entity.__class__ + for key, value in inspect.getmembers(cls): + if isinstance(value, NavigationProperty): + props.append((key, value)) + return props + + @property + def dirty_properties(self): + rv = [] + for prop_name, prop in self.properties: + if prop.name in self.dirty: + rv.append((prop_name, prop)) + for prop_name, prop in self.navigation_properties: + if prop.name in self.dirty: + rv.append((prop_name, prop)) + return rv + + def set_property_dirty(self, prop): + if prop.name not in self.dirty: + self.dirty.append(prop.name) + + def data_for_insert(self): + return self._new_entity(self.entity) + + def data_for_update(self): + return self._updated_entity(self.entity) + + def set_scope(self, odata_scope): + if odata_scope: + self.odata_scope = odata_scope + + def _new_entity(self, entity): + """:type entity: odata.entity.EntityBase """ + insert_data = OrderedDict() + insert_data['@odata.type'] = entity.__odata_type__ + + es = entity.__odata__ + + for _, prop in es.properties: + if prop.is_computed_value: + continue + + insert_data[prop.name] = es[prop.name] + + # Allow pk properties only if they have values + for _, pk_prop in es.primary_key_properties: + if insert_data[pk_prop.name] is None: + insert_data.pop(pk_prop.name) + + # Deep insert from nav properties + for prop_name, prop in es.navigation_properties: + value = getattr(entity, prop_name, None) + """:type : None | odata.entity.EntityBase | list[odata.entity.EntityBase]""" + insert_data = self._add_or_update_associated(insert_data, prop, value) + + for _, prop in es.properties: + if prop.name in insert_data: + if not insert_data[prop.name]: + insert_data.pop(prop.name) + + return insert_data + + def _updated_entity(self, entity): + update_data = OrderedDict() + update_data['@odata.type'] = self.entity.__odata_type__ + + es = entity.__odata__ + + for _, pk_prop in es.primary_key_properties: + update_data[pk_prop.name] = es[pk_prop.name] + + if '@odata.etag' in es: + update_data['@odata.etag'] = es['@odata.etag'] + + for _, prop in es.dirty_properties: + if prop.is_computed_value: + continue + if prop.name in dict(es.navigation_properties).keys(): + continue + + update_data[prop.name] = es[prop.name] + + for prop_name, prop in es.navigation_properties: + if prop.name in es.dirty: + value = getattr(entity, prop_name, None) # get the related object + """:type : None | odata.entity.EntityBase | list[odata.entity.EntityBase]""" + update_data = self._add_or_update_associated(update_data, prop, value) + + return update_data + + def _add_or_update_associated(self, data, prop, value): + if value is None: + return data + if prop.is_collection: + data = self._add_or_update_associated_collection(data, prop, value) + else: + data = self._add_or_update_associated_instance(data, prop, value) + return data + + def _add_or_update_associated_collection(self, data, prop, value): + + def is_new(entity): + if entity.__odata__.id is None: + return True + return False + + def is_dirty(entity): + if is_new(entity): + return False + elif hasattr(entity.__odata__, 'dirty') and entity.__odata__.dirty: + return True + return False + + def is_persisted(entity): + return (not is_new(entity) and not is_dirty(entity)) + + ids = ['/' + i.__odata__.id for i in value if is_persisted(i)] + if ids: + data['{0}@odata.bind'.format(prop.name)] = ids + + upd_objs = [self._updated_entity(i) for i in value if is_dirty(i)] + + new_objs = [self._new_entity(i) for i in value if is_new(i)] + + if upd_objs or new_objs: + data[prop.name] = upd_objs + new_objs + + return data + + def _add_or_update_associated_instance(self, data, prop, value): + if isinstance(value, odata.entity.EntityBase): + if value.persisted is False: + data[prop.name] = self._new_entity(value) + + elif value.__odata__.id: + data['{0}@odata.bind'.format(prop.name)] = '/' + value.__odata__.id + + elif value.dirty: + data[prop.name] = self._updated_entity(value) + + elif value.__odata__.id: + data['{0}@odata.bind'.format(prop.name)] = '/' + value.__odata__.id + + return data diff --git a/odata/tests/test_metadata.py b/odata/tests/test_metadata.py index 8d6029b..a5e39ba 100644 --- a/odata/tests/test_metadata.py +++ b/odata/tests/test_metadata.py @@ -9,6 +9,7 @@ from odata import ODataService from odata.entity import EntityBase +from odata.metadata import MetaData path = os.path.join(os.path.dirname(__file__), 'demo_metadata.xml') with open(path, mode='rb') as f: @@ -18,6 +19,7 @@ class TestMetadataImport(TestCase): def test_read(self): + MetaData.flush_cache() with responses.RequestsMock() as rsps: rsps.add(rsps.GET, 'http://demo.local/odata/$metadata/', body=metadata_xml, content_type='text/xml') @@ -46,6 +48,7 @@ def test_read(self): self.assertIn('DemoUnboundAction', Service.actions) def test_computed_value_in_insert(self): + MetaData.flush_cache() with responses.RequestsMock() as rsps: rsps.add(rsps.GET, 'http://demo.local/odata/$metadata/', body=metadata_xml, content_type='text/xml') diff --git a/odata/tests/test_state.py b/odata/tests/test_state.py new file mode 100644 index 0000000..a173259 --- /dev/null +++ b/odata/tests/test_state.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + +import unittest + +from odata.state import EntityState +from odata.tests import Product, ProductPart + + +class TestSate(unittest.TestCase): + + def test_new_entity(self): + uuid = '3d46cd74-a3af-4afd-af94-512b5cee1ef0' + + product = Product() + product.id = uuid + product.name = u'Defender' + product.category = u'Cars' + product.price = 40000.00 + + state = EntityState(product) + + data = dict(state.data_for_insert()) + + assert data['ProductID'] == uuid + assert data['ProductName'] == 'Defender' + assert data['Category'] == 'Cars' + assert data['Price'] == 40000.00 + + assert state.dirty == [] + + product.name = 'Toyota Carola' + product.price = 32500.00 + + data = dict(state.data_for_update()) + + assert data['ProductName'] == 'Toyota Carola' + assert data['Category'] == 'Cars' + assert data['Price'] == 32500.00 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..af4cb20 Binary files /dev/null and b/requirements.txt differ