Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 41 additions & 67 deletions kafka/producer/transaction_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,39 +877,29 @@ def priority(self):
class InitProducerIdHandler(TxnRequestHandler):
def __init__(self, transaction_manager, transaction_timeout_ms, is_epoch_bump=False):
super().__init__(transaction_manager)

self._is_epoch_bump = is_epoch_bump
api_version = transaction_manager._api_version
# KIP-360 / InitProducerIdRequest v3+ (Kafka 2.5+) lets us resume
# an existing producer_id by bumping its epoch rather than allocating
# a fresh one. v3+ takes producer_id + epoch fields; on broker match,
# the broker returns (same producer_id, epoch+1).
if api_version >= (2, 5):
version = 3
elif api_version >= (2, 4):
version = 2
elif api_version >= (2, 0):
version = 1
else:
version = 0
max_version = 4
min_version = 0

if is_epoch_bump:
assert version >= 3, "KIP-360 epoch bump requires Kafka 2.5+ broker"
# KIP-360 / InitProducerIdRequest v3+ (Kafka 2.5+) lets us resume
# an existing producer_id by bumping its epoch rather than allocating
# a fresh one. v3+ takes producer_id + epoch fields; on broker match,
# the broker returns (same producer_id, epoch+1).
min_version = 3
producer_id = transaction_manager.producer_id_and_epoch.producer_id
producer_epoch = transaction_manager.producer_id_and_epoch.epoch
else:
producer_id = NO_PRODUCER_ID
producer_epoch = NO_PRODUCER_EPOCH

kwargs = {
'version': version,
'transactional_id': self.transactional_id,
'transaction_timeout_ms': transaction_timeout_ms,
}
if version >= 3:
kwargs['producer_id'] = producer_id
kwargs['producer_epoch'] = producer_epoch
self.request = InitProducerIdRequest(**kwargs)
self.request = InitProducerIdRequest(
transactional_id=self.transactional_id,
transaction_timeout_ms=transaction_timeout_ms,
producer_id=producer_id,
producer_epoch=producer_epoch,
max_version=max_version,
min_version=min_version)

@property
def priority(self):
Expand Down Expand Up @@ -961,20 +951,18 @@ class AddPartitionsToTxnHandler(TxnRequestHandler):
def __init__(self, transaction_manager, topic_partitions):
super().__init__(transaction_manager)

if transaction_manager._api_version >= (2, 7):
version = 2
elif transaction_manager._api_version >= (2, 0):
version = 1
else:
version = 0
topic_data = collections.defaultdict(list)
for tp in topic_partitions:
topic_data[tp.topic].append(tp.partition)
self.request = AddPartitionsToTxnRequest[version](

Topic = AddPartitionsToTxnRequest.AddPartitionsToTxnTopic
self.request = AddPartitionsToTxnRequest(
v3_and_below_transactional_id=self.transactional_id,
v3_and_below_producer_id=self.producer_id,
v3_and_below_producer_epoch=self.producer_epoch,
v3_and_below_topics=list(topic_data.items()))
v3_and_below_topics=[Topic(name=topic, partitions=partitions)
for topic, partitions in topic_data.items()],
max_version=3)

@property
def priority(self):
Expand Down Expand Up @@ -1056,19 +1044,16 @@ def __init__(self, transaction_manager, coord_type, coord_key):

self._coord_type = coord_type
self._coord_key = coord_key
if transaction_manager._api_version >= (2, 0):
version = 2
else:
version = 1
if coord_type == 'group':
coord_type_int8 = 0
elif coord_type == 'transaction':
coord_type_int8 = 1
else:
raise ValueError("Unrecognized coordinator type: %s" % (coord_type,))
self.request = FindCoordinatorRequest[version](
self.request = FindCoordinatorRequest(
key=coord_key,
key_type=coord_type_int8,
max_version=3,
)

@property
Expand Down Expand Up @@ -1109,18 +1094,12 @@ def handle_response(self, response):
class EndTxnHandler(TxnRequestHandler):
def __init__(self, transaction_manager, committed):
super().__init__(transaction_manager)

if self.transaction_manager._api_version >= (2, 7):
version = 2
elif self.transaction_manager._api_version >= (2, 0):
version = 1
else:
version = 0
self.request = EndTxnRequest[version](
self.request = EndTxnRequest(
transactional_id=self.transactional_id,
producer_id=self.producer_id,
producer_epoch=self.producer_epoch,
committed=committed)
committed=committed,
max_version=3)

@property
def priority(self):
Expand Down Expand Up @@ -1154,17 +1133,13 @@ def __init__(self, transaction_manager, consumer_group_id, offsets):

self.consumer_group_id = consumer_group_id
self.offsets = offsets
if self.transaction_manager._api_version >= (2, 7):
version = 2
elif self.transaction_manager._api_version >= (2, 0):
version = 1
else:
version = 0
self.request = AddOffsetsToTxnRequest[version](
self.request = AddOffsetsToTxnRequest(
transactional_id=self.transactional_id,
producer_id=self.producer_id,
producer_epoch=self.producer_epoch,
group_id=consumer_group_id)
group_id=self.consumer_group_id,
max_version=3,
)

@property
def priority(self):
Expand Down Expand Up @@ -1210,27 +1185,26 @@ def __init__(self, transaction_manager, consumer_group_id, offsets, result):
self.request = self._build_request()

def _build_request(self):
if self.transaction_manager._api_version >= (2, 1):
version = 2
elif self.transaction_manager._api_version >= (2, 0):
version = 1
else:
version = 0
Topic = TxnOffsetCommitRequest.TxnOffsetCommitRequestTopic
Partition = Topic.TxnOffsetCommitRequestPartition

topic_data = collections.defaultdict(list)
for tp, offset in self.offsets.items():
if version >= 2:
partition_data = (tp.partition, offset.offset, offset.leader_epoch, offset.metadata)
else:
partition_data = (tp.partition, offset.offset, offset.metadata)
topic_data[tp.topic].append(partition_data)
topic_data[tp.topic].append(Partition(
partition_index=tp.partition,
committed_offset=offset.offset,
committed_leader_epoch=offset.leader_epoch,
committed_metadata=offset.metadata))

return TxnOffsetCommitRequest[version](
return TxnOffsetCommitRequest(
transactional_id=self.transactional_id,
group_id=self.consumer_group_id,
producer_id=self.producer_id,
producer_epoch=self.producer_epoch,
topics=list(topic_data.items()))
topics=[Topic(name=topic, partitions=partitions)
for topic, partitions in topic_data.items()],
max_version=2,
)

@property
def priority(self):
Expand Down
16 changes: 0 additions & 16 deletions test/producer/test_transaction_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,22 +330,6 @@ def _fake_init_producer_id_response(self, error_code=0, producer_id=42, producer
producer_epoch=producer_epoch,
)

def test_v3_version_on_modern_broker(self):
_, handler = self._make_handler(api_version=(2, 5))
assert handler.request.version == 3

def test_v2_version_on_2_4_broker(self):
_, handler = self._make_handler(api_version=(2, 4))
assert handler.request.version == 2

def test_v1_version_on_2_0_broker(self):
_, handler = self._make_handler(api_version=(2, 0))
assert handler.request.version == 1

def test_v0_version_on_old_broker(self):
_, handler = self._make_handler(api_version=(0, 11))
assert handler.request.version == 0

def test_non_bump_request_has_no_producer_id(self):
_, handler = self._make_handler(api_version=(2, 5), is_epoch_bump=False)
# v3 request fields default to NO_PRODUCER_ID / NO_PRODUCER_EPOCH
Expand Down
91 changes: 89 additions & 2 deletions test/producer/test_transaction_manager_mock_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
AddOffsetsToTxnResponse,
AddPartitionsToTxnResponse,
EndTxnResponse,
InitProducerIdRequest,
InitProducerIdResponse,
ProduceRequest,
ProduceResponse,
Expand Down Expand Up @@ -123,8 +124,12 @@ def _pending_handlers(tm):


@pytest.fixture
def broker():
return MockBroker()
def broker(request):
"""Parametrizable broker version: ``@pytest.mark.parametrize("broker",
[(2, 4)], indirect=True)`` simulates an older broker so a handler's
request negotiates a lower wire version."""
broker_version = getattr(request, 'param', (4, 2))
return MockBroker(broker_version=broker_version)


@pytest.fixture
Expand Down Expand Up @@ -351,6 +356,88 @@ def test_idempotent_producer_init_then_bump(self, broker, client):
assert tm.producer_id_and_epoch.epoch == _PRODUCER_EPOCH + 1


# ---------------------------------------------------------------------------
# Wire-version negotiation for InitProducerIdRequest
# ---------------------------------------------------------------------------


class TestInitProducerIdHandlerWireVersion:
"""InitProducerIdHandler uses the modern construction style with
``min_version`` / ``max_version``; the wire version is picked by
the per-connection ``broker_version_data`` at send time, not by an
explicit version selector in the handler. These tests drive the request
over MockBroker and assert the captured ``api_version`` matches what we
expect for each broker generation."""

def _capture(self, captured, response):
"""Build a respond_fn callable that records the negotiated version."""
def fn(api_key, api_version, correlation_id, request_bytes):
captured['api_version'] = api_version
return response
return fn

@pytest.mark.parametrize("broker, expected_version", [
((0, 11), 0),
((2, 0), 1),
((2, 4), 2),
((2, 5), 3),
((4, 2), 4),
], indirect=['broker'])
def test_non_bump_negotiates_to_broker_max(self, broker, client,
expected_version):
from kafka.producer.transaction_manager import InitProducerIdHandler
tm = _make_manager(client)
tm._current_state = TransactionState.INITIALIZING
tm.set_producer_id_and_epoch(ProducerIdAndEpoch(-1, -1))
handler = InitProducerIdHandler(tm, transaction_timeout_ms=1000)
tm._enqueue_request(handler)

captured = {}
broker.respond_fn(InitProducerIdResponse, self._capture(
captured,
InitProducerIdResponse(throttle_time_ms=0, error_code=0,
producer_id=42, producer_epoch=0),
))

_, future = _dispatch_next(client, tm)
_poll_for_future(client, future)

assert captured['api_version'] == expected_version

@pytest.mark.parametrize("broker, expected_version", [
((2, 5), 3),
((2, 7), 4),
((4, 2), 4),
], indirect=['broker'])
def test_epoch_bump_negotiates_to_v3_or_higher(self, broker, client,
expected_version):
"""KIP-360 epoch-bump path sets ``min_version=3``; older brokers
(<2.5) lack v3 of InitProducerId and the request fails before reach-
ing the wire — covered separately in
``TestBumpProducerIdAndEpoch._supports_epoch_bump`` in
test_transaction_manager.py."""
from kafka.producer.transaction_manager import InitProducerIdHandler
tm = _make_manager(client)
# Seed the manager with a producer_id/epoch for the bump to capture.
tm._current_state = TransactionState.READY
tm.set_producer_id_and_epoch(ProducerIdAndEpoch(42, 7))
handler = InitProducerIdHandler(
tm, transaction_timeout_ms=1000, is_epoch_bump=True)
tm._enqueue_request(handler)

captured = {}
broker.respond_fn(InitProducerIdResponse, self._capture(
captured,
InitProducerIdResponse(throttle_time_ms=0, error_code=0,
producer_id=42, producer_epoch=8),
))

_, future = _dispatch_next(client, tm)
_poll_for_future(client, future)

assert captured['api_version'] == expected_version


# ---------------------------------------------------------------------------
# AddPartitionsToTxnHandler
# ---------------------------------------------------------------------------
Expand Down
Loading