diff --git a/.circleci/config.yml b/.circleci/config.yml
deleted file mode 100644
index 87fc76bd5..000000000
--- a/.circleci/config.yml
+++ /dev/null
@@ -1,37 +0,0 @@
-version: 2
-
-workflows:
- version: 2
- workflow:
- jobs:
- - test-3.8
- - test-3.9
- - test-3.10
-
-defaults: &defaults
- working_directory: ~/code
- steps:
- - checkout
- - run:
- name: Install dependencies
- command: pip install --user -r test-requirements.txt
- - run:
- name: Test
- command: pytest tests/
-
-jobs:
- test-3.8:
- <<: *defaults
- docker:
- - image: circleci/python:3.8
- - image: mongo:3.2.19
- test-3.9:
- <<: *defaults
- docker:
- - image: circleci/python:3.9
- - image: mongo:3.2.19
- test-3.10:
- <<: *defaults
- docker:
- - image: circleci/python:3.10
- - image: mongo:3.2.19
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
new file mode 100644
index 000000000..314a6dbd8
--- /dev/null
+++ b/.github/workflows/test.yml
@@ -0,0 +1,46 @@
+name: Test
+'on':
+ push:
+ branches:
+ - master
+ pull_request:
+ branches:
+ - master
+ workflow_dispatch: {}
+jobs:
+ test:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version:
+ - '3.11'
+ - '3.12'
+ - '3.13'
+ mongodb-version:
+ - 5.0.30
+ - 6.0.19
+ - 7.0.16
+ pymongo-version:
+ - 3.13.0
+ - 4.2.0
+ - 4.6.3
+ - 4.10.1
+ services:
+ mongodb:
+ image: mongo:${{ matrix.mongodb-version }}
+ ports:
+ - 27017:27017
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: 'python -m pip install --upgrade pip
+
+ pip install "pymongo==${{ matrix.pymongo-version }}"
+
+ pip install -r test-requirements.txt'
+ - name: Run tests
+ run: pytest tests/
diff --git a/README.md b/README.md
index f18a89433..809db4d59 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@ MongoMallard
MongoMallard is a fast ORM-like layer on top of PyMongo, based on MongoEngine.
-* Repository: https://github.com/elasticsales/mongoengine
+* Repository: https://github.com/closeio/mongoengine
* See [README_MONGOENGINE](https://github.com/elasticsales/mongoengine/blob/master/README_MONGOENGINE.rst) for MongoEngine's README.
* See [DIFFERENCES](https://github.com/elasticsales/mongoengine/blob/master/DIFFERENCES.md) for differences between MongoEngine and MongoMallard.
@@ -11,75 +11,19 @@ MongoMallard is a fast ORM-like layer on top of PyMongo, based on MongoEngine.
Benchmarks
----------
-Sample run on a 2.7 GHz Intel Core i5 running OS X 10.8.3
+Sample run on a Apple M3 Max running Sonoma 14.6.1
-
-
- |
- MongoEngine 0.8.2 (ede9fcf) |
- MongoMallard (478062c) |
- Speedup |
-
-
- | Doc initialization |
- 52.494us |
- 25.195us |
- 2.08x |
-
-
- | Doc getattr |
- 1.339us |
- 0.584us |
- 2.29x |
-
-
- | Doc setattr |
- 3.064us |
- 2.550us |
- 1.20x |
-
-
- | Doc to mongo |
- 49.415us |
- 26.497us |
- 1.86x |
-
-
- | Load from SON |
- 61.475us |
- 4.510us |
- 13.63x |
-
-
- | Save to database |
- 434.389us |
- 289.972us |
- 2.29x |
-
-
- | Load from database |
- 558.178us |
- 480.690us |
- 1.16x |
-
-
- | Save/delete big object to database |
- 98.838ms |
- 65.789ms |
- 1.50x |
-
-
- | Serialize big object from database |
- 31.390ms |
- 20.265ms |
- 1.55x |
-
-
- | Load big object from database |
- 41.159ms |
- 1.400ms |
- 29.40x |
-
-
+| | MongoEngine | MongoMallard | Speedup |
+|---|---|---|---|
+| Doc initialization | 10.113us | 3.219us | 3.14x |
+| Doc getattr | 0.086us | 0.086us | 1.00x |
+| Doc setattr | 0.549us | 0.211us | 2.60x |
+| Doc to mongo | 5.991us | 3.181us | 1.88x |
+| Load from SON | 12.094us | 0.685us | 17.66x |
+| Save to database | 259.094us | 218.945us | 1.18x |
+| Load from database | 260.192us | 246.576us | 1.06x |
+| Save/delete big object to database | 18.510ms | 8.925ms | 2.07x |
+| Serialize big object from database | 4.058ms | 2.346ms | 1.73x |
+| Load big object from database | 11.205ms | 0.655ms | 17.11x |
See [tests/benchmark.py](https://github.com/elasticsales/mongoengine/blob/master/tests/benchmark.py) for source code.
diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py
index 56845d0c1..8d322de0d 100644
--- a/mongoengine/__init__.py
+++ b/mongoengine/__init__.py
@@ -14,7 +14,7 @@
__all__ = (list(document.__all__) + fields.__all__ + connection.__all__ +
list(queryset.__all__) + signals.__all__ + list(errors.__all__))
-VERSION = (0, 8, 2)
+VERSION = (0, 8, 3)
MALLARD = True
diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py
index 9a5d5b1ba..e1499b50a 100644
--- a/mongoengine/base/document.py
+++ b/mongoengine/base/document.py
@@ -1,4 +1,5 @@
import operator
+import warnings
from functools import partial
import pymongo
@@ -13,6 +14,7 @@
from mongoengine.base.common import get_document, ALLOW_INHERITANCE
from mongoengine.base.datastructures import BaseDict, BaseList
from mongoengine.base.fields import ComplexBaseField
+from mongoengine.pymongo_support import LEGACY_JSON_OPTIONS
__all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
@@ -193,14 +195,39 @@ def validate(self, clean=True):
message = "ValidationError (%s:%s) " % (self._class_name, pk)
raise ValidationError(message, errors=errors)
- def to_json(self):
- """Converts a document to JSON"""
- return json_util.dumps(self.to_mongo())
+ def to_json(self, json_options=None):
+ """Convert this document to JSON."""
+ if json_options is None:
+ warnings.warn(
+ "No 'json_options' are specified! Falling back to "
+ "LEGACY_JSON_OPTIONS with uuid_representation=PYTHON_LEGACY. "
+ "For use with other MongoDB drivers specify the UUID "
+ "representation to use. This will be changed to "
+ "uuid_representation=UNSPECIFIED in a future release.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ json_options = LEGACY_JSON_OPTIONS
+ return json_util.dumps(self.to_mongo(), json_options=json_options)
@classmethod
- def from_json(cls, json_data):
- """Converts json data to an unsaved document instance"""
- return cls._from_son(json_util.loads(json_data))
+ def from_json(cls, json_data, json_options=None):
+ """Converts json data to a Document instance.
+
+ :param str json_data: The json data to load into the Document.
+ """
+ if json_options is None:
+ warnings.warn(
+ "No 'json_options' are specified! Falling back to "
+ "LEGACY_JSON_OPTIONS with uuid_representation=PYTHON_LEGACY. "
+ "For use with other MongoDB drivers specify the UUID "
+ "representation to use. This will be changed to "
+ "uuid_representation=UNSPECIFIED in a future release.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ json_options = LEGACY_JSON_OPTIONS
+ return cls._from_son(json_util.loads(json_data, json_options=json_options))
def __expand_dynamic_values(self, name, value):
"""expand any dynamic values to their correct types / values"""
diff --git a/mongoengine/connection.py b/mongoengine/connection.py
index 13c4719ab..696f6ecb8 100644
--- a/mongoengine/connection.py
+++ b/mongoengine/connection.py
@@ -1,6 +1,7 @@
import pymongo
-from pymongo import (MongoClient, MongoReplicaSetClient, ReadPreference,
- uri_parser)
+import warnings
+from pymongo import MongoClient, ReadPreference, uri_parser
+from pymongo.common import _UUID_REPRESENTATIONS
__all__ = [
'DEFAULT_CONNECTION_NAME',
@@ -36,6 +37,7 @@ def register_connection(
slaves=None,
username=None,
password=None,
+ uuidrepresentation=None,
**kwargs
):
"""Add a connection.
@@ -84,7 +86,22 @@ def register_connection(
})
if "replicaSet" in host:
conn_settings['replicaSet'] = True
-
+ if "uuidrepresentation" in uri_dict:
+ uuidrepresentation = uri_dict.get('uuidrepresentation')
+
+ if uuidrepresentation is None:
+ warnings.warn(
+ "No uuidrepresentation is specified! Falling back to "
+ "'pythonLegacy' which is the default for pymongo 3.x. "
+ "For compatibility with other MongoDB drivers this should be "
+ "specified as 'standard' or '{java,csharp}Legacy' to work with "
+ "older drivers in those languages. This will be changed to "
+ "'unspecified' in a future release.",
+ DeprecationWarning,
+ stacklevel=3,
+ )
+
+ conn_settings['uuidrepresentation'] = uuidrepresentation or 'pythonLegacy'
conn_settings.update(kwargs)
_connection_settings[alias] = conn_settings
@@ -129,7 +146,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
conn_settings['slaves'] = slaves
conn_settings.pop('read_preference', None)
- connection_class = MongoClient
if 'replicaSet' in conn_settings:
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None)
# Discard port since it can't be used on MongoReplicaSetClient
@@ -137,10 +153,9 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
# Discard replicaSet if not base string
if not isinstance(conn_settings['replicaSet'], str):
conn_settings.pop('replicaSet', None)
- connection_class = MongoReplicaSetClient
try:
- _connections[alias] = connection_class(**conn_settings)
+ _connections[alias] = MongoClient(**conn_settings)
except Exception as e:
raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e))
return _connections[alias]
@@ -155,10 +170,6 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
conn = get_connection(alias)
conn_settings = _connection_settings[alias]
db = conn[conn_settings['name']]
- # Authenticate if necessary
- if conn_settings['username'] and conn_settings['password']:
- db.authenticate(conn_settings['username'],
- conn_settings['password'])
_dbs[alias] = db
return _dbs[alias]
diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py
index 2d7a8e69c..a7bbbe6d8 100644
--- a/mongoengine/context_managers.py
+++ b/mongoengine/context_managers.py
@@ -175,14 +175,14 @@ def __init__(self):
def __enter__(self):
""" On every with block we need to drop the profile collection. """
- self.db.set_profiling_level(0)
+ self.db.command({"profile": 0})
self.db.system.profile.drop()
- self.db.set_profiling_level(2)
+ self.db.command({"profile": 2})
return self
def __exit__(self, t, value, traceback):
""" Reset the profiling level. """
- self.db.set_profiling_level(0)
+ self.db.command({"profile": 0})
def __eq__(self, value):
""" == Compare querycounter. """
@@ -220,7 +220,7 @@ def __repr__(self):
def _get_count(self):
""" Get the number of queries. """
ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}}
- count = self.db.system.profile.find(ignore_query).count() - self.counter
+ count = self.db.system.profile.count_documents(filter=ignore_query) - self.counter
self.counter += 1
return count
diff --git a/mongoengine/document.py b/mongoengine/document.py
index 407409795..ce326ab2d 100644
--- a/mongoengine/document.py
+++ b/mongoengine/document.py
@@ -13,6 +13,7 @@
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
from mongoengine.context_managers import (set_write_concern, switch_db,
switch_collection)
+from mongoengine.pymongo_support import list_collection_names
__all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument',
'DynamicEmbeddedDocument', 'OperationError',
@@ -147,7 +148,9 @@ def _get_collection(cls):
max_size = cls._meta['max_size'] or 10000000 # 10MB default
max_documents = cls._meta['max_documents']
- if collection_name in db.collection_names():
+ if collection_name in list_collection_names(
+ db, include_system_collections=True
+ ):
cls._collection = db[collection_name]
# The collection already exists, check if its capped
# options match the specified capped options
diff --git a/mongoengine/fields.py b/mongoengine/fields.py
index c9f7bc1f1..4d138edc6 100644
--- a/mongoengine/fields.py
+++ b/mongoengine/fields.py
@@ -28,6 +28,7 @@
from .queryset import DO_NOTHING, QuerySet
from .document import Document, EmbeddedDocument
from .connection import get_db, DEFAULT_CONNECTION_NAME
+from pymongo import ReturnDocument
try:
from PIL import Image, ImageOps
@@ -1464,8 +1465,9 @@ def __init__(self, collection_name=None, db_alias=None, sequence_name=None,
self.collection_name = collection_name or self.COLLECTION_NAME
self.db_alias = db_alias or DEFAULT_CONNECTION_NAME
self.sequence_name = sequence_name
- self.value_decorator = (callable(value_decorator) and
- value_decorator or self.VALUE_DECORATOR)
+ self.value_decorator = (
+ value_decorator if callable(value_decorator) else self.VALUE_DECORATOR
+ )
return super(SequenceField, self).__init__(*args, **kwargs)
def generate(self):
@@ -1475,22 +1477,28 @@ def generate(self):
sequence_name = self.get_sequence_name()
sequence_id = "%s.%s" % (sequence_name, self.name)
collection = get_db(alias=self.db_alias)[self.collection_name]
- counter = collection.find_and_modify(query={"_id": sequence_id},
- update={"$inc": {"next": 1}},
- new=True,
- upsert=True)
- return self.value_decorator(counter['next'])
+
+ counter = collection.find_one_and_update(
+ filter={"_id": sequence_id},
+ update={"$inc": {"next": 1}},
+ return_document=ReturnDocument.AFTER,
+ upsert=True
+ )
+ return self.value_decorator(counter["next"])
def set_next_value(self, value):
"""Helper method to set the next sequence value"""
sequence_name = self.get_sequence_name()
sequence_id = "%s.%s" % (sequence_name, self.name)
collection = get_db(alias=self.db_alias)[self.collection_name]
- counter = collection.find_and_modify(query={"_id": sequence_id},
- update={"$set": {"next": value}},
- new=True,
- upsert=True)
- return self.value_decorator(counter['next'])
+
+ counter = collection.find_one_and_update(
+ filter={"_id": sequence_id},
+ update={"$set": {"next": value}},
+ return_document=ReturnDocument.AFTER,
+ upsert=True
+ )
+ return self.value_decorator(counter["next"])
def get_next_value(self):
"""Helper method to get the next value for previewing.
diff --git a/mongoengine/mongodb_support.py b/mongoengine/mongodb_support.py
new file mode 100644
index 000000000..f15a72c96
--- /dev/null
+++ b/mongoengine/mongodb_support.py
@@ -0,0 +1,24 @@
+"""
+Helper functions, constants, and types to aid with MongoDB version support
+"""
+
+from mongoengine.connection import get_connection
+
+# Constant that can be used to compare the version retrieved with
+# get_mongodb_version()
+MONGODB_34 = (3, 4)
+MONGODB_36 = (3, 6)
+MONGODB_42 = (4, 2)
+MONGODB_44 = (4, 4)
+MONGODB_50 = (5, 0)
+MONGODB_60 = (6, 0)
+MONGODB_70 = (7, 0)
+
+
+def get_mongodb_version():
+ """Return the version of the default connected mongoDB (first 2 digits)
+
+ :return: tuple(int, int)
+ """
+ version_list = get_connection().server_info()["versionArray"][:2] # e.g: (3, 2)
+ return tuple(version_list)
diff --git a/mongoengine/pymongo_support.py b/mongoengine/pymongo_support.py
new file mode 100644
index 000000000..565f09a81
--- /dev/null
+++ b/mongoengine/pymongo_support.py
@@ -0,0 +1,30 @@
+"""
+Helper functions, constants, and types to aid with PyMongo support.
+"""
+
+import pymongo
+from bson import binary, json_util
+
+PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])
+
+# This will be changed to UuidRepresentation.UNSPECIFIED in a future
+# (breaking) release.
+if PYMONGO_VERSION >= (4,):
+ LEGACY_JSON_OPTIONS = json_util.LEGACY_JSON_OPTIONS.with_options(
+ uuid_representation=binary.UuidRepresentation.PYTHON_LEGACY,
+ )
+else:
+ LEGACY_JSON_OPTIONS = json_util.DEFAULT_JSON_OPTIONS
+
+
+def list_collection_names(db, include_system_collections=False):
+ """Pymongo>3.7 deprecates collection_names in favour of list_collection_names"""
+ if PYMONGO_VERSION >= (3, 7):
+ collections = db.list_collection_names()
+ else:
+ collections = db.collection_names()
+
+ if not include_system_collections:
+ collections = [c for c in collections if not c.startswith("system.")]
+
+ return collections
diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py
index 4f92b0a1e..19ee97cb6 100644
--- a/mongoengine/queryset/queryset.py
+++ b/mongoengine/queryset/queryset.py
@@ -6,7 +6,7 @@
import warnings
import pymongo
-from bson import json_util
+from bson import SON, json_util
from bson.code import Code
from pymongo.collection import ReturnDocument
from pymongo.common import validate_read_preference
@@ -14,8 +14,10 @@
from mongoengine import signals
from mongoengine.common import _import_class
+from mongoengine.connection import get_db
from mongoengine.context_managers import set_read_write_concern, set_write_concern
from mongoengine.errors import InvalidQueryError, NotUniqueError, OperationError
+from mongoengine.pymongo_support import LEGACY_JSON_OPTIONS
from mongoengine.queryset import transform
from mongoengine.queryset.field_list import QueryFieldList
from mongoengine.queryset.visitor import Q, QNode
@@ -35,6 +37,7 @@
RE_TYPE = type(re.compile(''))
+PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])
class QuerySet(object):
"""A set of results returned from a query. Wraps a MongoDB cursor,
@@ -79,6 +82,13 @@ def __init__(self, document, collection):
self._hint = -1 # Using -1 as None is a valid value for hint
self._batch_size = None
+ # Hack - As people expect cursor[5:5] to return
+ # an empty result set. It's hard to do that right, though, because the
+ # server uses limit(0) to mean 'no limit'. So we set _empty
+ # in that case and check for it when iterating. We also unset
+ # it anytime we change _limit. Inspired by how it is done in pymongo.Cursor
+ self._empty = False
+
def __call__(self, q_obj=None, class_check=True, slave_okay=False,
read_preference=None, **query):
"""Filter the selected documents by calling the
@@ -177,6 +187,7 @@ def __getitem__(self, key):
"""Support skip and limit using getitem and slicing syntax.
"""
queryset = self.clone()
+ queryset._empty = False
# Slice provided
if isinstance(key, slice):
@@ -185,6 +196,8 @@ def __getitem__(self, key):
queryset._skip, queryset._limit = key.start, key.stop
if key.start and key.stop:
queryset._limit = key.stop - key.start
+ if queryset._limit == 0:
+ queryset._empty = True
except IndexError as err:
# PyMongo raises an error if key.start == key.stop, catch it,
# bin it, kill it.
@@ -194,6 +207,7 @@ def __getitem__(self, key):
queryset.limit(0)
queryset._skip = key.start
queryset._limit = key.stop - start
+ queryset._empty = True
return queryset
raise err
# Allow further QuerySet modifications to be performed
@@ -319,6 +333,9 @@ def first(self):
"""Retrieve the first object matching the query.
"""
queryset = self.clone()
+ if self._none or self._empty:
+ return None
+
try:
result = queryset[0]
except IndexError:
@@ -415,17 +432,39 @@ def count(self, with_limit_and_skip=True):
:meth:`skip` that has been applied to this cursor into account when
getting the count
"""
- if self._limit == 0:
- return 0
-
- if self._none:
+ # mimic the fact that setting .limit(0) in pymongo sets no limit
+ # https://www.mongodb.com/docs/manual/reference/method/cursor.limit/#zero-value
+ if (
+ self._limit == 0
+ and with_limit_and_skip is False
+ or self._none
+ or self._empty
+ ):
return 0
if with_limit_and_skip and self._len is not None:
return self._len
- count = self._cursor.count(with_limit_and_skip=with_limit_and_skip)
+
+ # TODO: evaluate utility of `estimated_document_count`
+
+ if PYMONGO_VERSION >= (3, 7):
+ options = {}
+
+ if with_limit_and_skip:
+ if self._limit is not None:
+ options["limit"] = self._limit
+ if self._skip is not None:
+ options["skip"] = self._skip
+ if self._hint not in (-1, None):
+ options["hint"] = self._hint
+
+ count = self._cursor.collection.count_documents(filter=self._query, **options)
+ else:
+ count = self._cursor.count(with_limit_and_skip=with_limit_and_skip)
+
if with_limit_and_skip:
self._len = count
+
return count
def delete(self, write_concern=None, _from_doc_delete=False):
@@ -505,6 +544,8 @@ def update(
if write_concern is None:
write_concern = {}
+ if self._none or self._empty:
+ return 0
queryset = self.clone()
query = queryset._query
@@ -587,6 +628,9 @@ def modify(self, upsert=False, full_response=False, remove=False, new=False, **u
raise OperationError(
"No update parameters, must either update or remove")
+ if self._none or self._empty:
+ return None
+
queryset = self.clone()
query = queryset._query
if not remove:
@@ -749,13 +793,16 @@ def limit(self, n):
:param n: the maximum number of objects to return
"""
+ # chesterton's fence:
+ if n == 0: n = 1
+
queryset = self.clone()
- if n == 0:
- queryset._cursor.limit(1)
- else:
- queryset._cursor.limit(n)
queryset._limit = n
- # Return self to allow chaining
+ queryset._empty = False
+
+ if queryset._cursor_obj:
+ queryset._cursor_obj.limit(n)
+
return queryset
def skip(self, n):
@@ -765,8 +812,11 @@ def skip(self, n):
:param n: the number of objects to skip before returning results
"""
queryset = self.clone()
- queryset._cursor.skip(n)
queryset._skip = n
+
+ if queryset._cursor_obj:
+ queryset._cursor_obj.skip(queryset._skip)
+
return queryset
def hint(self, index=None):
@@ -783,14 +833,22 @@ def hint(self, index=None):
.. versionadded:: 0.5
"""
queryset = self.clone()
- queryset._cursor.hint(index)
queryset._hint = index
+
+ # If a cursor object has already been created, apply the hint to it.
+ if queryset._cursor_obj:
+ queryset._cursor_obj.hint(queryset._hint)
+
return queryset
def batch_size(self, size):
queryset = self.clone()
- queryset._cursor.batch_size(size)
queryset._batch_size = size
+
+ # If a cursor object has already been created, apply the batch size to it.
+ if queryset._cursor_obj:
+ queryset._cursor_obj.batch_size(queryset._batch_size)
+
return queryset
def distinct(self, field, dereference=True):
@@ -1019,9 +1077,20 @@ def as_pymongo(self, coerce_types=False):
# JSON Helpers
- def to_json(self):
+ def to_json(self, *args, **kwargs):
"""Converts a queryset to JSON"""
- return json_util.dumps(self.as_pymongo())
+ if "json_options" not in kwargs:
+ warnings.warn(
+ "No 'json_options' are specified! Falling back to "
+ "LEGACY_JSON_OPTIONS with uuid_representation=PYTHON_LEGACY. "
+ "For use with other MongoDB drivers specify the UUID "
+ "representation to use. This will be changed to "
+ "uuid_representation=UNSPECIFIED in a future release.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ kwargs["json_options"] = LEGACY_JSON_OPTIONS
+ return json_util.dumps(self.as_pymongo(), *args, **kwargs)
def from_json(self, json_data):
"""Converts json data to unsaved objects"""
@@ -1080,14 +1149,14 @@ def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None,
if isinstance(map_f, Code):
map_f_scope = map_f.scope
map_f = str(map_f)
- map_f = Code(queryset._sub_js_fields(map_f), map_f_scope)
+ map_f = Code(queryset._sub_js_fields(map_f), map_f_scope or None)
reduce_f_scope = {}
if isinstance(reduce_f, Code):
reduce_f_scope = reduce_f.scope
reduce_f = str(reduce_f)
reduce_f_code = queryset._sub_js_fields(reduce_f)
- reduce_f = Code(reduce_f_code, reduce_f_scope)
+ reduce_f = Code(reduce_f_code, reduce_f_scope or None)
mr_args = {'query': queryset._query}
@@ -1097,7 +1166,7 @@ def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None,
finalize_f_scope = finalize_f.scope
finalize_f = str(finalize_f)
finalize_f_code = queryset._sub_js_fields(finalize_f)
- finalize_f = Code(finalize_f_code, finalize_f_scope)
+ finalize_f = Code(finalize_f_code, finalize_f_scope or None)
mr_args['finalize'] = finalize_f
if scope:
@@ -1107,16 +1176,57 @@ def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None,
mr_args['limit'] = limit
if output == 'inline' and not queryset._ordering:
- map_reduce_function = 'inline_map_reduce'
+ inline = True
+ mr_args['out'] = {'inline': 1}
else:
- map_reduce_function = 'map_reduce'
- mr_args['out'] = output
+ inline = False
+ if isinstance(output, str):
+ mr_args['out'] = output
+
+ elif isinstance(output, dict):
+ ordered_output = []
+
+ for part in ('replace', 'merge', 'reduce'):
+ value = output.get(part)
+ if value:
+ ordered_output.append((part, value))
+ break
+
+ else:
+ raise OperationError('actionData not specified for output')
- results = getattr(queryset._collection, map_reduce_function)(
- map_f, reduce_f, **mr_args)
+ db_alias = output.get('db_alias')
+ remaing_args = ['db', 'sharded', 'nonAtomic']
- if map_reduce_function == 'map_reduce':
- results = results.find()
+ if db_alias:
+ ordered_output.append(('db', get_db(db_alias).name))
+ del remaing_args[0]
+
+ for part in remaing_args:
+ value = output.get(part)
+ if value:
+ ordered_output.append((part, value))
+
+ mr_args['out'] = SON(ordered_output)
+
+ db = queryset._document._get_db()
+ result = db.command(
+ {
+ 'mapReduce': queryset._document._get_collection_name(),
+ 'map': map_f,
+ 'reduce': reduce_f,
+ **mr_args,
+ }
+ )
+
+ if inline:
+ results = result['results']
+ else:
+ if isinstance(result['result'], str):
+ results = db[result['result']].find()
+ else:
+ info = result['result']
+ results = db.client[info['db']][info['collection']].find()
if queryset._ordering:
results = results.sort(queryset._ordering)
@@ -1167,7 +1277,7 @@ def exec_js(self, code, *fields, **options):
code = Code(code, scope=scope)
db = queryset._document._get_db()
- return db.eval(code, *fields)
+ return db.command("eval", code, args=fields).get("retval")
def where(self, where_clause):
"""Filter ``QuerySet`` results with a ``$where`` clause (a Javascript
diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py
index 6e8b2a12a..e9951bf46 100644
--- a/mongoengine/queryset/transform.py
+++ b/mongoengine/queryset/transform.py
@@ -230,8 +230,14 @@ def update(_doc_cls=None, **update):
value = {key: value}
elif op == 'addToSet' and isinstance(value, list):
value = {key: {"$each": value}}
+ elif op == "pushAll":
+ op = 'push' # convert to non-deprecated keyword
+ if not isinstance(value, (set, tuple, list)):
+ value = [value]
+ value = {key: {'$each': value}}
else:
value = {key: value}
+
key = '$' + op
if key not in mongo_update:
diff --git a/requirements.txt b/requirements.txt
index ddff5fb23..71bf46d3b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1 +1 @@
-pymongo==3.12.3
+pymongo>=3.13
diff --git a/setup.cfg b/setup.cfg
index c3df01801..5cc2b4752 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,10 +1,2 @@
-[nosetests]
-verbosity = 3
-detailed-errors = 1
-#with-coverage = 1
-#cover-erase = 1
-#cover-html = 1
-#cover-html-dir = ../htmlcov
-#cover-package = mongoengine
-where = tests
-#tests = document/__init__.py
+[tool:pytest]
+testpaths=tests
diff --git a/setup.py b/setup.py
index 0391ac00b..4a7086011 100644
--- a/setup.py
+++ b/setup.py
@@ -38,10 +38,12 @@ def get_version(version_tuple):
'Operating System :: OS Independent',
'Programming Language :: Python',
"Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.1",
- "Programming Language :: Python :: 3.2",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
'Topic :: Database',
'Topic :: Software Development :: Libraries :: Python Modules',
@@ -65,7 +67,7 @@ def get_version(version_tuple):
long_description=LONG_DESCRIPTION,
platforms=['any'],
classifiers=CLASSIFIERS,
- install_requires=['pymongo>=3.0,<3.14'],
- test_suite='nose.collector',
+ install_requires=['pymongo>=3.13.0'],
+ tests_require=['pytest'],
**extra_opts
)
diff --git a/test-requirements.txt b/test-requirements.txt
index 92ef961a4..aa732d848 100644
--- a/test-requirements.txt
+++ b/test-requirements.txt
@@ -1,5 +1,4 @@
-r requirements.txt
pytest
-nose
coverage
blinker
diff --git a/tests/document/test_class_methods.py b/tests/document/test_class_methods.py
index c78703fce..ddba1d9ae 100644
--- a/tests/document/test_class_methods.py
+++ b/tests/document/test_class_methods.py
@@ -6,7 +6,7 @@
from mongoengine import *
from mongoengine.queryset import NULLIFY, PULL
-from mongoengine.connection import get_db
+from mongoengine.pymongo_support import list_collection_names
__all__ = ("ClassMethodsTest", )
@@ -28,9 +28,7 @@ class Person(Document):
self.Person = Person
def tearDown(self):
- for collection in self.db.collection_names():
- if 'system.' in collection:
- continue
+ for collection in list_collection_names(self.db):
self.db.drop_collection(collection)
def test_definition(self):
@@ -67,10 +65,10 @@ def test_drop_collection(self):
"""
collection_name = 'person'
self.Person(name='Test').save()
- self.assertTrue(collection_name in self.db.collection_names())
+ self.assertTrue(collection_name in list_collection_names(self.db))
self.Person.drop_collection()
- self.assertFalse(collection_name in self.db.collection_names())
+ self.assertFalse(collection_name in list_collection_names(self.db))
def test_register_delete_rule(self):
"""Ensure that register delete rule adds a delete rule to the document
@@ -219,7 +217,7 @@ class Person(Document):
meta = {'collection': collection_name}
Person(name="Test User").save()
- self.assertTrue(collection_name in self.db.collection_names())
+ self.assertTrue(collection_name in list_collection_names(self.db))
user_obj = self.db[collection_name].find_one()
self.assertEqual(user_obj['name'], "Test User")
@@ -228,7 +226,7 @@ class Person(Document):
self.assertEqual(user_obj.name, "Test User")
Person.drop_collection()
- self.assertFalse(collection_name in self.db.collection_names())
+ self.assertFalse(collection_name in list_collection_names(self.db))
def test_collection_name_and_primary(self):
"""Ensure that a collection with a specified name may be used.
diff --git a/tests/document/test_delta.py b/tests/document/test_delta.py
index 355717fb7..73d24c9ed 100644
--- a/tests/document/test_delta.py
+++ b/tests/document/test_delta.py
@@ -5,6 +5,7 @@
from mongoengine import *
from mongoengine.connection import get_db
+from mongoengine.pymongo_support import list_collection_names
__all__ = ("DeltaTest",)
@@ -26,9 +27,7 @@ class Person(Document):
self.Person = Person
def tearDown(self):
- for collection in self.db.collection_names():
- if 'system.' in collection:
- continue
+ for collection in list_collection_names(self.db):
self.db.drop_collection(collection)
def test_delta(self):
diff --git a/tests/document/test_indexes.py b/tests/document/test_indexes.py
index 23e29078e..05a96d607 100644
--- a/tests/document/test_indexes.py
+++ b/tests/document/test_indexes.py
@@ -5,13 +5,15 @@
import os
import pymongo
+import pytest
-from nose.plugins.skip import SkipTest
from datetime import datetime
from mongoengine import *
from mongoengine.connection import get_db, get_connection
from pymongo.errors import OperationFailure
+from mongoengine.pymongo_support import list_collection_names
+from mongoengine.mongodb_support import get_mongodb_version, MONGODB_42
__all__ = ("IndexesTest", )
@@ -33,9 +35,7 @@ class Person(Document):
self.Person = Person
def tearDown(self):
- for collection in self.db.collection_names():
- if 'system.' in collection:
- continue
+ for collection in list_collection_names(self.db):
self.db.drop_collection(collection)
def test_indexes_document(self):
@@ -579,15 +579,6 @@ class Log(Document):
}
Log.drop_collection()
-
- if pymongo.version_tuple[0] < 2 and pymongo.version_tuple[1] < 3:
- raise SkipTest('pymongo needs to be 2.3 or higher for this test')
-
- connection = get_connection()
- version_array = connection.server_info()['versionArray']
- if version_array[0] < 2 and version_array[1] < 2:
- raise SkipTest('MongoDB needs to be 2.2 or higher for this test')
-
Log.ensure_indexes()
info = Log.objects._collection.index_information()
self.assertEqual(3600,
@@ -749,6 +740,47 @@ class TestChildDoc(TestDoc):
}
})
+ def test_covered_index(self):
+ """Ensure that covered indexes can be used"""
+
+ class Test(Document):
+ a = IntField()
+ b = IntField()
+
+ meta = {"indexes": ["a"], "allow_inheritance": False, "auto_create_index": True}
+
+ Test.drop_collection()
+
+ obj = Test(a=1)
+ obj.save()
+
+ # Need to be explicit about covered indexes as mongoDB doesn't know if
+ # the documents returned might have more keys in that here.
+ query_plan = Test.objects(id=obj.id).exclude("a").explain()
+ assert (
+ query_plan["queryPlanner"]["winningPlan"]["inputStage"]["stage"] == "IDHACK"
+ )
+
+ query_plan = Test.objects(id=obj.id).only("id").explain()
+ assert (
+ query_plan["queryPlanner"]["winningPlan"]["inputStage"]["stage"] == "IDHACK"
+ )
+
+ mongo_db = get_mongodb_version()
+ query_plan = Test.objects(a=1).only("a").exclude("id").explain()
+ assert (
+ query_plan["queryPlanner"]["winningPlan"]["inputStage"]["stage"] == "IXSCAN"
+ )
+
+ PROJECTION_STR = "PROJECTION" if mongo_db < MONGODB_42 else "PROJECTION_COVERED"
+ assert query_plan["queryPlanner"]["winningPlan"]["stage"] == PROJECTION_STR
+
+ query_plan = Test.objects(a=1).explain()
+ assert (
+ query_plan["queryPlanner"]["winningPlan"]["inputStage"]["stage"] == "IXSCAN"
+ )
+
+ assert query_plan.get("queryPlanner").get("winningPlan").get("stage") == "FETCH"
if __name__ == '__main__':
unittest.main()
diff --git a/tests/document/test_inheritance.py b/tests/document/test_inheritance.py
index 8e8508c6c..f7dfb2851 100644
--- a/tests/document/test_inheritance.py
+++ b/tests/document/test_inheritance.py
@@ -12,6 +12,7 @@
from mongoengine.connection import get_db
from mongoengine.fields import (BooleanField, GenericReferenceField,
IntField, StringField)
+from mongoengine.pymongo_support import list_collection_names
__all__ = ('InheritanceTest', )
@@ -23,9 +24,7 @@ def setUp(self):
self.db = get_db()
def tearDown(self):
- for collection in self.db.collection_names():
- if 'system.' in collection:
- continue
+ for collection in list_collection_names(self.db):
self.db.drop_collection(collection)
def test_superclasses(self):
diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py
index f47c01262..111c2ed10 100644
--- a/tests/document/test_instance.py
+++ b/tests/document/test_instance.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
import sys
+
+from mongoengine.mongodb_support import MONGODB_34, get_mongodb_version
sys.path[0:0] = [""]
import bson
@@ -20,6 +22,7 @@
from mongoengine.base import get_document
from mongoengine.context_managers import switch_db, query_counter
from mongoengine import signals
+from mongoengine.pymongo_support import list_collection_names
TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__),
'../fields/mongoengine.png')
@@ -44,9 +47,7 @@ class Person(Document):
self.Person = Person
def tearDown(self):
- for collection in self.db.collection_names():
- if 'system.' in collection:
- continue
+ for collection in list_collection_names(self.db):
self.db.drop_collection(collection)
def test_capped_collection(self):
@@ -382,10 +383,9 @@ class Animal(Document):
query_op = q.db.system.profile.find_one({
'ns': 'mongoenginetest.animal'
})
- self.assertEqual(
- set(query_op['query']['filter'].keys()),
- set(['_id', 'superphylum'])
- )
+ # MongoDB 5.0+ uses 'command' instead of 'query'
+ cmd_query = query_op.get('command') or query_op.get('query')
+ assert set(cmd_query['filter'].keys()) == {'_id', 'superphylum'}
Animal.drop_collection()
@@ -402,13 +402,18 @@ class Animal(Document):
doc = Animal(is_mammal=True, name='Dog')
doc.save()
+ mongo_db = get_mongodb_version()
+
with query_counter() as q:
doc.name = 'Cat'
doc.save()
query_op = q.db.system.profile.find({ 'ns': 'mongoenginetest.animal' })[0]
- self.assertEqual(query_op['op'], 'update')
- self.assertEqual(set(query_op['query'].keys()), set(['_id', 'is_mammal']))
+ # MongoDB 5.0+ uses 'command' instead of 'query'
+ if mongo_db <= MONGODB_34:
+ assert set(query_op['query'].keys()) == {'_id', 'is_mammal'}, str(cmd_query)
+ else:
+ assert set(query_op['command']['q'].keys()) == {'_id', 'is_mammal'}, str(cmd_query)
Animal.drop_collection()
@@ -1975,16 +1980,16 @@ class Person(Document):
person = Person(name="name", age=10, job=job)
from pymongo.collection import Collection
- orig_update = Collection.update
+ orig_update_one = Collection.update_one
try:
- def fake_update(*args, **kwargs):
+ def fake_update_one(*args, **kwargs):
self.fail("Unexpected update for %s" % args[0].name)
- return orig_update(*args, **kwargs)
+ return orig_update_one(*args, **kwargs)
- Collection.update = fake_update
+ Collection.update_one = fake_update_one
person.save()
finally:
- Collection.update = orig_update
+ Collection.update_one = orig_update_one
def test_db_alias_tests(self):
""" DB Alias tests """
diff --git a/tests/document/test_json_serialisation.py b/tests/document/test_json_serialisation.py
index 1f2d5c888..1e8a4a733 100644
--- a/tests/document/test_json_serialisation.py
+++ b/tests/document/test_json_serialisation.py
@@ -4,7 +4,6 @@
import unittest
import uuid
-from nose.plugins.skip import SkipTest
from datetime import datetime
from bson import ObjectId
@@ -35,9 +34,6 @@ class Doc(Document):
def test_json_complex(self):
- if pymongo.version_tuple[0] <= 2 and pymongo.version_tuple[1] <= 3:
- raise SkipTest("Need pymongo 2.4 as has a fix for DBRefs")
-
class EmbeddedDoc(EmbeddedDocument):
pass
diff --git a/tests/fields/test_file.py b/tests/fields/test_file.py
index 30dbf15b3..460392eda 100644
--- a/tests/fields/test_file.py
+++ b/tests/fields/test_file.py
@@ -9,16 +9,11 @@
import gridfs
-from nose.plugins.skip import SkipTest
from mongoengine import *
from mongoengine.connection import get_db
from mongoengine.python_support import PY3, b, StringIO
+from tests.utils import requires_pil
-try:
- from PIL import Image
- HAS_PIL = True
-except ImportError:
- HAS_PIL = False
TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png')
TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png')
@@ -261,10 +256,8 @@ class TestFile(Document):
test_file = TestFile()
self.assertFalse(test_file.the_file in [{"test": 1}])
+ @requires_pil
def test_image_field(self):
- if not HAS_PIL:
- raise SkipTest('PIL not installed')
-
class TestImage(Document):
image = ImageField()
@@ -295,10 +288,8 @@ class TestImage(Document):
t.image.delete()
+ @requires_pil
def test_image_field_reassigning(self):
- if not HAS_PIL:
- raise SkipTest('PIL not installed')
-
class TestFile(Document):
the_file = ImageField()
TestFile.drop_collection()
@@ -311,10 +302,8 @@ class TestFile(Document):
test_file.save()
self.assertEqual(test_file.the_file.size, (45, 101))
+ @requires_pil
def test_image_field_resize(self):
- if not HAS_PIL:
- raise SkipTest('PIL not installed')
-
class TestImage(Document):
image = ImageField(size=(185, 37))
@@ -334,10 +323,8 @@ class TestImage(Document):
t.image.delete()
+ @requires_pil
def test_image_field_resize_force(self):
- if not HAS_PIL:
- raise SkipTest('PIL not installed')
-
class TestImage(Document):
image = ImageField(size=(185, 37, True))
@@ -357,10 +344,8 @@ class TestImage(Document):
t.image.delete()
+ @requires_pil
def test_image_field_thumbnail(self):
- if not HAS_PIL:
- raise SkipTest('PIL not installed')
-
class TestImage(Document):
image = ImageField(thumbnail_size=(92, 18))
@@ -433,11 +418,9 @@ class TestFile(Document):
self.assertEqual(putfile, copy.copy(putfile))
self.assertEqual(putfile, copy.deepcopy(putfile))
+ @requires_pil
def test_get_image_by_grid_id(self):
- if not HAS_PIL:
- raise SkipTest('PIL not installed')
-
class TestImage(Document):
image1 = ImageField()
diff --git a/tests/queryset/test_modify.py b/tests/queryset/test_modify.py
new file mode 100644
index 000000000..b726c7c8a
--- /dev/null
+++ b/tests/queryset/test_modify.py
@@ -0,0 +1,146 @@
+import unittest
+import pytest
+
+from mongoengine import (
+ Document,
+ IntField,
+ ListField,
+ StringField,
+ connect,
+)
+
+
+class Doc(Document):
+ id = IntField(primary_key=True)
+ value = IntField()
+
+
+class TestFindAndModify(unittest.TestCase):
+ def setUp(self):
+ connect(db="mongoenginetest")
+ Doc.drop_collection()
+
+ def _assert_db_equal(self, docs):
+ assert list(Doc._collection.find().sort("id")) == docs
+
+ def test_modify(self):
+ Doc(id=0, value=0).save()
+ doc = Doc(id=1, value=1).save()
+
+ old_doc = Doc.objects(id=1).modify(set__value=-1)
+ assert old_doc.to_json() == doc.to_json()
+ self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])
+
+ def test_modify_with_new(self):
+ Doc(id=0, value=0).save()
+ doc = Doc(id=1, value=1).save()
+
+ new_doc = Doc.objects(id=1).modify(set__value=-1, new=True)
+ doc.value = -1
+ assert new_doc.to_json() == doc.to_json()
+ self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])
+
+ def test_modify_not_existing(self):
+ Doc(id=0, value=0).save()
+ assert Doc.objects(id=1).modify(set__value=-1) is None
+ self._assert_db_equal([{"_id": 0, "value": 0}])
+
+ def test_modify_with_upsert(self):
+ Doc(id=0, value=0).save()
+ old_doc = Doc.objects(id=1).modify(set__value=1, upsert=True)
+ assert old_doc is None
+ self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}])
+
+ def test_modify_with_upsert_existing(self):
+ Doc(id=0, value=0).save()
+ doc = Doc(id=1, value=1).save()
+
+ old_doc = Doc.objects(id=1).modify(set__value=-1, upsert=True)
+ assert old_doc.to_json() == doc.to_json()
+ self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])
+
+ def test_modify_with_upsert_with_new(self):
+ Doc(id=0, value=0).save()
+ new_doc = Doc.objects(id=1).modify(upsert=True, new=True, set__value=1)
+ assert new_doc.to_mongo() == {"_id": 1, "value": 1}
+ self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}])
+
+ def test_modify_with_remove(self):
+ Doc(id=0, value=0).save()
+ doc = Doc(id=1, value=1).save()
+
+ old_doc = Doc.objects(id=1).modify(remove=True)
+ assert old_doc.to_json() == doc.to_json()
+ self._assert_db_equal([{"_id": 0, "value": 0}])
+
+ def test_find_and_modify_with_remove_not_existing(self):
+ Doc(id=0, value=0).save()
+ assert Doc.objects(id=1).modify(remove=True) is None
+ self._assert_db_equal([{"_id": 0, "value": 0}])
+
+ def test_modify_with_order_by(self):
+ Doc(id=0, value=3).save()
+ Doc(id=1, value=2).save()
+ Doc(id=2, value=1).save()
+ doc = Doc(id=3, value=0).save()
+
+ old_doc = Doc.objects().order_by("-id").modify(set__value=-1)
+ assert old_doc.to_json() == doc.to_json()
+ self._assert_db_equal(
+ [
+ {"_id": 0, "value": 3},
+ {"_id": 1, "value": 2},
+ {"_id": 2, "value": 1},
+ {"_id": 3, "value": -1},
+ ]
+ )
+
+ def test_modify_with_fields(self):
+ Doc(id=0, value=0).save()
+ Doc(id=1, value=1).save()
+
+ old_doc = Doc.objects(id=1).only("id").modify(set__value=-1)
+ assert old_doc.to_mongo() == {"_id": 1}
+ self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])
+
+ def test_modify_with_push(self):
+ class BlogPost(Document):
+ tags = ListField(StringField())
+
+ BlogPost.drop_collection()
+
+ blog = BlogPost.objects.create()
+
+ # Push a new tag via modify with new=False (default).
+ BlogPost(id=blog.id).modify(push__tags="code")
+ assert blog.tags == []
+ blog.reload()
+ assert blog.tags == ["code"]
+
+ # Push a new tag via modify with new=True.
+ blog = BlogPost.objects(id=blog.id).modify(push__tags="java", new=True)
+ assert blog.tags == ["code", "java"]
+
+
+ @pytest.mark.skip("op__x__n not supported")
+ def test_modify_with_push_and_index(self):
+ # This continues to be unsupported with mongomallard.
+ class BlogPost(Document):
+ tags = ListField(StringField())
+
+ BlogPost.drop_collection()
+
+ blog = BlogPost.objects.create()
+
+ # Push a new tag with a positional argument.
+ blog = BlogPost.objects(id=blog.id).modify(push__tags__0="python", new=True)
+ assert blog.tags == ["python", "code", "java"]
+
+ # Push multiple new tags with a positional argument.
+ blog = BlogPost.objects(id=blog.id).modify(
+ push__tags__1=["go", "rust"], new=True
+ )
+ assert blog.tags == ["python", "go", "rust", "code", "java"]
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py
index 922ec7f3c..aa8727940 100644
--- a/tests/queryset/test_queryset.py
+++ b/tests/queryset/test_queryset.py
@@ -1,14 +1,13 @@
import sys
-
sys.path[0:0] = [""]
+import pytest
import unittest
import uuid
from datetime import datetime, timedelta
import pymongo
-from bson import ObjectId
-from nose.plugins.skip import SkipTest
+from bson import DBRef, ObjectId
from pymongo.errors import ConfigurationError
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
@@ -21,6 +20,12 @@
from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned,
QuerySet, QuerySetManager, queryset_manager)
+from tests.utils import (
+ requires_mongodb_gte_42,
+ requires_mongodb_gte_44,
+ requires_mongodb_lt_42,
+)
+
__all__ = ("QuerySetTest",)
@@ -140,6 +145,165 @@ def test_find(self):
self.assertEqual("[, ]", "%s" % self.Person.objects[1:3])
self.assertEqual("[, ]", "%s" % self.Person.objects[51:53])
+ def test_slicing_sets_empty_limit_skip(self):
+ self.Person.objects.insert(
+ [self.Person(name=f"User {i}", age=i) for i in range(5)],
+ load_bulk=False,
+ )
+
+ self.Person.objects.create(name="User B", age=30)
+ self.Person.objects.create(name="User C", age=40)
+
+ qs = self.Person.objects()[1:2]
+ assert (qs._empty, qs._skip, qs._limit) == (False, 1, 1)
+ assert len(list(qs)) == 1
+
+ # Test edge case of [1:1] which should return nothing
+ # and require a hack so that it doesn't clash with limit(0)
+ qs = self.Person.objects()[1:1]
+ assert (qs._empty, qs._skip, qs._limit) == (True, 1, 0)
+ assert len(list(qs)) == 0
+
+ qs2 = qs[1:5] # Make sure that further slicing resets _empty
+ assert (qs2._empty, qs2._skip, qs2._limit) == (False, 1, 4)
+ assert len(list(qs2)) == 4
+
+ def test_limit_0_returns_all_documents(self):
+ self.Person.objects.create(name="User A", age=20)
+ self.Person.objects.create(name="User B", age=30)
+
+ n_docs = self.Person.objects().count()
+
+ persons = list(self.Person.objects().limit(0))
+ assert len(persons) == 1
+ assert n_docs == 2
+
+ def test_limit(self):
+ """Ensure that QuerySet.limit works as expected."""
+ user_a = self.Person.objects.create(name="User A", age=20)
+ _ = self.Person.objects.create(name="User B", age=30)
+
+ # Test limit on a new queryset
+ people = list(self.Person.objects.limit(1))
+ assert len(people) == 1
+ assert people[0] == user_a
+
+ # Test limit on an existing queryset
+ people = self.Person.objects
+ assert len(people) == 2
+ people2 = people.limit(1)
+ assert len(people) == 2
+ assert len(people2) == 1
+ assert people2[0] == user_a
+
+ # Test limit with 0 as parameter
+ people = self.Person.objects.limit(0)
+ assert people.count(with_limit_and_skip=True) == 1
+ assert len(people) == 1
+
+ # Test chaining of only after limit
+ person = self.Person.objects().limit(1).only("name").first()
+ assert person == user_a
+ assert person.name == "User A"
+ assert person.age is None
+
+ def test_skip(self):
+ """Ensure that QuerySet.skip works as expected."""
+ user_a = self.Person.objects.create(name="User A", age=20)
+ user_b = self.Person.objects.create(name="User B", age=30)
+
+ # Test skip on a new queryset
+ people = list(self.Person.objects.skip(0))
+ assert len(people) == 2
+ assert people[0] == user_a
+ assert people[1] == user_b
+
+ people = list(self.Person.objects.skip(1))
+ assert len(people) == 1
+ assert people[0] == user_b
+
+ # Test skip on an existing queryset
+ people = self.Person.objects
+ assert len(people) == 2
+ people2 = people.skip(1)
+ assert len(people) == 2
+ assert len(people2) == 1
+ assert people2[0] == user_b
+
+ # Test chaining of only after skip
+ person = self.Person.objects().skip(1).only("name").first()
+ assert person == user_b
+ assert person.name == "User B"
+ assert person.age is None
+
+ def test___getitem___invalid_index(self):
+ """Ensure slicing a queryset works as expected."""
+ with pytest.raises(AttributeError):
+ self.Person.objects()["a"]
+
+ def test_slice(self):
+ """Ensure slicing a queryset works as expected."""
+ user_a = self.Person.objects.create(name="User A", age=20)
+ user_b = self.Person.objects.create(name="User B", age=30)
+ user_c = self.Person.objects.create(name="User C", age=40)
+
+ # Test slice limit
+ people = list(self.Person.objects[:2])
+ assert len(people) == 2
+ assert people[0] == user_a
+ assert people[1] == user_b
+
+ # Test slice skip
+ people = list(self.Person.objects[1:])
+ assert len(people) == 2
+ assert people[0] == user_b
+ assert people[1] == user_c
+
+ # Test slice limit and skip
+ people = list(self.Person.objects[1:2])
+ assert len(people) == 1
+ assert people[0] == user_b
+
+ # Test slice limit and skip on an existing queryset
+ people = self.Person.objects
+ assert len(people) == 3
+ people2 = people[1:2]
+ assert len(people2) == 1
+ assert people2[0] == user_b
+
+ # Test slice limit and skip cursor reset
+ qs = self.Person.objects[1:2]
+ # fetch then delete the cursor
+ qs._cursor
+ qs._cursor_obj = None
+ people = list(qs)
+ assert len(people) == 1
+ assert people[0].name == "User B"
+
+ # Test empty slice
+ people = list(self.Person.objects[1:1])
+ assert len(people) == 0
+
+ # Test slice out of range
+ people = list(self.Person.objects[80000:80001])
+ assert len(people) == 0
+
+ # Test larger slice __repr__
+ self.Person.objects.delete()
+ for i in range(55):
+ self.Person(name="A%s" % i, age=i).save()
+
+ assert self.Person.objects.count() == 55
+ assert "Person object" == "%s" % self.Person.objects[0]
+ assert (
+ "[, ]"
+ == "%s" % self.Person.objects[1:3]
+ )
+ assert (
+ "[, ]"
+ == "%s" % self.Person.objects[51:53]
+ )
+
def test_find_one(self):
"""Ensure that a query using find_one returns a valid result.
"""
@@ -255,6 +419,25 @@ class A(Document):
self.assertEqual(list(A.objects.none().all()), [])
self.assertEqual(A.objects.none().count(), 0)
+ # validate collection not empty
+ assert A.objects.count() == 1
+
+ # update operations
+ assert A.objects.none().update(s="1") == 0
+ assert A.objects.none().update_one(s="1") == 0
+ assert A.objects.none().modify(s="1") is None
+
+ # validate noting change by update operations
+ assert A.objects(s="1").count() == 0
+
+ # fetch queries
+ assert A.objects.none().first() is None
+ assert list(A.objects.none()) == []
+ assert list(A.objects.none().all()) == []
+ assert list(A.objects.none().limit(1)) == []
+ assert list(A.objects.none().skip(1)) == []
+ assert list(A.objects.none()[:5]) == []
+
def test_chaining(self):
class A(Document):
s = StringField()
@@ -1049,6 +1232,7 @@ class BlogPost(Document):
BlogPost.drop_collection()
+ @requires_mongodb_lt_42
def test_exec_js_query(self):
"""Ensure that queries are properly formed for use in exec_js.
"""
@@ -1086,6 +1270,7 @@ class BlogPost(Document):
BlogPost.drop_collection()
+ @requires_mongodb_lt_42
def test_exec_js_field_sub(self):
"""Ensure that field substitutions occur properly in exec_js functions.
"""
@@ -1806,9 +1991,8 @@ class BlogPost(Document):
results = BlogPost.objects.map_reduce(map_f, reduce_f, "myresults")
results = list(results)
- self.assertEqual(results[0].object, post1)
- self.assertEqual(results[1].object, post2)
- self.assertEqual(results[2].object, post3)
+ self.assertEqual({ result.object for result in results },
+ { post1, post2, post3 })
BlogPost.drop_collection()
@@ -2069,6 +2253,7 @@ class Person(Document):
freq = Person.objects.item_frequencies('city', normalize=True, map_reduce=True)
self.assertEqual(freq, {'CRB': 0.5, None: 0.5})
+ @requires_mongodb_lt_42
def test_item_frequencies_with_null_embedded(self):
class Data(EmbeddedDocument):
name = StringField()
@@ -2097,6 +2282,7 @@ class Person(Document):
ot = Person.objects.item_frequencies('extra.tag', map_reduce=True)
self.assertEqual(ot, {None: 1.0, 'friend': 1.0})
+ @requires_mongodb_lt_42
def test_item_frequencies_with_0_values(self):
class Test(Document):
val = IntField()
@@ -2111,6 +2297,7 @@ class Test(Document):
ot = Test.objects.item_frequencies('val', map_reduce=False)
self.assertEqual(ot, {0: 1})
+ @requires_mongodb_lt_42
def test_item_frequencies_with_False_values(self):
class Test(Document):
val = BooleanField()
@@ -2125,6 +2312,7 @@ class Test(Document):
ot = Test.objects.item_frequencies('val', map_reduce=False)
self.assertEqual(ot, {False: 1})
+ @requires_mongodb_lt_42
def test_item_frequencies_normalize(self):
class Test(Document):
val = IntField()
@@ -3247,8 +3435,6 @@ class Doc(Document):
self.assertEqual(doc_objects, Doc.objects.from_json(json_data))
def test_json_complex(self):
- if pymongo.version_tuple[0] <= 2 and pymongo.version_tuple[1] <= 3:
- raise SkipTest("Need pymongo 2.4 as has a fix for DBRefs")
class EmbeddedDoc(EmbeddedDocument):
pass
diff --git a/tests/test_connection.py b/tests/test_connection.py
index 11bf02084..056c506db 100644
--- a/tests/test_connection.py
+++ b/tests/test_connection.py
@@ -4,7 +4,8 @@
import datetime
import unittest
-import pymongo
+import pymongo.mongo_client
+import pymongo.database
from bson.tz_util import utc
import mongoengine.connection
@@ -42,9 +43,8 @@ def test_connect_uri(self):
c.admin.system.users.delete_many({})
c.mongoenginetest.system.users.delete_many({})
- c.admin.add_user("admin", "password")
- c.admin.authenticate("admin", "password")
- c.mongoenginetest.add_user("username", "password")
+ c.admin.command('createUser', 'admin', pwd='password', roles=['root'])
+ c.mongoenginetest.command('createUser', 'username', pwd='password', roles=['read'])
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
@@ -54,8 +54,11 @@ def test_connect_uri(self):
self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient))
db = get_db()
- self.assertTrue(isinstance(db, pymongo.database.Database))
- self.assertEqual(db.name, 'mongoenginetest')
+ assert isinstance(db, pymongo.database.Database)
+ assert db.name == "mongoenginetest"
+
+ c.admin.system.users.delete_many({})
+ c.mongoenginetest.system.users.delete_many({})
def test_register_connection(self):
"""Ensure that connections with different aliases may be registered.
diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py
index 8852e5f61..12bd7a04c 100644
--- a/tests/test_context_managers.py
+++ b/tests/test_context_managers.py
@@ -196,7 +196,7 @@ def test_query_counter(self):
self.assertEqual(0, q)
for i in range(1, 51):
- db.test.find({}).count()
+ db.test.count_documents({})
self.assertEqual(50, q)
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 000000000..c76b9f980
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,88 @@
+import functools
+import operator
+
+import pymongo
+import pytest
+
+from mongoengine.mongodb_support import get_mongodb_version
+
+PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])
+
+
+def get_as_pymongo(doc):
+ """Fetch the pymongo version of a certain Document"""
+ return doc.__class__.objects.as_pymongo().get(id=doc.id)
+
+
+def requires_mongodb_lt_42(func):
+ return _decorated_with_ver_requirement(func, (4, 2), oper=operator.lt)
+
+
+def requires_mongodb_gte_40(func):
+ return _decorated_with_ver_requirement(func, (4, 0), oper=operator.ge)
+
+
+def requires_mongodb_gte_42(func):
+ return _decorated_with_ver_requirement(func, (4, 2), oper=operator.ge)
+
+
+def requires_mongodb_gte_44(func):
+ return _decorated_with_ver_requirement(func, (4, 4), oper=operator.ge)
+
+
+def requires_mongodb_gte_50(func):
+ return _decorated_with_ver_requirement(func, (5, 0), oper=operator.ge)
+
+
+def requires_mongodb_gte_60(func):
+ return _decorated_with_ver_requirement(func, (6, 0), oper=operator.ge)
+
+
+def requires_mongodb_gte_70(func):
+ return _decorated_with_ver_requirement(func, (7, 0), oper=operator.ge)
+
+try:
+ from PIL import Image as _
+ HAS_PIL = True
+except ImportError:
+ HAS_PIL = False
+
+def requires_pil(func):
+ @functools.wraps(func)
+ def _inner(*args, **kwargs):
+ if HAS_PIL:
+ return func(*args, **kwargs)
+ else:
+ pytest.skip("PIL not installed")
+
+def _decorated_with_ver_requirement(func, mongo_version_req, oper):
+ """Return a MongoDB version requirement decorator.
+
+ The resulting decorator will skip the test if the current
+ MongoDB version doesn't match the provided version/operator.
+
+ For example, if you define a decorator like so:
+
+ def requires_mongodb_gte_36(func):
+ return _decorated_with_ver_requirement(
+ func, (3.6), oper=operator.ge
+ )
+
+ Then tests decorated with @requires_mongodb_gte_36 will be skipped if
+ ran against MongoDB < v3.6.
+
+ :param mongo_version_req: The mongodb version requirement (tuple(int, int))
+ :param oper: The operator to apply (e.g. operator.ge)
+ """
+
+ @functools.wraps(func)
+ def _inner(*args, **kwargs):
+ mongodb_v = get_mongodb_version()
+ if oper(mongodb_v, mongo_version_req):
+ return func(*args, **kwargs)
+ else:
+ pretty_version = ".".join(str(n) for n in mongo_version_req)
+ pytest.skip(f"Needs MongoDB {oper.__name__} v{pretty_version}")
+
+ return _inner
+