diff --git a/kafka/net/connection.py b/kafka/net/connection.py index 84aa37468..838a6ffb2 100644 --- a/kafka/net/connection.py +++ b/kafka/net/connection.py @@ -394,8 +394,9 @@ def sasl_enabled(self): async def _sasl_authenticate(self): # Step 1: SaslHandshake to negotiate mechanism - version = self.broker_version_data.api_version(SaslHandshakeRequest, max_version=1) - request = SaslHandshakeRequest[version](self.config['sasl_mechanism']) + request = SaslHandshakeRequest( + mechanism=self.config['sasl_mechanism'], + max_version=1) try: response = await self._send_request(request) except Exception as exc: @@ -415,6 +416,7 @@ async def _sasl_authenticate(self): return # Step 2: SASL authentication exchange + version = response.API_VERSION try: mechanism = get_sasl_mechanism(self.config['sasl_mechanism'])( host=self.transport.getPeer()[0], **self.config) @@ -425,7 +427,7 @@ async def _sasl_authenticate(self): while not mechanism.is_done(): token = mechanism.auth_bytes() if version == 1: - auth_request = SaslAuthenticateRequest[0](token) + auth_request = SaslAuthenticateRequest(token, version=0) else: auth_request = SaslBytesRequest(token) diff --git a/kafka/producer/sender.py b/kafka/producer/sender.py index 1eed97cf9..e13682dfa 100644 --- a/kafka/producer/sender.py +++ b/kafka/producer/sender.py @@ -668,34 +668,32 @@ def _produce_request(self, node_id, acks, timeout, batches): Returns: ProduceRequest (version depends on client api_versions) """ - produce_records_by_partition = collections.defaultdict(dict) + max_version = 9 + min_version = 0 + Topic = ProduceRequest.TopicProduceData + Partition = Topic.PartitionProduceData + topic_data = collections.defaultdict(list) for batch in batches: topic = batch.topic_partition.topic - partition = batch.topic_partition.partition - - buf = batch.records.buffer() - produce_records_by_partition[topic][partition] = buf + partition = Partition( + index=batch.topic_partition.partition, + records=batch.records.buffer(), + ) + topic_data[topic].append(partition) - version = self._client.api_version(ProduceRequest, max_version=8) - topic_partition_data = [ - (topic, list(partition_info.items())) - for topic, partition_info in produce_records_by_partition.items()] transactional_id = self._transaction_manager.transactional_id if self._transaction_manager else None - if version >= 3: - return ProduceRequest[version]( - transactional_id=transactional_id, - acks=acks, - timeout_ms=timeout, - topic_data=topic_partition_data, - ) - else: - if transactional_id is not None: - log.warning('%s: Broker does not support ProduceRequest v3+, required for transactional_id', str(self)) - return ProduceRequest[version]( - acks=acks, - timeout_ms=timeout, - topic_data=topic_partition_data, - ) + if transactional_id is not None: + min_version = 3 + + return ProduceRequest( + transactional_id=transactional_id, + acks=acks, + timeout_ms=timeout, + topic_data=[Topic(name=topic, partition_data=partitions) + for topic, partitions in topic_data.items()], + min_version=min_version, + max_version=max_version, + ) def wakeup(self): """Wake up the selector associated with this send thread.""" diff --git a/test/producer/test_sender.py b/test/producer/test_sender.py index 13e01af07..137e1da46 100644 --- a/test/producer/test_sender.py +++ b/test/producer/test_sender.py @@ -10,7 +10,6 @@ from kafka.cluster import ClusterMetadata import kafka.errors as Errors -from kafka.protocol.broker_version_data import BrokerVersionData from kafka.producer.kafka import KafkaProducer from kafka.protocol.producer import ProduceRequest from kafka.producer.future import FutureRecordMetadata @@ -95,37 +94,89 @@ def transaction_manager(cluster): metadata=cluster) -@pytest.mark.parametrize(("api_version", "produce_version"), [ +def _capture(captured): + """Build a respond_fn that records the negotiated wire api_version.""" + def fn(api_key, api_version, correlation_id, request_bytes): + captured['api_version'] = api_version + # acks=1 expects a response; an empty topic list is valid at every version. + return ProduceResponse(throttle_time_ms=0, responses=[]) + return fn + + +@pytest.mark.parametrize("broker, produce_version", [ ((2, 1), 7), ((0, 10, 0), 2), ((0, 9), 1), - ((0, 8, 0), 0) -]) -def test_produce_request(sender, api_version, produce_version): - sender._client._manager.broker_version_data = BrokerVersionData(api_version) - magic = KafkaProducer.max_usable_produce_magic(api_version) + ((0, 8, 0), 0), +], indirect=['broker']) +def test_produce_request_negotiates_wire_version(sender, broker, manager, produce_version): + """``Sender._produce_request`` returns a ProduceRequest with no fixed + version; the connection negotiates the wire version against the broker's + api_versions table at send time. We verify by capturing the api_version + that arrives at the broker.""" + # Bootstrap so cluster metadata knows about the MockBroker node. + manager.bootstrap(timeout_ms=5000) + + magic = KafkaProducer.max_usable_produce_magic(broker.broker_version) batch = producer_batch(magic=magic) - produce_request = sender._produce_request(0, 0, 0, [batch]) + produce_request = sender._produce_request(0, 1, 0, [batch]) # acks=1 assert isinstance(produce_request, ProduceRequest) - assert produce_request.version == produce_version + # Version is not pinned at construction — that's the whole point. + assert produce_request.API_VERSION is None + + captured = {} + broker.respond_fn(ProduceResponse, _capture(captured)) + future = manager.send(produce_request, node_id=0) + manager.run(manager.wait_for, future, 5000) -@pytest.mark.parametrize(("api_version", "produce_version"), [ + assert captured['api_version'] == produce_version + + +@pytest.mark.parametrize("broker, produce_version", [ ((2, 1), 7), -]) -def test_create_produce_requests(sender, api_version, produce_version): - sender._client._manager.broker_version_data = BrokerVersionData(api_version) - tp = TopicPartition('foo', 0) - magic = KafkaProducer.max_usable_produce_magic(api_version) + ((0, 10, 0), 2), + ((0, 9), 1), + ((0, 8, 0), 0), +], indirect=['broker']) +def test_create_produce_requests_negotiates_wire_version( + sender, broker, manager, produce_version): + """``_create_produce_requests`` builds one ProduceRequest per node; + each one negotiates independently against its broker's api_versions + table. We send each through the MockBroker (all routed to the single + MockBroker node via shared metadata) and assert each arrived at the + expected wire version.""" + # Advertise three broker entries (all pointing at this single MockBroker) + # so ``manager.send(..., node_id=n)`` resolves for nodes 1 and 2 as well. + # Must happen *before* bootstrap so the metadata response carries them. + from kafka.protocol.metadata import MetadataResponse + Broker = MetadataResponse.MetadataResponseBroker + broker.set_metadata(brokers=[ + Broker(node_id=n, host=broker.host, port=broker.port, rack=None) + for n in range(3) + ]) + manager.bootstrap(timeout_ms=5000) + + magic = KafkaProducer.max_usable_produce_magic(broker.broker_version) batches_by_node = collections.defaultdict(list) for node in range(3): for _ in range(5): batches_by_node[node].append(producer_batch(magic=magic)) produce_requests_by_node = sender._create_produce_requests(batches_by_node) assert len(produce_requests_by_node) == 3 + for node in range(3): - assert isinstance(produce_requests_by_node[node], ProduceRequest) - assert produce_requests_by_node[node].version == produce_version + request = produce_requests_by_node[node] + assert isinstance(request, ProduceRequest) + assert request.API_VERSION is None + + captured = {} + broker.respond_fn(ProduceResponse, _capture(captured)) + future = manager.send(request, node_id=node) + manager.run(manager.wait_for, future, 5000) + assert captured['api_version'] == produce_version, ( + 'node %d: expected v%d got v%s' + % (node, produce_version, captured.get('api_version'))) def test_complete_batch_success(sender): diff --git a/test/test_mock_broker.py b/test/test_mock_broker.py index 8c38de302..3e8a4469c 100644 --- a/test/test_mock_broker.py +++ b/test/test_mock_broker.py @@ -226,8 +226,7 @@ def test_send_and_receive(self): client.await_ready(node_id, timeout_ms=5000) # Send a MetadataRequest directly via client.send - version = client.api_version(MetadataRequest, max_version=8) - request = MetadataRequest[version]() + request = MetadataRequest(max_version=9) future = client.send(node_id, request) _poll_for_future(client, future) @@ -251,8 +250,7 @@ def test_fail_next_aborts_request(self): error = Errors.KafkaConnectionError('simulated') broker.fail_next(MetadataRequest, error=error) - version = client.api_version(MetadataRequest, max_version=8) - future = client.send(node_id, MetadataRequest[version]()) + future = client.send(node_id, MetadataRequest(max_version=9)) _poll_for_future(client, future) assert future.failed()