Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions burr/integrations/persisters/b_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

from pymongo import MongoClient

from burr.integrations.persisters.b_pymongo import (
_DRIVER_INFO,
)
from burr.integrations.persisters.b_pymongo import MongoDBBasePersister as PymongoPersister

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -50,6 +53,7 @@ def from_values(

if mongo_client_kwargs is None:
mongo_client_kwargs = {}
mongo_client_kwargs.setdefault("driver", _DRIVER_INFO)
client = MongoClient(uri, **mongo_client_kwargs)
return PymongoPersister(
client=client,
Expand All @@ -76,6 +80,7 @@ def __init__(
"""Initializes the MongoDBPersister class."""
if mongo_client_kwargs is None:
mongo_client_kwargs = {}
mongo_client_kwargs.setdefault("driver", _DRIVER_INFO)
client = MongoClient(uri, **mongo_client_kwargs)
super(MongoDBPersister, self).__init__(
client=client,
Expand Down
16 changes: 15 additions & 1 deletion burr/integrations/persisters/b_pymongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,21 @@
import json
import logging
from datetime import datetime, timezone
from importlib.metadata import version as get_version
from typing import Literal, Optional

from pymongo import MongoClient
from pymongo.driver_info import DriverInfo

from burr.core import persistence, state

try:
_VERSION = get_version("apache-burr")
except Exception:
_VERSION = None

_DRIVER_INFO = DriverInfo(name="Burr", version=_VERSION)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -69,6 +78,7 @@ def from_values(
"""Initializes the MongoDBBasePersister class."""
if mongo_client_kwargs is None:
mongo_client_kwargs = {}
mongo_client_kwargs.setdefault("driver", _DRIVER_INFO)
client = MongoClient(uri, **mongo_client_kwargs)
return cls(
client=client,
Expand All @@ -92,6 +102,8 @@ def __init__(
:param serde_kwargs: serializer/deserializer keyword arguments to pass to the state object
"""
self.client = client
if hasattr(client, "append_metadata"):
client.append_metadata(_DRIVER_INFO)
self.db = self.client[db_name]
self.collection = self.db[collection_name]
self.serde_kwargs = serde_kwargs or {}
Expand Down Expand Up @@ -215,7 +227,9 @@ def __getstate__(self) -> dict:
def __setstate__(self, state: dict):
connection_params = state.pop("connection_params")
# we assume MongoClient.
self.client = MongoClient(connection_params["uri"], connection_params["port"])
self.client = MongoClient(
connection_params["uri"], connection_params["port"], driver=_DRIVER_INFO
)
self.db = self.client[connection_params["db_name"]]
self.collection = self.db[connection_params["collection_name"]]
self.__dict__.update(state)
83 changes: 83 additions & 0 deletions tests/integrations/persisters/test_b_pymongo_driver_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Unit tests for MongoDB driver handshake metadata (no live DB required)."""

from unittest.mock import MagicMock, patch

from pymongo.driver_info import DriverInfo

from burr.integrations.persisters.b_pymongo import (
_DRIVER_INFO,
_VERSION,
MongoDBBasePersister,
)


def test_driver_info_name():
assert isinstance(_DRIVER_INFO, DriverInfo)
assert _DRIVER_INFO.name == "Burr"


def test_driver_info_version_matches_package():
assert _DRIVER_INFO.version == _VERSION


def test_from_values_passes_driver_info():
"""from_values() injects driver=_DRIVER_INFO into MongoClient."""
with patch("burr.integrations.persisters.b_pymongo.MongoClient") as mock_ctor:
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value=MagicMock())
mock_ctor.return_value = mock_client
MongoDBBasePersister.from_values(uri="mongodb://localhost:27017")
_, kwargs = mock_ctor.call_args
assert "driver" in kwargs
assert isinstance(kwargs["driver"], DriverInfo)
assert kwargs["driver"].name == "Burr"


def test_from_values_does_not_override_caller_driver():
"""from_values() preserves a driver value supplied via mongo_client_kwargs."""
custom = DriverInfo(name="MyApp", version="9.9")
with patch("burr.integrations.persisters.b_pymongo.MongoClient") as mock_ctor:
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value=MagicMock())
mock_ctor.return_value = mock_client
MongoDBBasePersister.from_values(
uri="mongodb://localhost:27017",
mongo_client_kwargs={"driver": custom},
)
_, kwargs = mock_ctor.call_args
assert kwargs["driver"] is custom


def test_init_calls_append_metadata_when_available():
"""__init__() calls append_metadata on a client that supports it."""
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value=MagicMock())
MongoDBBasePersister(client=mock_client)
mock_client.append_metadata.assert_called_once_with(_DRIVER_INFO)


def test_init_skips_append_metadata_when_absent():
"""__init__() does not raise when client lacks append_metadata (older driver)."""
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value=MagicMock())
# Simulate a client without append_metadata by deleting the attribute
del mock_client.append_metadata
# Should not raise
MongoDBBasePersister(client=mock_client)
Loading