diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100644 index 87fc76bd5..000000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,37 +0,0 @@ -version: 2 - -workflows: - version: 2 - workflow: - jobs: - - test-3.8 - - test-3.9 - - test-3.10 - -defaults: &defaults - working_directory: ~/code - steps: - - checkout - - run: - name: Install dependencies - command: pip install --user -r test-requirements.txt - - run: - name: Test - command: pytest tests/ - -jobs: - test-3.8: - <<: *defaults - docker: - - image: circleci/python:3.8 - - image: mongo:3.2.19 - test-3.9: - <<: *defaults - docker: - - image: circleci/python:3.9 - - image: mongo:3.2.19 - test-3.10: - <<: *defaults - docker: - - image: circleci/python:3.10 - - image: mongo:3.2.19 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 000000000..314a6dbd8 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,46 @@ +name: Test +'on': + push: + branches: + - master + pull_request: + branches: + - master + workflow_dispatch: {} +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: + - '3.11' + - '3.12' + - '3.13' + mongodb-version: + - 5.0.30 + - 6.0.19 + - 7.0.16 + pymongo-version: + - 3.13.0 + - 4.2.0 + - 4.6.3 + - 4.10.1 + services: + mongodb: + image: mongo:${{ matrix.mongodb-version }} + ports: + - 27017:27017 + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: 'python -m pip install --upgrade pip + + pip install "pymongo==${{ matrix.pymongo-version }}" + + pip install -r test-requirements.txt' + - name: Run tests + run: pytest tests/ diff --git a/README.md b/README.md index f18a89433..809db4d59 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ MongoMallard MongoMallard is a fast ORM-like layer on top of PyMongo, based on MongoEngine. -* Repository: https://github.com/elasticsales/mongoengine +* Repository: https://github.com/closeio/mongoengine * See [README_MONGOENGINE](https://github.com/elasticsales/mongoengine/blob/master/README_MONGOENGINE.rst) for MongoEngine's README. * See [DIFFERENCES](https://github.com/elasticsales/mongoengine/blob/master/DIFFERENCES.md) for differences between MongoEngine and MongoMallard. @@ -11,75 +11,19 @@ MongoMallard is a fast ORM-like layer on top of PyMongo, based on MongoEngine. Benchmarks ---------- -Sample run on a 2.7 GHz Intel Core i5 running OS X 10.8.3 +Sample run on a Apple M3 Max running Sonoma 14.6.1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
MongoEngine 0.8.2 (ede9fcf)MongoMallard (478062c)Speedup
Doc initialization52.494us25.195us2.08x
Doc getattr1.339us0.584us2.29x
Doc setattr3.064us2.550us1.20x
Doc to mongo49.415us26.497us1.86x
Load from SON61.475us4.510us13.63x
Save to database434.389us289.972us2.29x
Load from database558.178us480.690us1.16x
Save/delete big object to database98.838ms65.789ms1.50x
Serialize big object from database31.390ms20.265ms1.55x
Load big object from database41.159ms1.400ms29.40x
+| | MongoEngine | MongoMallard | Speedup | +|---|---|---|---| +| Doc initialization | 10.113us | 3.219us | 3.14x | +| Doc getattr | 0.086us | 0.086us | 1.00x | +| Doc setattr | 0.549us | 0.211us | 2.60x | +| Doc to mongo | 5.991us | 3.181us | 1.88x | +| Load from SON | 12.094us | 0.685us | 17.66x | +| Save to database | 259.094us | 218.945us | 1.18x | +| Load from database | 260.192us | 246.576us | 1.06x | +| Save/delete big object to database | 18.510ms | 8.925ms | 2.07x | +| Serialize big object from database | 4.058ms | 2.346ms | 1.73x | +| Load big object from database | 11.205ms | 0.655ms | 17.11x | See [tests/benchmark.py](https://github.com/elasticsales/mongoengine/blob/master/tests/benchmark.py) for source code. diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index 56845d0c1..8d322de0d 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -14,7 +14,7 @@ __all__ = (list(document.__all__) + fields.__all__ + connection.__all__ + list(queryset.__all__) + signals.__all__ + list(errors.__all__)) -VERSION = (0, 8, 2) +VERSION = (0, 8, 3) MALLARD = True diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 9a5d5b1ba..e1499b50a 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -1,4 +1,5 @@ import operator +import warnings from functools import partial import pymongo @@ -13,6 +14,7 @@ from mongoengine.base.common import get_document, ALLOW_INHERITANCE from mongoengine.base.datastructures import BaseDict, BaseList from mongoengine.base.fields import ComplexBaseField +from mongoengine.pymongo_support import LEGACY_JSON_OPTIONS __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') @@ -193,14 +195,39 @@ def validate(self, clean=True): message = "ValidationError (%s:%s) " % (self._class_name, pk) raise ValidationError(message, errors=errors) - def to_json(self): - """Converts a document to JSON""" - return json_util.dumps(self.to_mongo()) + def to_json(self, json_options=None): + """Convert this document to JSON.""" + if json_options is None: + warnings.warn( + "No 'json_options' are specified! Falling back to " + "LEGACY_JSON_OPTIONS with uuid_representation=PYTHON_LEGACY. " + "For use with other MongoDB drivers specify the UUID " + "representation to use. This will be changed to " + "uuid_representation=UNSPECIFIED in a future release.", + DeprecationWarning, + stacklevel=2, + ) + json_options = LEGACY_JSON_OPTIONS + return json_util.dumps(self.to_mongo(), json_options=json_options) @classmethod - def from_json(cls, json_data): - """Converts json data to an unsaved document instance""" - return cls._from_son(json_util.loads(json_data)) + def from_json(cls, json_data, json_options=None): + """Converts json data to a Document instance. + + :param str json_data: The json data to load into the Document. + """ + if json_options is None: + warnings.warn( + "No 'json_options' are specified! Falling back to " + "LEGACY_JSON_OPTIONS with uuid_representation=PYTHON_LEGACY. " + "For use with other MongoDB drivers specify the UUID " + "representation to use. This will be changed to " + "uuid_representation=UNSPECIFIED in a future release.", + DeprecationWarning, + stacklevel=2, + ) + json_options = LEGACY_JSON_OPTIONS + return cls._from_son(json_util.loads(json_data, json_options=json_options)) def __expand_dynamic_values(self, name, value): """expand any dynamic values to their correct types / values""" diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 13c4719ab..696f6ecb8 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,6 +1,7 @@ import pymongo -from pymongo import (MongoClient, MongoReplicaSetClient, ReadPreference, - uri_parser) +import warnings +from pymongo import MongoClient, ReadPreference, uri_parser +from pymongo.common import _UUID_REPRESENTATIONS __all__ = [ 'DEFAULT_CONNECTION_NAME', @@ -36,6 +37,7 @@ def register_connection( slaves=None, username=None, password=None, + uuidrepresentation=None, **kwargs ): """Add a connection. @@ -84,7 +86,22 @@ def register_connection( }) if "replicaSet" in host: conn_settings['replicaSet'] = True - + if "uuidrepresentation" in uri_dict: + uuidrepresentation = uri_dict.get('uuidrepresentation') + + if uuidrepresentation is None: + warnings.warn( + "No uuidrepresentation is specified! Falling back to " + "'pythonLegacy' which is the default for pymongo 3.x. " + "For compatibility with other MongoDB drivers this should be " + "specified as 'standard' or '{java,csharp}Legacy' to work with " + "older drivers in those languages. This will be changed to " + "'unspecified' in a future release.", + DeprecationWarning, + stacklevel=3, + ) + + conn_settings['uuidrepresentation'] = uuidrepresentation or 'pythonLegacy' conn_settings.update(kwargs) _connection_settings[alias] = conn_settings @@ -129,7 +146,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): conn_settings['slaves'] = slaves conn_settings.pop('read_preference', None) - connection_class = MongoClient if 'replicaSet' in conn_settings: conn_settings['hosts_or_uri'] = conn_settings.pop('host', None) # Discard port since it can't be used on MongoReplicaSetClient @@ -137,10 +153,9 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): # Discard replicaSet if not base string if not isinstance(conn_settings['replicaSet'], str): conn_settings.pop('replicaSet', None) - connection_class = MongoReplicaSetClient try: - _connections[alias] = connection_class(**conn_settings) + _connections[alias] = MongoClient(**conn_settings) except Exception as e: raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e)) return _connections[alias] @@ -155,10 +170,6 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): conn = get_connection(alias) conn_settings = _connection_settings[alias] db = conn[conn_settings['name']] - # Authenticate if necessary - if conn_settings['username'] and conn_settings['password']: - db.authenticate(conn_settings['username'], - conn_settings['password']) _dbs[alias] = db return _dbs[alias] diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index 2d7a8e69c..a7bbbe6d8 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -175,14 +175,14 @@ def __init__(self): def __enter__(self): """ On every with block we need to drop the profile collection. """ - self.db.set_profiling_level(0) + self.db.command({"profile": 0}) self.db.system.profile.drop() - self.db.set_profiling_level(2) + self.db.command({"profile": 2}) return self def __exit__(self, t, value, traceback): """ Reset the profiling level. """ - self.db.set_profiling_level(0) + self.db.command({"profile": 0}) def __eq__(self, value): """ == Compare querycounter. """ @@ -220,7 +220,7 @@ def __repr__(self): def _get_count(self): """ Get the number of queries. """ ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}} - count = self.db.system.profile.find(ignore_query).count() - self.counter + count = self.db.system.profile.count_documents(filter=ignore_query) - self.counter self.counter += 1 return count diff --git a/mongoengine/document.py b/mongoengine/document.py index 407409795..ce326ab2d 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -13,6 +13,7 @@ from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME from mongoengine.context_managers import (set_write_concern, switch_db, switch_collection) +from mongoengine.pymongo_support import list_collection_names __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', 'DynamicEmbeddedDocument', 'OperationError', @@ -147,7 +148,9 @@ def _get_collection(cls): max_size = cls._meta['max_size'] or 10000000 # 10MB default max_documents = cls._meta['max_documents'] - if collection_name in db.collection_names(): + if collection_name in list_collection_names( + db, include_system_collections=True + ): cls._collection = db[collection_name] # The collection already exists, check if its capped # options match the specified capped options diff --git a/mongoengine/fields.py b/mongoengine/fields.py index c9f7bc1f1..4d138edc6 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -28,6 +28,7 @@ from .queryset import DO_NOTHING, QuerySet from .document import Document, EmbeddedDocument from .connection import get_db, DEFAULT_CONNECTION_NAME +from pymongo import ReturnDocument try: from PIL import Image, ImageOps @@ -1464,8 +1465,9 @@ def __init__(self, collection_name=None, db_alias=None, sequence_name=None, self.collection_name = collection_name or self.COLLECTION_NAME self.db_alias = db_alias or DEFAULT_CONNECTION_NAME self.sequence_name = sequence_name - self.value_decorator = (callable(value_decorator) and - value_decorator or self.VALUE_DECORATOR) + self.value_decorator = ( + value_decorator if callable(value_decorator) else self.VALUE_DECORATOR + ) return super(SequenceField, self).__init__(*args, **kwargs) def generate(self): @@ -1475,22 +1477,28 @@ def generate(self): sequence_name = self.get_sequence_name() sequence_id = "%s.%s" % (sequence_name, self.name) collection = get_db(alias=self.db_alias)[self.collection_name] - counter = collection.find_and_modify(query={"_id": sequence_id}, - update={"$inc": {"next": 1}}, - new=True, - upsert=True) - return self.value_decorator(counter['next']) + + counter = collection.find_one_and_update( + filter={"_id": sequence_id}, + update={"$inc": {"next": 1}}, + return_document=ReturnDocument.AFTER, + upsert=True + ) + return self.value_decorator(counter["next"]) def set_next_value(self, value): """Helper method to set the next sequence value""" sequence_name = self.get_sequence_name() sequence_id = "%s.%s" % (sequence_name, self.name) collection = get_db(alias=self.db_alias)[self.collection_name] - counter = collection.find_and_modify(query={"_id": sequence_id}, - update={"$set": {"next": value}}, - new=True, - upsert=True) - return self.value_decorator(counter['next']) + + counter = collection.find_one_and_update( + filter={"_id": sequence_id}, + update={"$set": {"next": value}}, + return_document=ReturnDocument.AFTER, + upsert=True + ) + return self.value_decorator(counter["next"]) def get_next_value(self): """Helper method to get the next value for previewing. diff --git a/mongoengine/mongodb_support.py b/mongoengine/mongodb_support.py new file mode 100644 index 000000000..f15a72c96 --- /dev/null +++ b/mongoengine/mongodb_support.py @@ -0,0 +1,24 @@ +""" +Helper functions, constants, and types to aid with MongoDB version support +""" + +from mongoengine.connection import get_connection + +# Constant that can be used to compare the version retrieved with +# get_mongodb_version() +MONGODB_34 = (3, 4) +MONGODB_36 = (3, 6) +MONGODB_42 = (4, 2) +MONGODB_44 = (4, 4) +MONGODB_50 = (5, 0) +MONGODB_60 = (6, 0) +MONGODB_70 = (7, 0) + + +def get_mongodb_version(): + """Return the version of the default connected mongoDB (first 2 digits) + + :return: tuple(int, int) + """ + version_list = get_connection().server_info()["versionArray"][:2] # e.g: (3, 2) + return tuple(version_list) diff --git a/mongoengine/pymongo_support.py b/mongoengine/pymongo_support.py new file mode 100644 index 000000000..565f09a81 --- /dev/null +++ b/mongoengine/pymongo_support.py @@ -0,0 +1,30 @@ +""" +Helper functions, constants, and types to aid with PyMongo support. +""" + +import pymongo +from bson import binary, json_util + +PYMONGO_VERSION = tuple(pymongo.version_tuple[:2]) + +# This will be changed to UuidRepresentation.UNSPECIFIED in a future +# (breaking) release. +if PYMONGO_VERSION >= (4,): + LEGACY_JSON_OPTIONS = json_util.LEGACY_JSON_OPTIONS.with_options( + uuid_representation=binary.UuidRepresentation.PYTHON_LEGACY, + ) +else: + LEGACY_JSON_OPTIONS = json_util.DEFAULT_JSON_OPTIONS + + +def list_collection_names(db, include_system_collections=False): + """Pymongo>3.7 deprecates collection_names in favour of list_collection_names""" + if PYMONGO_VERSION >= (3, 7): + collections = db.list_collection_names() + else: + collections = db.collection_names() + + if not include_system_collections: + collections = [c for c in collections if not c.startswith("system.")] + + return collections diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index 4f92b0a1e..19ee97cb6 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -6,7 +6,7 @@ import warnings import pymongo -from bson import json_util +from bson import SON, json_util from bson.code import Code from pymongo.collection import ReturnDocument from pymongo.common import validate_read_preference @@ -14,8 +14,10 @@ from mongoengine import signals from mongoengine.common import _import_class +from mongoengine.connection import get_db from mongoengine.context_managers import set_read_write_concern, set_write_concern from mongoengine.errors import InvalidQueryError, NotUniqueError, OperationError +from mongoengine.pymongo_support import LEGACY_JSON_OPTIONS from mongoengine.queryset import transform from mongoengine.queryset.field_list import QueryFieldList from mongoengine.queryset.visitor import Q, QNode @@ -35,6 +37,7 @@ RE_TYPE = type(re.compile('')) +PYMONGO_VERSION = tuple(pymongo.version_tuple[:2]) class QuerySet(object): """A set of results returned from a query. Wraps a MongoDB cursor, @@ -79,6 +82,13 @@ def __init__(self, document, collection): self._hint = -1 # Using -1 as None is a valid value for hint self._batch_size = None + # Hack - As people expect cursor[5:5] to return + # an empty result set. It's hard to do that right, though, because the + # server uses limit(0) to mean 'no limit'. So we set _empty + # in that case and check for it when iterating. We also unset + # it anytime we change _limit. Inspired by how it is done in pymongo.Cursor + self._empty = False + def __call__(self, q_obj=None, class_check=True, slave_okay=False, read_preference=None, **query): """Filter the selected documents by calling the @@ -177,6 +187,7 @@ def __getitem__(self, key): """Support skip and limit using getitem and slicing syntax. """ queryset = self.clone() + queryset._empty = False # Slice provided if isinstance(key, slice): @@ -185,6 +196,8 @@ def __getitem__(self, key): queryset._skip, queryset._limit = key.start, key.stop if key.start and key.stop: queryset._limit = key.stop - key.start + if queryset._limit == 0: + queryset._empty = True except IndexError as err: # PyMongo raises an error if key.start == key.stop, catch it, # bin it, kill it. @@ -194,6 +207,7 @@ def __getitem__(self, key): queryset.limit(0) queryset._skip = key.start queryset._limit = key.stop - start + queryset._empty = True return queryset raise err # Allow further QuerySet modifications to be performed @@ -319,6 +333,9 @@ def first(self): """Retrieve the first object matching the query. """ queryset = self.clone() + if self._none or self._empty: + return None + try: result = queryset[0] except IndexError: @@ -415,17 +432,39 @@ def count(self, with_limit_and_skip=True): :meth:`skip` that has been applied to this cursor into account when getting the count """ - if self._limit == 0: - return 0 - - if self._none: + # mimic the fact that setting .limit(0) in pymongo sets no limit + # https://www.mongodb.com/docs/manual/reference/method/cursor.limit/#zero-value + if ( + self._limit == 0 + and with_limit_and_skip is False + or self._none + or self._empty + ): return 0 if with_limit_and_skip and self._len is not None: return self._len - count = self._cursor.count(with_limit_and_skip=with_limit_and_skip) + + # TODO: evaluate utility of `estimated_document_count` + + if PYMONGO_VERSION >= (3, 7): + options = {} + + if with_limit_and_skip: + if self._limit is not None: + options["limit"] = self._limit + if self._skip is not None: + options["skip"] = self._skip + if self._hint not in (-1, None): + options["hint"] = self._hint + + count = self._cursor.collection.count_documents(filter=self._query, **options) + else: + count = self._cursor.count(with_limit_and_skip=with_limit_and_skip) + if with_limit_and_skip: self._len = count + return count def delete(self, write_concern=None, _from_doc_delete=False): @@ -505,6 +544,8 @@ def update( if write_concern is None: write_concern = {} + if self._none or self._empty: + return 0 queryset = self.clone() query = queryset._query @@ -587,6 +628,9 @@ def modify(self, upsert=False, full_response=False, remove=False, new=False, **u raise OperationError( "No update parameters, must either update or remove") + if self._none or self._empty: + return None + queryset = self.clone() query = queryset._query if not remove: @@ -749,13 +793,16 @@ def limit(self, n): :param n: the maximum number of objects to return """ + # chesterton's fence: + if n == 0: n = 1 + queryset = self.clone() - if n == 0: - queryset._cursor.limit(1) - else: - queryset._cursor.limit(n) queryset._limit = n - # Return self to allow chaining + queryset._empty = False + + if queryset._cursor_obj: + queryset._cursor_obj.limit(n) + return queryset def skip(self, n): @@ -765,8 +812,11 @@ def skip(self, n): :param n: the number of objects to skip before returning results """ queryset = self.clone() - queryset._cursor.skip(n) queryset._skip = n + + if queryset._cursor_obj: + queryset._cursor_obj.skip(queryset._skip) + return queryset def hint(self, index=None): @@ -783,14 +833,22 @@ def hint(self, index=None): .. versionadded:: 0.5 """ queryset = self.clone() - queryset._cursor.hint(index) queryset._hint = index + + # If a cursor object has already been created, apply the hint to it. + if queryset._cursor_obj: + queryset._cursor_obj.hint(queryset._hint) + return queryset def batch_size(self, size): queryset = self.clone() - queryset._cursor.batch_size(size) queryset._batch_size = size + + # If a cursor object has already been created, apply the batch size to it. + if queryset._cursor_obj: + queryset._cursor_obj.batch_size(queryset._batch_size) + return queryset def distinct(self, field, dereference=True): @@ -1019,9 +1077,20 @@ def as_pymongo(self, coerce_types=False): # JSON Helpers - def to_json(self): + def to_json(self, *args, **kwargs): """Converts a queryset to JSON""" - return json_util.dumps(self.as_pymongo()) + if "json_options" not in kwargs: + warnings.warn( + "No 'json_options' are specified! Falling back to " + "LEGACY_JSON_OPTIONS with uuid_representation=PYTHON_LEGACY. " + "For use with other MongoDB drivers specify the UUID " + "representation to use. This will be changed to " + "uuid_representation=UNSPECIFIED in a future release.", + DeprecationWarning, + stacklevel=2, + ) + kwargs["json_options"] = LEGACY_JSON_OPTIONS + return json_util.dumps(self.as_pymongo(), *args, **kwargs) def from_json(self, json_data): """Converts json data to unsaved objects""" @@ -1080,14 +1149,14 @@ def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, if isinstance(map_f, Code): map_f_scope = map_f.scope map_f = str(map_f) - map_f = Code(queryset._sub_js_fields(map_f), map_f_scope) + map_f = Code(queryset._sub_js_fields(map_f), map_f_scope or None) reduce_f_scope = {} if isinstance(reduce_f, Code): reduce_f_scope = reduce_f.scope reduce_f = str(reduce_f) reduce_f_code = queryset._sub_js_fields(reduce_f) - reduce_f = Code(reduce_f_code, reduce_f_scope) + reduce_f = Code(reduce_f_code, reduce_f_scope or None) mr_args = {'query': queryset._query} @@ -1097,7 +1166,7 @@ def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, finalize_f_scope = finalize_f.scope finalize_f = str(finalize_f) finalize_f_code = queryset._sub_js_fields(finalize_f) - finalize_f = Code(finalize_f_code, finalize_f_scope) + finalize_f = Code(finalize_f_code, finalize_f_scope or None) mr_args['finalize'] = finalize_f if scope: @@ -1107,16 +1176,57 @@ def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, mr_args['limit'] = limit if output == 'inline' and not queryset._ordering: - map_reduce_function = 'inline_map_reduce' + inline = True + mr_args['out'] = {'inline': 1} else: - map_reduce_function = 'map_reduce' - mr_args['out'] = output + inline = False + if isinstance(output, str): + mr_args['out'] = output + + elif isinstance(output, dict): + ordered_output = [] + + for part in ('replace', 'merge', 'reduce'): + value = output.get(part) + if value: + ordered_output.append((part, value)) + break + + else: + raise OperationError('actionData not specified for output') - results = getattr(queryset._collection, map_reduce_function)( - map_f, reduce_f, **mr_args) + db_alias = output.get('db_alias') + remaing_args = ['db', 'sharded', 'nonAtomic'] - if map_reduce_function == 'map_reduce': - results = results.find() + if db_alias: + ordered_output.append(('db', get_db(db_alias).name)) + del remaing_args[0] + + for part in remaing_args: + value = output.get(part) + if value: + ordered_output.append((part, value)) + + mr_args['out'] = SON(ordered_output) + + db = queryset._document._get_db() + result = db.command( + { + 'mapReduce': queryset._document._get_collection_name(), + 'map': map_f, + 'reduce': reduce_f, + **mr_args, + } + ) + + if inline: + results = result['results'] + else: + if isinstance(result['result'], str): + results = db[result['result']].find() + else: + info = result['result'] + results = db.client[info['db']][info['collection']].find() if queryset._ordering: results = results.sort(queryset._ordering) @@ -1167,7 +1277,7 @@ def exec_js(self, code, *fields, **options): code = Code(code, scope=scope) db = queryset._document._get_db() - return db.eval(code, *fields) + return db.command("eval", code, args=fields).get("retval") def where(self, where_clause): """Filter ``QuerySet`` results with a ``$where`` clause (a Javascript diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 6e8b2a12a..e9951bf46 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -230,8 +230,14 @@ def update(_doc_cls=None, **update): value = {key: value} elif op == 'addToSet' and isinstance(value, list): value = {key: {"$each": value}} + elif op == "pushAll": + op = 'push' # convert to non-deprecated keyword + if not isinstance(value, (set, tuple, list)): + value = [value] + value = {key: {'$each': value}} else: value = {key: value} + key = '$' + op if key not in mongo_update: diff --git a/requirements.txt b/requirements.txt index ddff5fb23..71bf46d3b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -pymongo==3.12.3 +pymongo>=3.13 diff --git a/setup.cfg b/setup.cfg index c3df01801..5cc2b4752 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,10 +1,2 @@ -[nosetests] -verbosity = 3 -detailed-errors = 1 -#with-coverage = 1 -#cover-erase = 1 -#cover-html = 1 -#cover-html-dir = ../htmlcov -#cover-package = mongoengine -where = tests -#tests = document/__init__.py +[tool:pytest] +testpaths=tests diff --git a/setup.py b/setup.py index 0391ac00b..4a7086011 100644 --- a/setup.py +++ b/setup.py @@ -38,10 +38,12 @@ def get_version(version_tuple): 'Operating System :: OS Independent', 'Programming Language :: Python', "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.1", - "Programming Language :: Python :: 3.2", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", 'Topic :: Database', 'Topic :: Software Development :: Libraries :: Python Modules', @@ -65,7 +67,7 @@ def get_version(version_tuple): long_description=LONG_DESCRIPTION, platforms=['any'], classifiers=CLASSIFIERS, - install_requires=['pymongo>=3.0,<3.14'], - test_suite='nose.collector', + install_requires=['pymongo>=3.13.0'], + tests_require=['pytest'], **extra_opts ) diff --git a/test-requirements.txt b/test-requirements.txt index 92ef961a4..aa732d848 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,5 +1,4 @@ -r requirements.txt pytest -nose coverage blinker diff --git a/tests/document/test_class_methods.py b/tests/document/test_class_methods.py index c78703fce..ddba1d9ae 100644 --- a/tests/document/test_class_methods.py +++ b/tests/document/test_class_methods.py @@ -6,7 +6,7 @@ from mongoengine import * from mongoengine.queryset import NULLIFY, PULL -from mongoengine.connection import get_db +from mongoengine.pymongo_support import list_collection_names __all__ = ("ClassMethodsTest", ) @@ -28,9 +28,7 @@ class Person(Document): self.Person = Person def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue + for collection in list_collection_names(self.db): self.db.drop_collection(collection) def test_definition(self): @@ -67,10 +65,10 @@ def test_drop_collection(self): """ collection_name = 'person' self.Person(name='Test').save() - self.assertTrue(collection_name in self.db.collection_names()) + self.assertTrue(collection_name in list_collection_names(self.db)) self.Person.drop_collection() - self.assertFalse(collection_name in self.db.collection_names()) + self.assertFalse(collection_name in list_collection_names(self.db)) def test_register_delete_rule(self): """Ensure that register delete rule adds a delete rule to the document @@ -219,7 +217,7 @@ class Person(Document): meta = {'collection': collection_name} Person(name="Test User").save() - self.assertTrue(collection_name in self.db.collection_names()) + self.assertTrue(collection_name in list_collection_names(self.db)) user_obj = self.db[collection_name].find_one() self.assertEqual(user_obj['name'], "Test User") @@ -228,7 +226,7 @@ class Person(Document): self.assertEqual(user_obj.name, "Test User") Person.drop_collection() - self.assertFalse(collection_name in self.db.collection_names()) + self.assertFalse(collection_name in list_collection_names(self.db)) def test_collection_name_and_primary(self): """Ensure that a collection with a specified name may be used. diff --git a/tests/document/test_delta.py b/tests/document/test_delta.py index 355717fb7..73d24c9ed 100644 --- a/tests/document/test_delta.py +++ b/tests/document/test_delta.py @@ -5,6 +5,7 @@ from mongoengine import * from mongoengine.connection import get_db +from mongoengine.pymongo_support import list_collection_names __all__ = ("DeltaTest",) @@ -26,9 +27,7 @@ class Person(Document): self.Person = Person def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue + for collection in list_collection_names(self.db): self.db.drop_collection(collection) def test_delta(self): diff --git a/tests/document/test_indexes.py b/tests/document/test_indexes.py index 23e29078e..05a96d607 100644 --- a/tests/document/test_indexes.py +++ b/tests/document/test_indexes.py @@ -5,13 +5,15 @@ import os import pymongo +import pytest -from nose.plugins.skip import SkipTest from datetime import datetime from mongoengine import * from mongoengine.connection import get_db, get_connection from pymongo.errors import OperationFailure +from mongoengine.pymongo_support import list_collection_names +from mongoengine.mongodb_support import get_mongodb_version, MONGODB_42 __all__ = ("IndexesTest", ) @@ -33,9 +35,7 @@ class Person(Document): self.Person = Person def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue + for collection in list_collection_names(self.db): self.db.drop_collection(collection) def test_indexes_document(self): @@ -579,15 +579,6 @@ class Log(Document): } Log.drop_collection() - - if pymongo.version_tuple[0] < 2 and pymongo.version_tuple[1] < 3: - raise SkipTest('pymongo needs to be 2.3 or higher for this test') - - connection = get_connection() - version_array = connection.server_info()['versionArray'] - if version_array[0] < 2 and version_array[1] < 2: - raise SkipTest('MongoDB needs to be 2.2 or higher for this test') - Log.ensure_indexes() info = Log.objects._collection.index_information() self.assertEqual(3600, @@ -749,6 +740,47 @@ class TestChildDoc(TestDoc): } }) + def test_covered_index(self): + """Ensure that covered indexes can be used""" + + class Test(Document): + a = IntField() + b = IntField() + + meta = {"indexes": ["a"], "allow_inheritance": False, "auto_create_index": True} + + Test.drop_collection() + + obj = Test(a=1) + obj.save() + + # Need to be explicit about covered indexes as mongoDB doesn't know if + # the documents returned might have more keys in that here. + query_plan = Test.objects(id=obj.id).exclude("a").explain() + assert ( + query_plan["queryPlanner"]["winningPlan"]["inputStage"]["stage"] == "IDHACK" + ) + + query_plan = Test.objects(id=obj.id).only("id").explain() + assert ( + query_plan["queryPlanner"]["winningPlan"]["inputStage"]["stage"] == "IDHACK" + ) + + mongo_db = get_mongodb_version() + query_plan = Test.objects(a=1).only("a").exclude("id").explain() + assert ( + query_plan["queryPlanner"]["winningPlan"]["inputStage"]["stage"] == "IXSCAN" + ) + + PROJECTION_STR = "PROJECTION" if mongo_db < MONGODB_42 else "PROJECTION_COVERED" + assert query_plan["queryPlanner"]["winningPlan"]["stage"] == PROJECTION_STR + + query_plan = Test.objects(a=1).explain() + assert ( + query_plan["queryPlanner"]["winningPlan"]["inputStage"]["stage"] == "IXSCAN" + ) + + assert query_plan.get("queryPlanner").get("winningPlan").get("stage") == "FETCH" if __name__ == '__main__': unittest.main() diff --git a/tests/document/test_inheritance.py b/tests/document/test_inheritance.py index 8e8508c6c..f7dfb2851 100644 --- a/tests/document/test_inheritance.py +++ b/tests/document/test_inheritance.py @@ -12,6 +12,7 @@ from mongoengine.connection import get_db from mongoengine.fields import (BooleanField, GenericReferenceField, IntField, StringField) +from mongoengine.pymongo_support import list_collection_names __all__ = ('InheritanceTest', ) @@ -23,9 +24,7 @@ def setUp(self): self.db = get_db() def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue + for collection in list_collection_names(self.db): self.db.drop_collection(collection) def test_superclasses(self): diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index f47c01262..111c2ed10 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- import sys + +from mongoengine.mongodb_support import MONGODB_34, get_mongodb_version sys.path[0:0] = [""] import bson @@ -20,6 +22,7 @@ from mongoengine.base import get_document from mongoengine.context_managers import switch_db, query_counter from mongoengine import signals +from mongoengine.pymongo_support import list_collection_names TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), '../fields/mongoengine.png') @@ -44,9 +47,7 @@ class Person(Document): self.Person = Person def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue + for collection in list_collection_names(self.db): self.db.drop_collection(collection) def test_capped_collection(self): @@ -382,10 +383,9 @@ class Animal(Document): query_op = q.db.system.profile.find_one({ 'ns': 'mongoenginetest.animal' }) - self.assertEqual( - set(query_op['query']['filter'].keys()), - set(['_id', 'superphylum']) - ) + # MongoDB 5.0+ uses 'command' instead of 'query' + cmd_query = query_op.get('command') or query_op.get('query') + assert set(cmd_query['filter'].keys()) == {'_id', 'superphylum'} Animal.drop_collection() @@ -402,13 +402,18 @@ class Animal(Document): doc = Animal(is_mammal=True, name='Dog') doc.save() + mongo_db = get_mongodb_version() + with query_counter() as q: doc.name = 'Cat' doc.save() query_op = q.db.system.profile.find({ 'ns': 'mongoenginetest.animal' })[0] - self.assertEqual(query_op['op'], 'update') - self.assertEqual(set(query_op['query'].keys()), set(['_id', 'is_mammal'])) + # MongoDB 5.0+ uses 'command' instead of 'query' + if mongo_db <= MONGODB_34: + assert set(query_op['query'].keys()) == {'_id', 'is_mammal'}, str(cmd_query) + else: + assert set(query_op['command']['q'].keys()) == {'_id', 'is_mammal'}, str(cmd_query) Animal.drop_collection() @@ -1975,16 +1980,16 @@ class Person(Document): person = Person(name="name", age=10, job=job) from pymongo.collection import Collection - orig_update = Collection.update + orig_update_one = Collection.update_one try: - def fake_update(*args, **kwargs): + def fake_update_one(*args, **kwargs): self.fail("Unexpected update for %s" % args[0].name) - return orig_update(*args, **kwargs) + return orig_update_one(*args, **kwargs) - Collection.update = fake_update + Collection.update_one = fake_update_one person.save() finally: - Collection.update = orig_update + Collection.update_one = orig_update_one def test_db_alias_tests(self): """ DB Alias tests """ diff --git a/tests/document/test_json_serialisation.py b/tests/document/test_json_serialisation.py index 1f2d5c888..1e8a4a733 100644 --- a/tests/document/test_json_serialisation.py +++ b/tests/document/test_json_serialisation.py @@ -4,7 +4,6 @@ import unittest import uuid -from nose.plugins.skip import SkipTest from datetime import datetime from bson import ObjectId @@ -35,9 +34,6 @@ class Doc(Document): def test_json_complex(self): - if pymongo.version_tuple[0] <= 2 and pymongo.version_tuple[1] <= 3: - raise SkipTest("Need pymongo 2.4 as has a fix for DBRefs") - class EmbeddedDoc(EmbeddedDocument): pass diff --git a/tests/fields/test_file.py b/tests/fields/test_file.py index 30dbf15b3..460392eda 100644 --- a/tests/fields/test_file.py +++ b/tests/fields/test_file.py @@ -9,16 +9,11 @@ import gridfs -from nose.plugins.skip import SkipTest from mongoengine import * from mongoengine.connection import get_db from mongoengine.python_support import PY3, b, StringIO +from tests.utils import requires_pil -try: - from PIL import Image - HAS_PIL = True -except ImportError: - HAS_PIL = False TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png') @@ -261,10 +256,8 @@ class TestFile(Document): test_file = TestFile() self.assertFalse(test_file.the_file in [{"test": 1}]) + @requires_pil def test_image_field(self): - if not HAS_PIL: - raise SkipTest('PIL not installed') - class TestImage(Document): image = ImageField() @@ -295,10 +288,8 @@ class TestImage(Document): t.image.delete() + @requires_pil def test_image_field_reassigning(self): - if not HAS_PIL: - raise SkipTest('PIL not installed') - class TestFile(Document): the_file = ImageField() TestFile.drop_collection() @@ -311,10 +302,8 @@ class TestFile(Document): test_file.save() self.assertEqual(test_file.the_file.size, (45, 101)) + @requires_pil def test_image_field_resize(self): - if not HAS_PIL: - raise SkipTest('PIL not installed') - class TestImage(Document): image = ImageField(size=(185, 37)) @@ -334,10 +323,8 @@ class TestImage(Document): t.image.delete() + @requires_pil def test_image_field_resize_force(self): - if not HAS_PIL: - raise SkipTest('PIL not installed') - class TestImage(Document): image = ImageField(size=(185, 37, True)) @@ -357,10 +344,8 @@ class TestImage(Document): t.image.delete() + @requires_pil def test_image_field_thumbnail(self): - if not HAS_PIL: - raise SkipTest('PIL not installed') - class TestImage(Document): image = ImageField(thumbnail_size=(92, 18)) @@ -433,11 +418,9 @@ class TestFile(Document): self.assertEqual(putfile, copy.copy(putfile)) self.assertEqual(putfile, copy.deepcopy(putfile)) + @requires_pil def test_get_image_by_grid_id(self): - if not HAS_PIL: - raise SkipTest('PIL not installed') - class TestImage(Document): image1 = ImageField() diff --git a/tests/queryset/test_modify.py b/tests/queryset/test_modify.py new file mode 100644 index 000000000..b726c7c8a --- /dev/null +++ b/tests/queryset/test_modify.py @@ -0,0 +1,146 @@ +import unittest +import pytest + +from mongoengine import ( + Document, + IntField, + ListField, + StringField, + connect, +) + + +class Doc(Document): + id = IntField(primary_key=True) + value = IntField() + + +class TestFindAndModify(unittest.TestCase): + def setUp(self): + connect(db="mongoenginetest") + Doc.drop_collection() + + def _assert_db_equal(self, docs): + assert list(Doc._collection.find().sort("id")) == docs + + def test_modify(self): + Doc(id=0, value=0).save() + doc = Doc(id=1, value=1).save() + + old_doc = Doc.objects(id=1).modify(set__value=-1) + assert old_doc.to_json() == doc.to_json() + self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) + + def test_modify_with_new(self): + Doc(id=0, value=0).save() + doc = Doc(id=1, value=1).save() + + new_doc = Doc.objects(id=1).modify(set__value=-1, new=True) + doc.value = -1 + assert new_doc.to_json() == doc.to_json() + self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) + + def test_modify_not_existing(self): + Doc(id=0, value=0).save() + assert Doc.objects(id=1).modify(set__value=-1) is None + self._assert_db_equal([{"_id": 0, "value": 0}]) + + def test_modify_with_upsert(self): + Doc(id=0, value=0).save() + old_doc = Doc.objects(id=1).modify(set__value=1, upsert=True) + assert old_doc is None + self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}]) + + def test_modify_with_upsert_existing(self): + Doc(id=0, value=0).save() + doc = Doc(id=1, value=1).save() + + old_doc = Doc.objects(id=1).modify(set__value=-1, upsert=True) + assert old_doc.to_json() == doc.to_json() + self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) + + def test_modify_with_upsert_with_new(self): + Doc(id=0, value=0).save() + new_doc = Doc.objects(id=1).modify(upsert=True, new=True, set__value=1) + assert new_doc.to_mongo() == {"_id": 1, "value": 1} + self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}]) + + def test_modify_with_remove(self): + Doc(id=0, value=0).save() + doc = Doc(id=1, value=1).save() + + old_doc = Doc.objects(id=1).modify(remove=True) + assert old_doc.to_json() == doc.to_json() + self._assert_db_equal([{"_id": 0, "value": 0}]) + + def test_find_and_modify_with_remove_not_existing(self): + Doc(id=0, value=0).save() + assert Doc.objects(id=1).modify(remove=True) is None + self._assert_db_equal([{"_id": 0, "value": 0}]) + + def test_modify_with_order_by(self): + Doc(id=0, value=3).save() + Doc(id=1, value=2).save() + Doc(id=2, value=1).save() + doc = Doc(id=3, value=0).save() + + old_doc = Doc.objects().order_by("-id").modify(set__value=-1) + assert old_doc.to_json() == doc.to_json() + self._assert_db_equal( + [ + {"_id": 0, "value": 3}, + {"_id": 1, "value": 2}, + {"_id": 2, "value": 1}, + {"_id": 3, "value": -1}, + ] + ) + + def test_modify_with_fields(self): + Doc(id=0, value=0).save() + Doc(id=1, value=1).save() + + old_doc = Doc.objects(id=1).only("id").modify(set__value=-1) + assert old_doc.to_mongo() == {"_id": 1} + self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) + + def test_modify_with_push(self): + class BlogPost(Document): + tags = ListField(StringField()) + + BlogPost.drop_collection() + + blog = BlogPost.objects.create() + + # Push a new tag via modify with new=False (default). + BlogPost(id=blog.id).modify(push__tags="code") + assert blog.tags == [] + blog.reload() + assert blog.tags == ["code"] + + # Push a new tag via modify with new=True. + blog = BlogPost.objects(id=blog.id).modify(push__tags="java", new=True) + assert blog.tags == ["code", "java"] + + + @pytest.mark.skip("op__x__n not supported") + def test_modify_with_push_and_index(self): + # This continues to be unsupported with mongomallard. + class BlogPost(Document): + tags = ListField(StringField()) + + BlogPost.drop_collection() + + blog = BlogPost.objects.create() + + # Push a new tag with a positional argument. + blog = BlogPost.objects(id=blog.id).modify(push__tags__0="python", new=True) + assert blog.tags == ["python", "code", "java"] + + # Push multiple new tags with a positional argument. + blog = BlogPost.objects(id=blog.id).modify( + push__tags__1=["go", "rust"], new=True + ) + assert blog.tags == ["python", "go", "rust", "code", "java"] + +if __name__ == "__main__": + unittest.main() diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py index 922ec7f3c..aa8727940 100644 --- a/tests/queryset/test_queryset.py +++ b/tests/queryset/test_queryset.py @@ -1,14 +1,13 @@ import sys - sys.path[0:0] = [""] +import pytest import unittest import uuid from datetime import datetime, timedelta import pymongo -from bson import ObjectId -from nose.plugins.skip import SkipTest +from bson import DBRef, ObjectId from pymongo.errors import ConfigurationError from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference @@ -21,6 +20,12 @@ from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned, QuerySet, QuerySetManager, queryset_manager) +from tests.utils import ( + requires_mongodb_gte_42, + requires_mongodb_gte_44, + requires_mongodb_lt_42, +) + __all__ = ("QuerySetTest",) @@ -140,6 +145,165 @@ def test_find(self): self.assertEqual("[, ]", "%s" % self.Person.objects[1:3]) self.assertEqual("[, ]", "%s" % self.Person.objects[51:53]) + def test_slicing_sets_empty_limit_skip(self): + self.Person.objects.insert( + [self.Person(name=f"User {i}", age=i) for i in range(5)], + load_bulk=False, + ) + + self.Person.objects.create(name="User B", age=30) + self.Person.objects.create(name="User C", age=40) + + qs = self.Person.objects()[1:2] + assert (qs._empty, qs._skip, qs._limit) == (False, 1, 1) + assert len(list(qs)) == 1 + + # Test edge case of [1:1] which should return nothing + # and require a hack so that it doesn't clash with limit(0) + qs = self.Person.objects()[1:1] + assert (qs._empty, qs._skip, qs._limit) == (True, 1, 0) + assert len(list(qs)) == 0 + + qs2 = qs[1:5] # Make sure that further slicing resets _empty + assert (qs2._empty, qs2._skip, qs2._limit) == (False, 1, 4) + assert len(list(qs2)) == 4 + + def test_limit_0_returns_all_documents(self): + self.Person.objects.create(name="User A", age=20) + self.Person.objects.create(name="User B", age=30) + + n_docs = self.Person.objects().count() + + persons = list(self.Person.objects().limit(0)) + assert len(persons) == 1 + assert n_docs == 2 + + def test_limit(self): + """Ensure that QuerySet.limit works as expected.""" + user_a = self.Person.objects.create(name="User A", age=20) + _ = self.Person.objects.create(name="User B", age=30) + + # Test limit on a new queryset + people = list(self.Person.objects.limit(1)) + assert len(people) == 1 + assert people[0] == user_a + + # Test limit on an existing queryset + people = self.Person.objects + assert len(people) == 2 + people2 = people.limit(1) + assert len(people) == 2 + assert len(people2) == 1 + assert people2[0] == user_a + + # Test limit with 0 as parameter + people = self.Person.objects.limit(0) + assert people.count(with_limit_and_skip=True) == 1 + assert len(people) == 1 + + # Test chaining of only after limit + person = self.Person.objects().limit(1).only("name").first() + assert person == user_a + assert person.name == "User A" + assert person.age is None + + def test_skip(self): + """Ensure that QuerySet.skip works as expected.""" + user_a = self.Person.objects.create(name="User A", age=20) + user_b = self.Person.objects.create(name="User B", age=30) + + # Test skip on a new queryset + people = list(self.Person.objects.skip(0)) + assert len(people) == 2 + assert people[0] == user_a + assert people[1] == user_b + + people = list(self.Person.objects.skip(1)) + assert len(people) == 1 + assert people[0] == user_b + + # Test skip on an existing queryset + people = self.Person.objects + assert len(people) == 2 + people2 = people.skip(1) + assert len(people) == 2 + assert len(people2) == 1 + assert people2[0] == user_b + + # Test chaining of only after skip + person = self.Person.objects().skip(1).only("name").first() + assert person == user_b + assert person.name == "User B" + assert person.age is None + + def test___getitem___invalid_index(self): + """Ensure slicing a queryset works as expected.""" + with pytest.raises(AttributeError): + self.Person.objects()["a"] + + def test_slice(self): + """Ensure slicing a queryset works as expected.""" + user_a = self.Person.objects.create(name="User A", age=20) + user_b = self.Person.objects.create(name="User B", age=30) + user_c = self.Person.objects.create(name="User C", age=40) + + # Test slice limit + people = list(self.Person.objects[:2]) + assert len(people) == 2 + assert people[0] == user_a + assert people[1] == user_b + + # Test slice skip + people = list(self.Person.objects[1:]) + assert len(people) == 2 + assert people[0] == user_b + assert people[1] == user_c + + # Test slice limit and skip + people = list(self.Person.objects[1:2]) + assert len(people) == 1 + assert people[0] == user_b + + # Test slice limit and skip on an existing queryset + people = self.Person.objects + assert len(people) == 3 + people2 = people[1:2] + assert len(people2) == 1 + assert people2[0] == user_b + + # Test slice limit and skip cursor reset + qs = self.Person.objects[1:2] + # fetch then delete the cursor + qs._cursor + qs._cursor_obj = None + people = list(qs) + assert len(people) == 1 + assert people[0].name == "User B" + + # Test empty slice + people = list(self.Person.objects[1:1]) + assert len(people) == 0 + + # Test slice out of range + people = list(self.Person.objects[80000:80001]) + assert len(people) == 0 + + # Test larger slice __repr__ + self.Person.objects.delete() + for i in range(55): + self.Person(name="A%s" % i, age=i).save() + + assert self.Person.objects.count() == 55 + assert "Person object" == "%s" % self.Person.objects[0] + assert ( + "[, ]" + == "%s" % self.Person.objects[1:3] + ) + assert ( + "[, ]" + == "%s" % self.Person.objects[51:53] + ) + def test_find_one(self): """Ensure that a query using find_one returns a valid result. """ @@ -255,6 +419,25 @@ class A(Document): self.assertEqual(list(A.objects.none().all()), []) self.assertEqual(A.objects.none().count(), 0) + # validate collection not empty + assert A.objects.count() == 1 + + # update operations + assert A.objects.none().update(s="1") == 0 + assert A.objects.none().update_one(s="1") == 0 + assert A.objects.none().modify(s="1") is None + + # validate noting change by update operations + assert A.objects(s="1").count() == 0 + + # fetch queries + assert A.objects.none().first() is None + assert list(A.objects.none()) == [] + assert list(A.objects.none().all()) == [] + assert list(A.objects.none().limit(1)) == [] + assert list(A.objects.none().skip(1)) == [] + assert list(A.objects.none()[:5]) == [] + def test_chaining(self): class A(Document): s = StringField() @@ -1049,6 +1232,7 @@ class BlogPost(Document): BlogPost.drop_collection() + @requires_mongodb_lt_42 def test_exec_js_query(self): """Ensure that queries are properly formed for use in exec_js. """ @@ -1086,6 +1270,7 @@ class BlogPost(Document): BlogPost.drop_collection() + @requires_mongodb_lt_42 def test_exec_js_field_sub(self): """Ensure that field substitutions occur properly in exec_js functions. """ @@ -1806,9 +1991,8 @@ class BlogPost(Document): results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults") results = list(results) - self.assertEqual(results[0].object, post1) - self.assertEqual(results[1].object, post2) - self.assertEqual(results[2].object, post3) + self.assertEqual({ result.object for result in results }, + { post1, post2, post3 }) BlogPost.drop_collection() @@ -2069,6 +2253,7 @@ class Person(Document): freq = Person.objects.item_frequencies('city', normalize=True, map_reduce=True) self.assertEqual(freq, {'CRB': 0.5, None: 0.5}) + @requires_mongodb_lt_42 def test_item_frequencies_with_null_embedded(self): class Data(EmbeddedDocument): name = StringField() @@ -2097,6 +2282,7 @@ class Person(Document): ot = Person.objects.item_frequencies('extra.tag', map_reduce=True) self.assertEqual(ot, {None: 1.0, 'friend': 1.0}) + @requires_mongodb_lt_42 def test_item_frequencies_with_0_values(self): class Test(Document): val = IntField() @@ -2111,6 +2297,7 @@ class Test(Document): ot = Test.objects.item_frequencies('val', map_reduce=False) self.assertEqual(ot, {0: 1}) + @requires_mongodb_lt_42 def test_item_frequencies_with_False_values(self): class Test(Document): val = BooleanField() @@ -2125,6 +2312,7 @@ class Test(Document): ot = Test.objects.item_frequencies('val', map_reduce=False) self.assertEqual(ot, {False: 1}) + @requires_mongodb_lt_42 def test_item_frequencies_normalize(self): class Test(Document): val = IntField() @@ -3247,8 +3435,6 @@ class Doc(Document): self.assertEqual(doc_objects, Doc.objects.from_json(json_data)) def test_json_complex(self): - if pymongo.version_tuple[0] <= 2 and pymongo.version_tuple[1] <= 3: - raise SkipTest("Need pymongo 2.4 as has a fix for DBRefs") class EmbeddedDoc(EmbeddedDocument): pass diff --git a/tests/test_connection.py b/tests/test_connection.py index 11bf02084..056c506db 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -4,7 +4,8 @@ import datetime import unittest -import pymongo +import pymongo.mongo_client +import pymongo.database from bson.tz_util import utc import mongoengine.connection @@ -42,9 +43,8 @@ def test_connect_uri(self): c.admin.system.users.delete_many({}) c.mongoenginetest.system.users.delete_many({}) - c.admin.add_user("admin", "password") - c.admin.authenticate("admin", "password") - c.mongoenginetest.add_user("username", "password") + c.admin.command('createUser', 'admin', pwd='password', roles=['root']) + c.mongoenginetest.command('createUser', 'username', pwd='password', roles=['read']) self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') @@ -54,8 +54,11 @@ def test_connect_uri(self): self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) db = get_db() - self.assertTrue(isinstance(db, pymongo.database.Database)) - self.assertEqual(db.name, 'mongoenginetest') + assert isinstance(db, pymongo.database.Database) + assert db.name == "mongoenginetest" + + c.admin.system.users.delete_many({}) + c.mongoenginetest.system.users.delete_many({}) def test_register_connection(self): """Ensure that connections with different aliases may be registered. diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index 8852e5f61..12bd7a04c 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -196,7 +196,7 @@ def test_query_counter(self): self.assertEqual(0, q) for i in range(1, 51): - db.test.find({}).count() + db.test.count_documents({}) self.assertEqual(50, q) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 000000000..c76b9f980 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,88 @@ +import functools +import operator + +import pymongo +import pytest + +from mongoengine.mongodb_support import get_mongodb_version + +PYMONGO_VERSION = tuple(pymongo.version_tuple[:2]) + + +def get_as_pymongo(doc): + """Fetch the pymongo version of a certain Document""" + return doc.__class__.objects.as_pymongo().get(id=doc.id) + + +def requires_mongodb_lt_42(func): + return _decorated_with_ver_requirement(func, (4, 2), oper=operator.lt) + + +def requires_mongodb_gte_40(func): + return _decorated_with_ver_requirement(func, (4, 0), oper=operator.ge) + + +def requires_mongodb_gte_42(func): + return _decorated_with_ver_requirement(func, (4, 2), oper=operator.ge) + + +def requires_mongodb_gte_44(func): + return _decorated_with_ver_requirement(func, (4, 4), oper=operator.ge) + + +def requires_mongodb_gte_50(func): + return _decorated_with_ver_requirement(func, (5, 0), oper=operator.ge) + + +def requires_mongodb_gte_60(func): + return _decorated_with_ver_requirement(func, (6, 0), oper=operator.ge) + + +def requires_mongodb_gte_70(func): + return _decorated_with_ver_requirement(func, (7, 0), oper=operator.ge) + +try: + from PIL import Image as _ + HAS_PIL = True +except ImportError: + HAS_PIL = False + +def requires_pil(func): + @functools.wraps(func) + def _inner(*args, **kwargs): + if HAS_PIL: + return func(*args, **kwargs) + else: + pytest.skip("PIL not installed") + +def _decorated_with_ver_requirement(func, mongo_version_req, oper): + """Return a MongoDB version requirement decorator. + + The resulting decorator will skip the test if the current + MongoDB version doesn't match the provided version/operator. + + For example, if you define a decorator like so: + + def requires_mongodb_gte_36(func): + return _decorated_with_ver_requirement( + func, (3.6), oper=operator.ge + ) + + Then tests decorated with @requires_mongodb_gte_36 will be skipped if + ran against MongoDB < v3.6. + + :param mongo_version_req: The mongodb version requirement (tuple(int, int)) + :param oper: The operator to apply (e.g. operator.ge) + """ + + @functools.wraps(func) + def _inner(*args, **kwargs): + mongodb_v = get_mongodb_version() + if oper(mongodb_v, mongo_version_req): + return func(*args, **kwargs) + else: + pretty_version = ".".join(str(n) for n in mongo_version_req) + pytest.skip(f"Needs MongoDB {oper.__name__} v{pretty_version}") + + return _inner +