diff --git a/hashtopolis/__init__.py b/hashtopolis/__init__.py index f10fc27..35f30a9 100644 --- a/hashtopolis/__init__.py +++ b/hashtopolis/__init__.py @@ -15,6 +15,7 @@ # models from .hashtopolis import ( + ApiToken, AccessGroup, Agent, AgentStat, diff --git a/hashtopolis/hashtopolis.py b/hashtopolis/hashtopolis.py index 8d5b223..38cca7a 100644 --- a/hashtopolis/hashtopolis.py +++ b/hashtopolis/hashtopolis.py @@ -60,6 +60,16 @@ def __init__(self): self.username = self._cfg['username'] self.password = self._cfg['password'] + @classmethod + def with_credentials(cls, uri, username, password): + """Create a config with explicit credentials instead of reading from a config file.""" + config = cls.__new__(cls) + config._hashtopolis_uri = uri + config._api_endpoint = uri + '/api/v2' + config.username = username + config.password = password + return config + class HashtopolisResponseError(HashtopolisError): pass @@ -106,22 +116,34 @@ def __init__(self, model_uri, config): self._hashtopolis_uri = config._hashtopolis_uri self.config = config - def authenticate(self): - if self._api_endpoint not in HashtopolisConnector.token: - # Request access TOKEN, used throughout the test + def authenticate(self, auth=None): + """ + Authenticate with the API and store the token for future requests. - logger.info("Start authentication") + Args: + auth: Authentication object understood by requests, typically a + ``(username, password)`` tuple. Is only used for one off authentication + that differ from the config. This authentication is not cached. + """ + if auth is not None: + logger.info("Start authentication with provided credentials") auth_uri = self._api_endpoint + '/auth/token' - auth = (self.config.username, self.config.password) r = requests.post(auth_uri, auth=auth) self.validate_status_code(r, [201], "Authentication failed") - r_json = self.resp_to_json(r) - HashtopolisConnector.token[self._api_endpoint] = r_json['token'] - HashtopolisConnector.token_expires[self._api_endpoint] = r_json['token'] - - self._token = HashtopolisConnector.token[self._api_endpoint] - self._token_expires = HashtopolisConnector.token_expires[self._api_endpoint] + self._token = r_json['token'] + self._token_expires = r_json['token'] + else: + if self._api_endpoint not in HashtopolisConnector.token: + logger.info("Start authentication") + auth_uri = self._api_endpoint + '/auth/token' + r = requests.post(auth_uri, auth=(self.config.username, self.config.password)) + self.validate_status_code(r, [201], "Authentication failed") + r_json = self.resp_to_json(r) + HashtopolisConnector.token[self._api_endpoint] = r_json['token'] + HashtopolisConnector.token_expires[self._api_endpoint] = r_json['token'] + self._token = HashtopolisConnector.token[self._api_endpoint] + self._token_expires = HashtopolisConnector.token_expires[self._api_endpoint] self._headers = { 'Authorization': 'Bearer ' + self._token @@ -190,9 +212,9 @@ def validate_status_code(self, r, expected_status_code, error_msg): # query_params = urllib.parse.parse_qs(urllib.parse.urlparse(links["last"]).query) # TODO not really a straightforward way to validate the last link - def get_single_page(self, page, filter): + def get_single_page(self, page, filter, auth=None): """Gets a single page by using the page parameters""" - self.authenticate() + self.authenticate(auth=auth) headers = self._headers request_uri = self._api_endpoint + self._model_uri payload = {} @@ -215,8 +237,8 @@ def get_single_page(self, page, filter): return response["data"] # todo refactor start_offset into page variable - def filter(self, include, ordering, filter, start_offset): - self.authenticate() + def filter(self, include, ordering, filter, start_offset, auth=None): + self.authenticate(auth=auth) headers = self._headers after_dict = {"primary": {"id": start_offset}} @@ -253,8 +275,8 @@ def filter(self, include, ordering, filter, start_offset): break request_uri = response['links']['next'] - def get_one(self, pk, include): - self.authenticate() + def get_one(self, pk, include, auth=None): + self.authenticate(auth=auth) uri = self._api_endpoint + self._model_uri + f'/{pk}' headers = self._headers @@ -266,8 +288,8 @@ def get_one(self, pk, include): self.validate_status_code(r, [200], "Get single object failed") return self.resp_to_json(r) - def delete_many(self, objects): - self.authenticate() + def delete_many(self, objects, auth=None): + self.authenticate(auth=auth) uri = self._api_endpoint + self._model_uri headers = self._headers headers['Content-Type'] = 'application/json' @@ -282,7 +304,7 @@ def delete_many(self, objects): r = requests.delete(uri, headers=headers, data=json.dumps(payload)) self.validate_status_code(r, [204], "deleting failed") - def patch_many(self, objects, attributes, field): + def patch_many(self, objects, attributes, field, auth=None): """ Used to test PATCH many endpoint. @@ -293,7 +315,7 @@ def patch_many(self, objects, attributes, field): patched with attributes[0] on the set field """ assert len(objects) == len(attributes) - self.authenticate() + self.authenticate(auth=auth) uri = self._api_endpoint + self._model_uri headers = self._headers headers['Content-Type'] = 'application/json' @@ -302,12 +324,12 @@ def patch_many(self, objects, attributes, field): r = requests.patch(uri, headers=headers, data=json.dumps(payload)) self.validate_status_code(r, [200], "Patching failed") - def patch_one(self, obj): + def patch_one(self, obj, auth=None): if not obj.has_changed(): logger.debug("Object '%s' has not changed, no PATCH required", obj) return - self.authenticate() + self.authenticate(auth=auth) uri = self._hashtopolis_uri + obj.uri headers = self._headers headers['Content-Type'] = 'application/json' @@ -325,15 +347,15 @@ def patch_one(self, obj): # TODO: Validate if return objects matches digital twin obj.set_initial(self.resp_to_json(r)['data'].copy()) - def send_patch(self, uri, data): - self.authenticate() + def send_patch(self, uri, data, auth=None): + self.authenticate(auth=auth) headers = self._headers headers['Content-Type'] = 'application/json' logger.debug("Sending PATCH payload: %s to %s", json.dumps(data), uri) r = requests.patch(uri, headers=headers, data=json.dumps(data)) self.validate_status_code(r, [204], "Patching failed") - def patch_to_many_relationships(self, obj): + def patch_to_many_relationships(self, obj, auth=None): for k, v in obj.diff_includes().items(): attributes = [] logger.debug("Going to patch object '%s' property '%s' from '%s' to '%s'", obj, k, v[0], v[1]) @@ -341,13 +363,13 @@ def patch_to_many_relationships(self, obj): attributes.append({"type": k, "id": include_id}) data = {"data": attributes} uri = self._hashtopolis_uri + obj.uri + "/relationships/" + k - self.send_patch(uri, data) + self.send_patch(uri, data, auth=auth) - def create(self, obj): + def create(self, obj, auth=None): # Check if object to be created is new assert obj._new_model is True - self.authenticate() + self.authenticate(auth=auth) uri = self._api_endpoint + self._model_uri headers = self._headers headers['Content-Type'] = 'application/json' @@ -362,12 +384,12 @@ def create(self, obj): # TODO: Validate if return objects matches digital twin obj.set_initial(self.resp_to_json(r)['data'].copy()) - def delete(self, obj): + def delete(self, obj, auth=None): """ Delete object from database """ # TODO: Check if object to be deleted actually exists assert obj._new_model is False - self.authenticate() + self.authenticate(auth=auth) uri = self._hashtopolis_uri + obj.uri headers = self._headers payload = {} @@ -377,8 +399,8 @@ def delete(self, obj): # TODO: Cleanup object to allow re-creation - def count(self, filter): - self.authenticate() + def count(self, filter, auth=None): + self.authenticate(auth=auth) uri = self._api_endpoint + self._model_uri + "/count" headers = self._headers payload = {} @@ -394,12 +416,13 @@ def count(self, filter): # Build Django ORM style django.query interface class QuerySet(): - def __init__(self, cls, include=None, ordering=None, filters=None, pages=None): + def __init__(self, cls, include=None, ordering=None, filters=None, pages=None, auth=None): self.cls = cls self.include = include self.ordering = ordering self.filters = filters self.pages = pages + self.auth = auth def __iter__(self): yield from self.__getitem__(slice(None, None, 1)) @@ -431,7 +454,7 @@ def filter_(self, start, stop, step): filters['id'] = filters['pk'] del filters['pk'] - filter_generator = self.cls.get_conn().filter(self.include, self.ordering, filters, start_offset=cursor) + filter_generator = self.cls.get_conn().filter(self.include, self.ordering, filters, start_offset=cursor, auth=self.auth) while index < stop: # Fetch new entries in chunks default to server @@ -469,6 +492,10 @@ def page(self, **pages): def all(self): # yield from self return self + + def authenticate(self, auth): + self.auth = auth + return self def get(self, **filters): if filters: @@ -760,6 +787,10 @@ def uri(self): ## # Begin of API objects # +class ApiToken(Model, uri="/ui/apiTokens"): + pass + + class AccessGroup(Model, uri="/ui/accessgroups"): pass