From 45fd8a8f6a87f96968e897514400d8d61f981a7d Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Wed, 20 May 2026 09:49:58 -0700 Subject: [PATCH] Producer: Use new-style request construction in Txn Manager --- kafka/producer/transaction_manager.py | 108 +++++++----------- test/producer/test_transaction_manager.py | 16 --- .../test_transaction_manager_mock_broker.py | 91 ++++++++++++++- 3 files changed, 130 insertions(+), 85 deletions(-) diff --git a/kafka/producer/transaction_manager.py b/kafka/producer/transaction_manager.py index 2625d2057..d2cbe7956 100644 --- a/kafka/producer/transaction_manager.py +++ b/kafka/producer/transaction_manager.py @@ -877,39 +877,29 @@ def priority(self): class InitProducerIdHandler(TxnRequestHandler): def __init__(self, transaction_manager, transaction_timeout_ms, is_epoch_bump=False): super().__init__(transaction_manager) - self._is_epoch_bump = is_epoch_bump - api_version = transaction_manager._api_version - # KIP-360 / InitProducerIdRequest v3+ (Kafka 2.5+) lets us resume - # an existing producer_id by bumping its epoch rather than allocating - # a fresh one. v3+ takes producer_id + epoch fields; on broker match, - # the broker returns (same producer_id, epoch+1). - if api_version >= (2, 5): - version = 3 - elif api_version >= (2, 4): - version = 2 - elif api_version >= (2, 0): - version = 1 - else: - version = 0 + max_version = 4 + min_version = 0 if is_epoch_bump: - assert version >= 3, "KIP-360 epoch bump requires Kafka 2.5+ broker" + # KIP-360 / InitProducerIdRequest v3+ (Kafka 2.5+) lets us resume + # an existing producer_id by bumping its epoch rather than allocating + # a fresh one. v3+ takes producer_id + epoch fields; on broker match, + # the broker returns (same producer_id, epoch+1). + min_version = 3 producer_id = transaction_manager.producer_id_and_epoch.producer_id producer_epoch = transaction_manager.producer_id_and_epoch.epoch else: producer_id = NO_PRODUCER_ID producer_epoch = NO_PRODUCER_EPOCH - kwargs = { - 'version': version, - 'transactional_id': self.transactional_id, - 'transaction_timeout_ms': transaction_timeout_ms, - } - if version >= 3: - kwargs['producer_id'] = producer_id - kwargs['producer_epoch'] = producer_epoch - self.request = InitProducerIdRequest(**kwargs) + self.request = InitProducerIdRequest( + transactional_id=self.transactional_id, + transaction_timeout_ms=transaction_timeout_ms, + producer_id=producer_id, + producer_epoch=producer_epoch, + max_version=max_version, + min_version=min_version) @property def priority(self): @@ -961,20 +951,18 @@ class AddPartitionsToTxnHandler(TxnRequestHandler): def __init__(self, transaction_manager, topic_partitions): super().__init__(transaction_manager) - if transaction_manager._api_version >= (2, 7): - version = 2 - elif transaction_manager._api_version >= (2, 0): - version = 1 - else: - version = 0 topic_data = collections.defaultdict(list) for tp in topic_partitions: topic_data[tp.topic].append(tp.partition) - self.request = AddPartitionsToTxnRequest[version]( + + Topic = AddPartitionsToTxnRequest.AddPartitionsToTxnTopic + self.request = AddPartitionsToTxnRequest( v3_and_below_transactional_id=self.transactional_id, v3_and_below_producer_id=self.producer_id, v3_and_below_producer_epoch=self.producer_epoch, - v3_and_below_topics=list(topic_data.items())) + v3_and_below_topics=[Topic(name=topic, partitions=partitions) + for topic, partitions in topic_data.items()], + max_version=3) @property def priority(self): @@ -1056,19 +1044,16 @@ def __init__(self, transaction_manager, coord_type, coord_key): self._coord_type = coord_type self._coord_key = coord_key - if transaction_manager._api_version >= (2, 0): - version = 2 - else: - version = 1 if coord_type == 'group': coord_type_int8 = 0 elif coord_type == 'transaction': coord_type_int8 = 1 else: raise ValueError("Unrecognized coordinator type: %s" % (coord_type,)) - self.request = FindCoordinatorRequest[version]( + self.request = FindCoordinatorRequest( key=coord_key, key_type=coord_type_int8, + max_version=3, ) @property @@ -1109,18 +1094,12 @@ def handle_response(self, response): class EndTxnHandler(TxnRequestHandler): def __init__(self, transaction_manager, committed): super().__init__(transaction_manager) - - if self.transaction_manager._api_version >= (2, 7): - version = 2 - elif self.transaction_manager._api_version >= (2, 0): - version = 1 - else: - version = 0 - self.request = EndTxnRequest[version]( + self.request = EndTxnRequest( transactional_id=self.transactional_id, producer_id=self.producer_id, producer_epoch=self.producer_epoch, - committed=committed) + committed=committed, + max_version=3) @property def priority(self): @@ -1154,17 +1133,13 @@ def __init__(self, transaction_manager, consumer_group_id, offsets): self.consumer_group_id = consumer_group_id self.offsets = offsets - if self.transaction_manager._api_version >= (2, 7): - version = 2 - elif self.transaction_manager._api_version >= (2, 0): - version = 1 - else: - version = 0 - self.request = AddOffsetsToTxnRequest[version]( + self.request = AddOffsetsToTxnRequest( transactional_id=self.transactional_id, producer_id=self.producer_id, producer_epoch=self.producer_epoch, - group_id=consumer_group_id) + group_id=self.consumer_group_id, + max_version=3, + ) @property def priority(self): @@ -1210,27 +1185,26 @@ def __init__(self, transaction_manager, consumer_group_id, offsets, result): self.request = self._build_request() def _build_request(self): - if self.transaction_manager._api_version >= (2, 1): - version = 2 - elif self.transaction_manager._api_version >= (2, 0): - version = 1 - else: - version = 0 + Topic = TxnOffsetCommitRequest.TxnOffsetCommitRequestTopic + Partition = Topic.TxnOffsetCommitRequestPartition topic_data = collections.defaultdict(list) for tp, offset in self.offsets.items(): - if version >= 2: - partition_data = (tp.partition, offset.offset, offset.leader_epoch, offset.metadata) - else: - partition_data = (tp.partition, offset.offset, offset.metadata) - topic_data[tp.topic].append(partition_data) + topic_data[tp.topic].append(Partition( + partition_index=tp.partition, + committed_offset=offset.offset, + committed_leader_epoch=offset.leader_epoch, + committed_metadata=offset.metadata)) - return TxnOffsetCommitRequest[version]( + return TxnOffsetCommitRequest( transactional_id=self.transactional_id, group_id=self.consumer_group_id, producer_id=self.producer_id, producer_epoch=self.producer_epoch, - topics=list(topic_data.items())) + topics=[Topic(name=topic, partitions=partitions) + for topic, partitions in topic_data.items()], + max_version=2, + ) @property def priority(self): diff --git a/test/producer/test_transaction_manager.py b/test/producer/test_transaction_manager.py index 98f0f2253..25ddc9bcc 100644 --- a/test/producer/test_transaction_manager.py +++ b/test/producer/test_transaction_manager.py @@ -330,22 +330,6 @@ def _fake_init_producer_id_response(self, error_code=0, producer_id=42, producer producer_epoch=producer_epoch, ) - def test_v3_version_on_modern_broker(self): - _, handler = self._make_handler(api_version=(2, 5)) - assert handler.request.version == 3 - - def test_v2_version_on_2_4_broker(self): - _, handler = self._make_handler(api_version=(2, 4)) - assert handler.request.version == 2 - - def test_v1_version_on_2_0_broker(self): - _, handler = self._make_handler(api_version=(2, 0)) - assert handler.request.version == 1 - - def test_v0_version_on_old_broker(self): - _, handler = self._make_handler(api_version=(0, 11)) - assert handler.request.version == 0 - def test_non_bump_request_has_no_producer_id(self): _, handler = self._make_handler(api_version=(2, 5), is_epoch_bump=False) # v3 request fields default to NO_PRODUCER_ID / NO_PRODUCER_EPOCH diff --git a/test/producer/test_transaction_manager_mock_broker.py b/test/producer/test_transaction_manager_mock_broker.py index b01391dbb..df95f334b 100644 --- a/test/producer/test_transaction_manager_mock_broker.py +++ b/test/producer/test_transaction_manager_mock_broker.py @@ -32,6 +32,7 @@ AddOffsetsToTxnResponse, AddPartitionsToTxnResponse, EndTxnResponse, + InitProducerIdRequest, InitProducerIdResponse, ProduceRequest, ProduceResponse, @@ -123,8 +124,12 @@ def _pending_handlers(tm): @pytest.fixture -def broker(): - return MockBroker() +def broker(request): + """Parametrizable broker version: ``@pytest.mark.parametrize("broker", + [(2, 4)], indirect=True)`` simulates an older broker so a handler's + request negotiates a lower wire version.""" + broker_version = getattr(request, 'param', (4, 2)) + return MockBroker(broker_version=broker_version) @pytest.fixture @@ -351,6 +356,88 @@ def test_idempotent_producer_init_then_bump(self, broker, client): assert tm.producer_id_and_epoch.epoch == _PRODUCER_EPOCH + 1 +# --------------------------------------------------------------------------- +# Wire-version negotiation for InitProducerIdRequest +# --------------------------------------------------------------------------- + + +class TestInitProducerIdHandlerWireVersion: + """InitProducerIdHandler uses the modern construction style with + ``min_version`` / ``max_version``; the wire version is picked by + the per-connection ``broker_version_data`` at send time, not by an + explicit version selector in the handler. These tests drive the request + over MockBroker and assert the captured ``api_version`` matches what we + expect for each broker generation.""" + + def _capture(self, captured, response): + """Build a respond_fn callable that records the negotiated version.""" + def fn(api_key, api_version, correlation_id, request_bytes): + captured['api_version'] = api_version + return response + return fn + + @pytest.mark.parametrize("broker, expected_version", [ + ((0, 11), 0), + ((2, 0), 1), + ((2, 4), 2), + ((2, 5), 3), + ((4, 2), 4), + ], indirect=['broker']) + def test_non_bump_negotiates_to_broker_max(self, broker, client, + expected_version): + from kafka.producer.transaction_manager import InitProducerIdHandler + tm = _make_manager(client) + tm._current_state = TransactionState.INITIALIZING + tm.set_producer_id_and_epoch(ProducerIdAndEpoch(-1, -1)) + handler = InitProducerIdHandler(tm, transaction_timeout_ms=1000) + tm._enqueue_request(handler) + + captured = {} + broker.respond_fn(InitProducerIdResponse, self._capture( + captured, + InitProducerIdResponse(throttle_time_ms=0, error_code=0, + producer_id=42, producer_epoch=0), + )) + + _, future = _dispatch_next(client, tm) + _poll_for_future(client, future) + + assert captured['api_version'] == expected_version + + @pytest.mark.parametrize("broker, expected_version", [ + ((2, 5), 3), + ((2, 7), 4), + ((4, 2), 4), + ], indirect=['broker']) + def test_epoch_bump_negotiates_to_v3_or_higher(self, broker, client, + expected_version): + """KIP-360 epoch-bump path sets ``min_version=3``; older brokers + (<2.5) lack v3 of InitProducerId and the request fails before reach- + ing the wire — covered separately in + ``TestBumpProducerIdAndEpoch._supports_epoch_bump`` in + test_transaction_manager.py.""" + from kafka.producer.transaction_manager import InitProducerIdHandler + tm = _make_manager(client) + # Seed the manager with a producer_id/epoch for the bump to capture. + tm._current_state = TransactionState.READY + tm.set_producer_id_and_epoch(ProducerIdAndEpoch(42, 7)) + handler = InitProducerIdHandler( + tm, transaction_timeout_ms=1000, is_epoch_bump=True) + tm._enqueue_request(handler) + + captured = {} + broker.respond_fn(InitProducerIdResponse, self._capture( + captured, + InitProducerIdResponse(throttle_time_ms=0, error_code=0, + producer_id=42, producer_epoch=8), + )) + + _, future = _dispatch_next(client, tm) + _poll_for_future(client, future) + + assert captured['api_version'] == expected_version + + # --------------------------------------------------------------------------- # AddPartitionsToTxnHandler # ---------------------------------------------------------------------------