diff --git a/app/explain.py b/app/explain.py index bcead24..8630adb 100644 --- a/app/explain.py +++ b/app/explain.py @@ -12,11 +12,6 @@ LOGGER = logging.getLogger("explain") -# Constants -MAX_CODE_LENGTH = 10000 # 10K chars should be enough for most source files -MAX_ASM_LENGTH = 20000 # 20K chars for assembly output - - async def process_request( body: ExplainRequest, client: AsyncAnthropic, diff --git a/app/prompt.py b/app/prompt.py index 87b5428..f45a206 100644 --- a/app/prompt.py +++ b/app/prompt.py @@ -15,6 +15,14 @@ # Constants from explain.py that are needed for data preparation MAX_ASSEMBLY_LINES = 300 # Maximum number of assembly lines to process +# Character budgets for the input we hand to Claude. The line-based selection +# above bounds the *number* of assembly lines, but not their length, so a few +# pathological long lines (or large source files) can still push the prompt to +# 100k+ input tokens — directly inflating prefill/TTFT and cost. These caps put +# a hard ceiling on input size before the API call. +MAX_CODE_LENGTH = 10000 # 10K chars should be enough for most source files +MAX_ASM_LENGTH = 20000 # 20K chars for assembly output (after line selection) + # Minimum max_tokens that's safe to pair with extended thinking. Below this, # adaptive thinking can consume the whole budget on complex inputs and leave # nothing for the visible response. @@ -204,13 +212,45 @@ def select_important_assembly( return selected_assembly + @staticmethod + def _truncate_chars(text: str, max_chars: int) -> str: + """Hard-cap a string to max_chars, leaving a visible marker if cut.""" + if len(text) <= max_chars: + return text + omitted = len(text) - max_chars + return f"{text[:max_chars]}\n... ({omitted} characters truncated) ..." + + @staticmethod + def cap_assembly_chars(asm_items: list[dict], max_chars: int) -> tuple[list[dict], bool]: + """Trim an assembly item list so the total `text` length stays under max_chars. + + Runs *after* line-based selection: that bounds line count, this bounds + total characters so a few very long lines can't blow up the prompt. + """ + total = 0 + capped: list[dict] = [] + for item in asm_items: + text = item.get("text", "") + if total + len(text) > max_chars: + capped.append( + { + "text": f"... (assembly truncated at {max_chars} characters) ...", + "isOmissionMarker": True, + } + ) + return capped, True + total += len(text) + capped.append(item) + return capped, False + def prepare_structured_data(self, request: ExplainRequest) -> dict[str, Any]: """Prepare a structured JSON object for Claude's consumption.""" - # Extract and validate basic fields + # Extract and validate basic fields. Source is hard-capped so a huge + # source file can't dominate the prompt (and inflate TTFT/cost). structured_data = { "language": request.language, "compiler": request.compiler, - "sourceCode": request.code, + "sourceCode": self._truncate_chars(request.code, MAX_CODE_LENGTH), "instructionSet": request.instruction_set_with_default, } @@ -222,14 +262,21 @@ def prepare_structured_data(self, request: ExplainRequest) -> dict[str, Any]: if len(asm_dicts) > MAX_ASSEMBLY_LINES: # If assembly is too large, we need smart truncation - structured_data["assembly"] = self.select_important_assembly(asm_dicts, request.labelDefinitions or {}) + selected = self.select_important_assembly(asm_dicts, request.labelDefinitions or {}) structured_data["truncated"] = True structured_data["originalLength"] = len(asm_dicts) else: # Use the full assembly if it's within limits - structured_data["assembly"] = asm_dicts + selected = asm_dicts structured_data["truncated"] = False + # Hard-cap total assembly characters regardless of line count, so a few + # very long lines can't push input to 100k+ tokens. + capped_asm, char_truncated = self.cap_assembly_chars(selected, MAX_ASM_LENGTH) + structured_data["assembly"] = capped_asm + if char_truncated: + structured_data["truncated"] = True + # Include label definitions structured_data["labelDefinitions"] = request.labelDefinitions or {} diff --git a/app/test_explain.py b/app/test_explain.py index b0d6150..2c53047 100644 --- a/app/test_explain.py +++ b/app/test_explain.py @@ -11,7 +11,21 @@ SourceMapping, ) from app.metrics import NoopMetricsProvider -from app.prompt import MAX_ASSEMBLY_LINES, MIN_MAX_TOKENS_WITH_THINKING, Prompt +from app.prompt import MAX_ASM_LENGTH, MAX_ASSEMBLY_LINES, MAX_CODE_LENGTH, MIN_MAX_TOKENS_WITH_THINKING, Prompt + + +def _minimal_prompt() -> Prompt: + """A bare Prompt instance for data-preparation tests.""" + return Prompt( + { + "model": {"name": "test", "max_tokens": 100}, + "system_prompt": "", + "user_prompt": "", + "assistant_prefill": "", + "audience_levels": {}, + "explanation_types": {}, + } + ) @pytest.fixture @@ -495,6 +509,51 @@ def test_prepare_structured_data_assembly_dict_conversion(self, sample_request): assert result["assembly"][1]["source"]["line"] == 1 assert result["assembly"][1]["source"]["column"] == 21 + def test_source_code_is_char_capped(self): + """Oversized source is hard-capped to MAX_CODE_LENGTH with a marker.""" + big_request = ExplainRequest( + language="c++", + compiler="g++", + code="x" * (MAX_CODE_LENGTH + 5000), + asm=[AssemblyItem(text="ret", source=None)], + ) + result = _minimal_prompt().prepare_structured_data(big_request) + + assert len(result["sourceCode"]) < MAX_CODE_LENGTH + 200 # cap + short marker + assert result["sourceCode"].startswith("x" * MAX_CODE_LENGTH) + assert "characters truncated" in result["sourceCode"] + + def test_assembly_char_capped_for_few_long_lines(self): + """A handful of very long lines (under the line limit) must still be + capped by total characters so input can't balloon to 100k+ tokens.""" + long_lines = [AssemblyItem(text="a" * 8000, source=None) for _ in range(5)] + request = ExplainRequest( + language="c++", + compiler="g++", + code="int main() { return 0; }", + asm=long_lines, + ) + result = _minimal_prompt().prepare_structured_data(request) + + total_chars = sum(len(item["text"]) for item in result["assembly"]) + assert total_chars <= MAX_ASM_LENGTH + 100 # budget + final marker text + assert result["truncated"] + assert any(item.get("isOmissionMarker") for item in result["assembly"]) + + def test_small_inputs_not_truncated(self): + """Normal-sized inputs pass through untouched.""" + request = ExplainRequest( + language="c++", + compiler="g++", + code="int square(int x) { return x * x; }", + asm=[AssemblyItem(text="imul eax, edi", source=None)], + ) + result = _minimal_prompt().prepare_structured_data(request) + + assert result["sourceCode"] == "int square(int x) { return x * x; }" + assert not result["truncated"] + assert "truncated" not in result["sourceCode"] + class TestValidation: """Test Pydantic validation behavior."""