Skip to content

Commit 69d60bb

Browse files
Merge pull request #1485 from phenobarbital/codex/add-async-driver-for-redpanda
feat(drivers): add async Redpanda driver using aiokafka
2 parents 228dbba + 8ab795e commit 69d60bb

3 files changed

Lines changed: 301 additions & 0 deletions

File tree

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Currently AsyncDB supports the following databases:
5656
* MongoDB (using motor and pymongo)
5757
* SQLAlchemy (requires sqlalchemy async (+3.14))
5858
* Oracle (requires oracledb)
59+
* Redpanda (Kafka-compatible, requires aiokafka)
5960

6061
### Quick Tutorial ###
6162

@@ -89,6 +90,7 @@ Every Driver has a simple name to call it:
8990
* mcache: aiomcache (Memcache)
9091
* odbc: aiodbc (ODBC)
9192
* oracle: oracle (oracledb)
93+
* redpanda: Redpanda/Kafka (aiokafka)
9294

9395
### Output Support ###
9496

asyncdb/drivers/redpanda.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
"""Redpanda async driver built on top of aiokafka."""
2+
3+
from __future__ import annotations
4+
5+
import json
6+
from typing import Any, Optional, Union
7+
8+
from ..exceptions import DriverError
9+
from .base import BaseDriver
10+
11+
12+
class redpanda(BaseDriver):
13+
"""Async Redpanda driver.
14+
15+
Redpanda is Kafka API compatible, so this driver uses aiokafka under the hood.
16+
"""
17+
18+
_provider = "redpanda"
19+
_syntax = "kafka"
20+
_dsn_template: str = "{host}:{port}"
21+
22+
def __init__(self, dsn: str = None, loop=None, params: Optional[dict] = None, **kwargs):
23+
params = params or {}
24+
self._topic = kwargs.pop("topic", params.get("topic"))
25+
self._group_id = kwargs.pop("group_id", params.get("group_id", "asyncdb-redpanda"))
26+
self._client_id = kwargs.pop("client_id", params.get("client_id", "asyncdb"))
27+
self._producer = None
28+
self._admin = None
29+
super(redpanda, self).__init__(dsn=dsn, loop=loop, params=params, **kwargs)
30+
31+
@staticmethod
32+
def _load_aiokafka():
33+
try:
34+
from aiokafka import AIOKafkaAdminClient, AIOKafkaConsumer, AIOKafkaProducer
35+
except ImportError as err:
36+
raise DriverError(
37+
"aiokafka is required for Redpanda support. Install with: pip install asyncdb[redpanda]"
38+
) from err
39+
return AIOKafkaProducer, AIOKafkaConsumer, AIOKafkaAdminClient
40+
41+
def _bootstrap_servers(self) -> str:
42+
if self._dsn:
43+
return self._dsn
44+
host = self._params.get("host", "127.0.0.1")
45+
port = self._params.get("port", 9092)
46+
return f"{host}:{port}"
47+
48+
async def connection(self, **kwargs):
49+
AIOKafkaProducer, _, _ = self._load_aiokafka()
50+
servers = kwargs.pop("bootstrap_servers", self._bootstrap_servers())
51+
try:
52+
self._producer = AIOKafkaProducer(
53+
loop=self._loop,
54+
bootstrap_servers=servers,
55+
client_id=self._client_id,
56+
**kwargs,
57+
)
58+
await self._producer.start()
59+
self._connection = self._producer
60+
self._connected = True
61+
return self
62+
except Exception as err:
63+
self._connected = False
64+
raise DriverError(f"Redpanda connection error: {err}") from err
65+
66+
async def close(self):
67+
try:
68+
if self._producer is not None:
69+
await self._producer.stop()
70+
if self._admin is not None:
71+
await self._admin.close()
72+
self._connected = False
73+
self._connection = None
74+
return True
75+
except Exception as err:
76+
raise DriverError(f"Error closing Redpanda connection: {err}") from err
77+
78+
async def use(self, database: str):
79+
self._topic = database
80+
81+
async def prepare(self, sentence: Union[str, list]) -> Any:
82+
self._prepared = sentence
83+
return sentence
84+
85+
def _encode_value(self, value: Any) -> bytes:
86+
if value is None:
87+
raise DriverError("Cannot publish an empty message")
88+
if isinstance(value, bytes):
89+
return value
90+
if isinstance(value, str):
91+
return value.encode("utf-8")
92+
return json.dumps(value).encode("utf-8")
93+
94+
async def execute(self, sentence: Any, *args, **kwargs) -> Optional[Any]:
95+
if not self._producer:
96+
await self.connection()
97+
topic = kwargs.pop("topic", self._topic)
98+
if not topic:
99+
raise DriverError("No topic selected. Use use(<topic>) or pass topic=<topic>.")
100+
101+
key = kwargs.pop("key", None)
102+
partition = kwargs.pop("partition", None)
103+
headers = kwargs.pop("headers", None)
104+
timestamp_ms = kwargs.pop("timestamp_ms", None)
105+
106+
payload = self._encode_value(sentence)
107+
encoded_key = self._encode_value(key) if key is not None else None
108+
try:
109+
metadata = await self._producer.send_and_wait(
110+
topic,
111+
payload,
112+
key=encoded_key,
113+
partition=partition,
114+
headers=headers,
115+
timestamp_ms=timestamp_ms,
116+
)
117+
result = {
118+
"topic": metadata.topic,
119+
"partition": metadata.partition,
120+
"offset": metadata.offset,
121+
"timestamp": metadata.timestamp,
122+
}
123+
return await self._serializer(result, None)
124+
except Exception as err:
125+
raise DriverError(f"Error publishing message to Redpanda: {err}") from err
126+
127+
async def execute_many(self, sentence: list, *args) -> Optional[Any]:
128+
results = []
129+
for message in sentence:
130+
response, error = await self.execute(message, *args)
131+
if error:
132+
return await self._serializer(results, error)
133+
results.append(response)
134+
return await self._serializer(results, None)
135+
136+
async def query(self, sentence: Union[str, list], **kwargs):
137+
_, AIOKafkaConsumer, _ = self._load_aiokafka()
138+
topic = sentence if isinstance(sentence, str) and sentence else kwargs.pop("topic", self._topic)
139+
if not topic:
140+
raise DriverError("No topic selected. Use use(<topic>) or pass topic=<topic>.")
141+
142+
timeout_ms = kwargs.pop("timeout_ms", 1000)
143+
max_records = kwargs.pop("max_records", 100)
144+
consumer = AIOKafkaConsumer(
145+
topic,
146+
loop=self._loop,
147+
bootstrap_servers=kwargs.pop("bootstrap_servers", self._bootstrap_servers()),
148+
group_id=kwargs.pop("group_id", self._group_id),
149+
client_id=self._client_id,
150+
auto_offset_reset=kwargs.pop("auto_offset_reset", "earliest"),
151+
enable_auto_commit=kwargs.pop("enable_auto_commit", False),
152+
**kwargs,
153+
)
154+
155+
try:
156+
await consumer.start()
157+
records = await consumer.getmany(timeout_ms=timeout_ms, max_records=max_records)
158+
messages = []
159+
for _, batch in records.items():
160+
for msg in batch:
161+
messages.append(
162+
{
163+
"topic": msg.topic,
164+
"partition": msg.partition,
165+
"offset": msg.offset,
166+
"timestamp": msg.timestamp,
167+
"key": msg.key.decode("utf-8") if msg.key else None,
168+
"value": msg.value.decode("utf-8") if msg.value else None,
169+
}
170+
)
171+
return await self._serializer(messages, None)
172+
except Exception as err:
173+
raise DriverError(f"Error consuming messages from Redpanda: {err}") from err
174+
finally:
175+
await consumer.stop()
176+
177+
fetch_all = query
178+
179+
async def queryrow(self, sentence: Union[str, list]):
180+
messages, error = await self.query(sentence, max_records=1)
181+
if error:
182+
return await self._serializer(None, error)
183+
row = messages[0] if messages else None
184+
return await self._serializer(row, None)
185+
186+
fetch_one = queryrow

tests/test_redpanda.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import pytest
2+
3+
from asyncdb import AsyncDB
4+
5+
6+
pytestmark = pytest.mark.asyncio
7+
8+
9+
class DummyMetadata:
10+
topic = "events"
11+
partition = 0
12+
offset = 7
13+
timestamp = 123456
14+
15+
16+
class DummyMessage:
17+
def __init__(self):
18+
self.topic = "events"
19+
self.partition = 0
20+
self.offset = 1
21+
self.timestamp = 777
22+
self.key = b"user-1"
23+
self.value = b'{"event":"created"}'
24+
25+
26+
class DummyProducer:
27+
def __init__(self, *args, **kwargs):
28+
self.started = False
29+
self.stopped = False
30+
self.kwargs = kwargs
31+
32+
async def start(self):
33+
self.started = True
34+
35+
async def stop(self):
36+
self.stopped = True
37+
38+
async def send_and_wait(self, *args, **kwargs):
39+
return DummyMetadata()
40+
41+
42+
class DummyConsumer:
43+
def __init__(self, *args, **kwargs):
44+
self.started = False
45+
self.stopped = False
46+
47+
async def start(self):
48+
self.started = True
49+
50+
async def stop(self):
51+
self.stopped = True
52+
53+
async def getmany(self, timeout_ms=0, max_records=1):
54+
return {"tp": [DummyMessage()]}
55+
56+
57+
class DummyAdmin:
58+
async def close(self):
59+
return True
60+
61+
62+
@pytest.fixture
63+
def patch_aiokafka(monkeypatch):
64+
from asyncdb.drivers.redpanda import redpanda
65+
66+
monkeypatch.setattr(
67+
redpanda,
68+
"_load_aiokafka",
69+
staticmethod(lambda: (DummyProducer, DummyConsumer, DummyAdmin)),
70+
)
71+
72+
73+
async def test_redpanda_connection_and_publish(patch_aiokafka):
74+
db = AsyncDB("redpanda", params={"host": "127.0.0.1", "port": 9092, "topic": "events"})
75+
await db.connection()
76+
assert db.is_connected() is True
77+
78+
result, error = await db.execute({"event": "created"})
79+
assert error is None
80+
assert result["topic"] == "events"
81+
assert result["offset"] == 7
82+
83+
await db.close()
84+
85+
86+
async def test_redpanda_query_and_queryrow(patch_aiokafka):
87+
db = AsyncDB("redpanda", params={"host": "127.0.0.1", "port": 9092, "topic": "events"})
88+
89+
records, error = await db.query("events")
90+
assert error is None
91+
assert len(records) == 1
92+
assert records[0]["key"] == "user-1"
93+
94+
row, error = await db.queryrow("events")
95+
assert error is None
96+
assert row["value"] == '{"event":"created"}'
97+
98+
99+
async def test_redpanda_execute_many(patch_aiokafka):
100+
db = AsyncDB("redpanda", params={"host": "127.0.0.1", "port": 9092, "topic": "events"})
101+
await db.connection()
102+
103+
messages = [
104+
{"event": "created", "id": 1},
105+
{"event": "updated", "id": 2},
106+
]
107+
108+
results, error = await db.execute_many(messages)
109+
assert error is None
110+
assert isinstance(results, list)
111+
assert len(results) == len(messages)
112+
113+
await db.close()

0 commit comments

Comments
 (0)