-
Notifications
You must be signed in to change notification settings - Fork 681
[Optimization] Improve MTP module test coverage #5729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
…arrival_time, add missing keys
|
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds comprehensive unit test coverage for the MTPProposer class in the speculative decoding module. The tests use mocking to isolate the MTP (Multi-Token-Prediction) proposer's behavior and validate various code paths without requiring GPU hardware or compiled operators.
- Adds 701 lines of test code covering 11 distinct test methods for the MTPProposer class
- Introduces tests for initialization, configuration, cache management, task insertion, forward meta, input preparation, and CUDA execution paths
- Adds
logprobs_modeconfiguration to the FakeModelConfig test utility
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 9 comments.
| File | Description |
|---|---|
| tests/spec_decode/test_mtp_proposer.py | New comprehensive test suite for MTPProposer covering initialization, configuration methods, cache operations, task handling, and execution paths with extensive mocking |
| tests/utils.py | Adds logprobs_mode field to FakeModelConfig for test support |
| """ | ||
| # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The copyright header should use regular comments (# prefix) rather than a module docstring (triple quotes). This is inconsistent with the convention used elsewhere in the codebase, such as in tests/conftest.py. Module docstrings should be reserved for actual module documentation, not licensing information.
| @@ -0,0 +1,701 @@ | |||
| """ | |||
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR description lacks sufficient detail. According to the custom coding guidelines, the description should explain at a minimum: (1) why these modifications are being made and (2) what problem is being solved. Currently, the Motivation, Modifications, and Usage sections all simply say "No need," which does not provide reviewers with adequate context about the purpose of these test additions or what specific coverage gaps they address.
| self.target_model_inputs = { | ||
| "block_tables": paddle.zeros([2, 10], dtype="int32"), | ||
| "input_ids": paddle.zeros([2, 2048], dtype="int64"), | ||
| "seq_lens_this_time": paddle.zeros([2, 1], dtype="int32"), | ||
| "seq_lens_encoder": paddle.zeros([2, 1], dtype="int32"), | ||
| "seq_lens_decoder": paddle.zeros([2, 1], dtype="int32"), | ||
| "prompt_lens": paddle.zeros([2, 1], dtype="int64"), | ||
| "step_idx": paddle.zeros([2, 1], dtype="int64"), | ||
| "stop_flags": paddle.zeros([2, 1], dtype="bool"), | ||
| "stop_nums": paddle.zeros([2, 1], dtype="int32"), | ||
| "pre_ids": paddle.zeros([2, 2048], dtype="int64"), | ||
| "output_cum_offsets": paddle.zeros([2], dtype="int32"), | ||
| "output_padding_offset": paddle.zeros([2], dtype="int32"), | ||
| "ids_remove_padding": paddle.zeros([2], dtype="int64"), | ||
| "batch_id_per_token": paddle.zeros([2], dtype="int32"), | ||
| "cu_seqlens_q": paddle.zeros([3], dtype="int32"), | ||
| "cu_seqlens_k": paddle.zeros([3], dtype="int32"), | ||
| "decoder_batch_ids": paddle.zeros([2], dtype="int32"), | ||
| "decoder_tile_ids_per_batch": paddle.zeros([2], dtype="int32"), | ||
| "decoder_num_blocks_cpu": paddle.zeros([2], dtype="int32").cpu(), | ||
| "decoder_num_blocks_device": paddle.zeros([2], dtype="int32"), | ||
| "decoder_chunk_size_device": paddle.zeros([2], dtype="int32"), | ||
| "max_len_tensor_cpu": paddle.zeros([2], dtype="int32").cpu(), | ||
| "encoder_batch_ids": paddle.zeros([2], dtype="int32"), | ||
| "encoder_tile_ids_per_batch": paddle.zeros([2], dtype="int32"), | ||
| "encoder_num_blocks_x_cpu": paddle.zeros([2], dtype="int32").cpu(), | ||
| "kv_batch_ids": paddle.zeros([2], dtype="int32"), | ||
| "kv_tile_ids_per_batch": paddle.zeros([2], dtype="int32"), | ||
| "kv_num_blocks_x_cpu": paddle.zeros([2], dtype="int32").cpu(), | ||
| "top_p": paddle.ones([2, 1], dtype="float32") * 0.9, | ||
| "top_k": paddle.zeros([2, 1], dtype="int32"), | ||
| "temperature": paddle.ones([2, 1], dtype="float32"), | ||
| "eos_token_id": paddle.ones([2], dtype="int64") * 2, | ||
| "penalty_score": paddle.ones([2, 1], dtype="float32"), | ||
| "frequency_score": paddle.zeros([2, 1], dtype="float32"), | ||
| "presence_score": paddle.zeros([2, 1], dtype="float32"), | ||
| "infer_seed": paddle.zeros([2, 1], dtype="int64"), | ||
| "max_dec_len": paddle.ones([2, 1], dtype="int64") * 512, | ||
| "min_dec_len": paddle.zeros([2, 1], dtype="int64"), | ||
| "bad_tokens": paddle.zeros([2], dtype="int64"), | ||
| "draft_tokens": paddle.zeros([2, 2], dtype="int64"), | ||
| "accept_tokens": paddle.zeros([2, 2], dtype="int64"), | ||
| "accept_num": paddle.ones([2], dtype="int32"), | ||
| "draft_logits": paddle.zeros([4, 32000], dtype="float32"), | ||
| "temp_scaled_logprobs": paddle.zeros([2], dtype="float32"), | ||
| "top_p_normalized_logprobs": paddle.zeros([2], dtype="float32"), | ||
| "encoder_block_lens": paddle.zeros([2, 1], dtype="int32"), | ||
| "cu_batch_token_offset": paddle.zeros([3], dtype="int32"), | ||
| "is_block_step": paddle.zeros([2], dtype="bool"), | ||
| "actual_draft_token_num": paddle.zeros([2], dtype="int32"), | ||
| } |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The target_model_inputs dictionary in setUp contains a large amount of duplicated data (63 keys with hardcoded tensor initializations). Consider extracting this into a helper method or factory function to improve maintainability. If these inputs need to be modified in individual tests, you could create a base dictionary and allow tests to override specific values as needed.
| mock_model = Mock() | ||
| mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) | ||
| mock_model_loader.return_value.load_model.return_value = mock_model | ||
| mock_attn = Mock() | ||
| mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64]) | ||
| mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn | ||
| mock_rope.return_value = paddle.zeros([1, 2048, 64]) | ||
|
|
||
| proposer = MTPProposer( | ||
| self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs | ||
| ) | ||
|
|
||
| # Test _update_mtp_config | ||
| self.assertEqual(proposer.model_config.architectures[0], "ErnieMTPForCausalLM") | ||
| self.assertEqual(proposer.model_config.num_hidden_layers, 1) | ||
| self.assertEqual(proposer.speculative_config.model_type, "mtp") | ||
|
|
||
| # Test _get_cache_type | ||
| cache_type = proposer._get_cache_type() | ||
| self.assertIn(cache_type, ["uint8", "int8"]) | ||
|
|
||
| # Test is_chunk_prefill_enabled | ||
| self.assertTrue(proposer.is_chunk_prefill_enabled()) | ||
|
|
||
| @patch("fastdeploy.spec_decode.mtp.get_model_loader") | ||
| @patch("fastdeploy.spec_decode.mtp.get_attention_backend") | ||
| @patch("fastdeploy.spec_decode.mtp.get_rope") | ||
| def test_dummy_prefill_inputs_and_kv_cache(self, mock_rope, mock_attn_backend, mock_model_loader): | ||
| """Test dummy_prefill_inputs and initialize_kv_cache with different branches""" | ||
| mock_model = Mock() | ||
| mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) | ||
| mock_model_loader.return_value.load_model.return_value = mock_model | ||
| mock_attn = Mock() | ||
| mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64]) | ||
| mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn | ||
| mock_rope.return_value = paddle.zeros([1, 2048, 64]) | ||
|
|
||
| proposer = MTPProposer( | ||
| self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs | ||
| ) | ||
|
|
||
| # Test dummy_prefill_inputs with expert parallel | ||
| self.fd_config.parallel_config.enable_expert_parallel = True | ||
| proposer.dummy_prefill_inputs(num_tokens=100, batch_size=2, expected_decode_len=10) | ||
| self.assertGreater(proposer.model_inputs["seq_lens_encoder"][0].item(), 0) | ||
|
|
||
| # Test initialize_kv_cache with prefix caching | ||
| self.fd_config.cache_config.enable_prefix_caching = True | ||
| proposer.initialize_kv_cache(main_model_num_blocks=10, profile=False) | ||
| self.assertIn("caches", proposer.model_inputs) | ||
|
|
||
| # Test initialize_kv_cache with block_wise_fp8 | ||
| self.fd_config.quant_config = Mock() | ||
| self.fd_config.quant_config.kv_cache_quant_type = "block_wise_fp8" | ||
| proposer.initialize_kv_cache(main_model_num_blocks=10, profile=False) | ||
|
|
||
| # Test initialize_kv_cache with profile=True | ||
| proposer.initialize_kv_cache(main_model_num_blocks=10, profile=True) | ||
|
|
||
| # Test clear_mtp_cache | ||
| proposer.clear_mtp_cache() | ||
| self.assertNotIn("caches", proposer.model_inputs) | ||
|
|
||
| @patch("fastdeploy.spec_decode.mtp.get_model_loader") | ||
| @patch("fastdeploy.spec_decode.mtp.get_attention_backend") | ||
| @patch("fastdeploy.spec_decode.mtp.get_rope") | ||
| def test_update_mtp_block_num(self, mock_rope, mock_attn_backend, mock_model_loader): | ||
| """Test update_mtp_block_num""" | ||
| mock_model = Mock() | ||
| mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) | ||
| mock_model_loader.return_value.load_model.return_value = mock_model | ||
| mock_attn = Mock() | ||
| mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64]) | ||
| mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn | ||
| mock_rope.return_value = paddle.zeros([1, 2048, 64]) | ||
|
|
||
| proposer = MTPProposer( | ||
| self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs | ||
| ) | ||
| proposer.update_mtp_block_num(num_gpu_blocks=20) | ||
| self.assertEqual(proposer.main_model_num_gpu_blocks, 20) | ||
| self.assertIn("free_list", proposer.model_inputs) | ||
|
|
||
| @patch("fastdeploy.spec_decode.mtp.get_model_loader") | ||
| @patch("fastdeploy.spec_decode.mtp.get_attention_backend") | ||
| @patch("fastdeploy.spec_decode.mtp.get_rope") | ||
| def test_insert_tasks_v1(self, mock_rope, mock_attn_backend, mock_model_loader): | ||
| """Test insert_tasks_v1 with different request types""" | ||
| mock_model = Mock() | ||
| mock_model.compute_logits = Mock(return_value=paddle.zeros([2, 32000])) | ||
| mock_model_loader.return_value.load_model.return_value = mock_model | ||
| mock_attn = Mock() | ||
| mock_attn.get_kv_cache_shape.return_value = ([2, 12, 16, 64], [2, 12, 16, 64]) | ||
| mock_attn_backend.return_value = lambda *args, **kwargs: mock_attn | ||
| mock_rope.return_value = paddle.zeros([1, 2048, 64]) |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's significant duplication in the mock setup code across all test methods. Each test method has nearly identical mock configurations for model_loader, attn_backend, and rope. Consider refactoring this repetitive setup into a setUp method helper or using a test fixture to reduce code duplication and improve maintainability. This pattern appears in all 11 test methods.
| request1 = Request( | ||
| request_id="test1", | ||
| prompt="test", | ||
| prompt_token_ids=[1, 2, 3, 4, 5], | ||
| prompt_token_ids_len=5, | ||
| messages=None, | ||
| history=None, | ||
| tools=None, | ||
| system=None, | ||
| eos_token_ids=[2], | ||
| ) |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call to Request.init with too few arguments; should be no fewer than 10.
| request2 = Request( | ||
| request_id="test2", | ||
| prompt="test", | ||
| prompt_token_ids=[1, 2], | ||
| prompt_token_ids_len=2, | ||
| messages=None, | ||
| history=None, | ||
| tools=None, | ||
| system=None, | ||
| eos_token_ids=[2], | ||
| ) |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call to Request.init with too few arguments; should be no fewer than 10.
| request2 = Request( | ||
| request_id="test2", | ||
| prompt="test", | ||
| prompt_token_ids=[1, 2], | ||
| prompt_token_ids_len=2, | ||
| messages=None, | ||
| history=None, | ||
| tools=None, | ||
| system=None, | ||
| eos_token_ids=[2], | ||
| ) | ||
| request2.idx = 1 | ||
| request2.task_type = RequestType.DECODE | ||
| request2.block_tables = [2, 3] | ||
|
|
||
| # Test with PREEMPTED request | ||
| request3 = Request( | ||
| request_id="test3", | ||
| prompt="test", | ||
| prompt_token_ids=[1], | ||
| prompt_token_ids_len=1, | ||
| messages=None, | ||
| history=None, | ||
| tools=None, | ||
| system=None, | ||
| eos_token_ids=[2], | ||
| ) |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call to Request.init with too few arguments; should be no fewer than 10.
| request2 = Request( | |
| request_id="test2", | |
| prompt="test", | |
| prompt_token_ids=[1, 2], | |
| prompt_token_ids_len=2, | |
| messages=None, | |
| history=None, | |
| tools=None, | |
| system=None, | |
| eos_token_ids=[2], | |
| ) | |
| request2.idx = 1 | |
| request2.task_type = RequestType.DECODE | |
| request2.block_tables = [2, 3] | |
| # Test with PREEMPTED request | |
| request3 = Request( | |
| request_id="test3", | |
| prompt="test", | |
| prompt_token_ids=[1], | |
| prompt_token_ids_len=1, | |
| messages=None, | |
| history=None, | |
| tools=None, | |
| system=None, | |
| eos_token_ids=[2], | |
| ) | |
| request2 = Mock(spec=Request) | |
| request2.request_id = "test2" | |
| request2.prompt = "test" | |
| request2.prompt_token_ids = [1, 2] | |
| request2.prompt_token_ids_len = 2 | |
| request2.messages = None | |
| request2.history = None | |
| request2.tools = None | |
| request2.system = None | |
| request2.eos_token_ids = [2] | |
| request2.idx = 1 | |
| request2.task_type = RequestType.DECODE | |
| request2.block_tables = [2, 3] | |
| # Test with PREEMPTED request | |
| request3 = Mock(spec=Request) | |
| request3.request_id = "test3" | |
| request3.prompt = "test" | |
| request3.prompt_token_ids = [1] | |
| request3.prompt_token_ids_len = 1 | |
| request3.messages = None | |
| request3.history = None | |
| request3.tools = None | |
| request3.system = None | |
| request3.eos_token_ids = [2] |
| history=None, | ||
| tools=None, | ||
| system=None, | ||
| eos_token_ids=[2], |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call to Request.init with too few arguments; should be no fewer than 10.
| eos_token_ids=[2], | |
| eos_token_ids=[2], | |
| request_type=RequestType.GENERATION, |
| request = Request( | ||
| request_id="preempt", | ||
| prompt="t", | ||
| prompt_token_ids=[1], | ||
| prompt_token_ids_len=1, | ||
| messages=None, | ||
| history=None, | ||
| tools=None, | ||
| system=None, | ||
| eos_token_ids=[2], | ||
| ) |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call to Request.init with too few arguments; should be no fewer than 10.
This reverts commit f73b6e9.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #5729 +/- ##
==========================================
Coverage ? 64.81%
==========================================
Files ? 337
Lines ? 43074
Branches ? 6637
==========================================
Hits ? 27917
Misses ? 13085
Partials ? 2072
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Motivation
No need
Modifications
No need
Usage or Command
No need
Accuracy Tests
No need
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.