Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 94 additions & 2 deletions effectful/handlers/llm/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import ast
import builtins
import collections.abc
import copy
import inspect
import keyword
import linecache
import random
import string
import sys
import types
import typing
Expand Down Expand Up @@ -493,6 +497,54 @@ class definitions with proper inheritance, typed attributes, and method stubs.
return nodes


def _generate_unique_name(existing_names: set[str]) -> str:
"""Generate a random valid Python identifier that is not in existing_names.

Produces names like ``_synth_a3f7b2`` that are valid identifiers,
not Python keywords, and not in the given set of existing names.
"""
while True:
suffix = "".join(random.choices(string.ascii_lowercase + string.digits, k=8))
candidate = f"_synth_{suffix}"
if (
candidate not in existing_names
and candidate.isidentifier()
and not keyword.iskeyword(candidate)
):
return candidate


class _RenameTransformer(ast.NodeTransformer):
"""Rename function definitions and their references in a module AST.

Given a mapping ``{old_name: new_name}``, renames:
- ``FunctionDef.name`` for matching definitions
- ``ast.Name.id`` references throughout the entire AST

The rename is applied uniformly because it only targets module-level
function definitions that collide with context variable declarations.
Local assignments inside function bodies are in their own scope and
cannot cause the mypy ``[no-redef]`` error, so they need no special
handling.
"""

def __init__(self, rename_map: dict[str, str]):
self.rename_map = rename_map

def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
if node.name in self.rename_map:
node.name = self.rename_map[node.name]
self.generic_visit(node)
return node

visit_AsyncFunctionDef = visit_FunctionDef # type: ignore[assignment]

def visit_Name(self, node: ast.Name) -> ast.Name:
if node.id in self.rename_map:
node.id = self.rename_map[node.id]
return node


def mypy_type_check(
module: ast.Module,
ctx: typing.Mapping[str, Any],
Expand All @@ -505,6 +557,9 @@ def mypy_type_check(
appends the module body, then a postlude that assigns the last function to a
variable annotated with Callable[expected_params, expected_return]. Runs mypy
on the combined source; raises TypeError with the mypy report on failure.

If the synthesized function name clashes with a name already in the context,
the function is renamed to a unique random identifier for type-checking only.
"""
if not module.body:
raise TypeError("mypy_type_check: module.body is empty")
Expand All @@ -527,6 +582,43 @@ def mypy_type_check(
stubs = collect_runtime_type_stubs(ctx)
variables = collect_variable_declarations(ctx)

# Collect names already declared in the type-checking preamble
# (variable declarations and class stubs) that could collide with
# function definitions in the synthesized module.
declared_names = {
stmt.target.id
for stmt in variables
if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name)
} | {stmt.name for stmt in stubs if isinstance(stmt, ast.ClassDef)}

# Find all function names in the synthesized module that collide
synthesized_func_names = {
stmt.name
for stmt in module.body
if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef))
}
colliding_names = synthesized_func_names & declared_names

if colliding_names:
# Build a rename map for every colliding function name
all_reserved = declared_names | synthesized_func_names
rename_map: dict[str, str] = {}
for name in colliding_names:
unique = _generate_unique_name(all_reserved)
rename_map[name] = unique
all_reserved.add(unique)

# Deep-copy the module body so we don't mutate the caller's AST,
# then rename definitions and all references to them.
module_body = copy.deepcopy(list(module.body))
stub_module_body = ast.Module(body=module_body, type_ignores=[])
_RenameTransformer(rename_map).visit(stub_module_body)
module_body = stub_module_body.body
tc_func_name = rename_map.get(func_name, func_name)
else:
module_body = list(module.body)
tc_func_name = func_name

param_types = expected_params
expected_callable_type: type = typing.cast(
type,
Expand All @@ -539,15 +631,15 @@ def mypy_type_check(
postlude = ast.AnnAssign(
target=ast.Name(id="_synthesized_check", ctx=ast.Store()),
annotation=expected_callable_ast,
value=ast.Name(id=func_name, ctx=ast.Load()),
value=ast.Name(id=tc_func_name, ctx=ast.Load()),
simple=1,
)
full_body = (
baseline_imports
+ list(imports)
+ list(stubs)
+ list(variables)
+ list(module.body)
+ module_body
+ [postlude]
)
stub_module = ast.Module(body=full_body, type_ignores=[])
Expand Down
120 changes: 120 additions & 0 deletions tests/test_handlers_llm_type_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import ast
import inspect
import textwrap
import types
import typing
from collections import ChainMap
Expand Down Expand Up @@ -1267,3 +1268,122 @@ class MyErr(Exception):
ctx = ChainMap({"MyErr": MyErr}, get_context())
with pytest.raises(TypeError):
mypy_type_check(module, ctx, [], MyErr)


class TestMypyTypeCheckNameCollision:
"""Tests that mypy_type_check renames synthesized functions whose names
collide with variable declarations or class stubs from the context."""

def test_single_function_collides_with_variable(self):
"""Function name matches a variable in context; should still pass type-check."""
count_char = lambda s: s.count("a") # noqa: E731, F841

source = textwrap.dedent("""\
def count_char(s: str) -> int:
return s.count('a')
""")
module = ast.parse(source)
ctx = get_context()
# Should NOT raise — the collision is handled by renaming
mypy_type_check(module, ctx, [str], int)

def test_colliding_function_still_detects_type_errors(self):
"""Even after renaming, real type errors are still caught."""
count_char = lambda s: s.count("a") # noqa: E731, F841

source = textwrap.dedent("""\
def count_char(s: str) -> int:
return s # wrong return type
""")
module = ast.parse(source)
ctx = get_context()
with pytest.raises(TypeError):
mypy_type_check(module, ctx, [str], int)

def test_no_collision_passes_normally(self):
"""No name collision — normal type-check should work as before."""
x = 42 # noqa: F841

source = textwrap.dedent("""\
def some_unique_func(s: str) -> int:
return len(s)
""")
module = ast.parse(source)
ctx = get_context()
mypy_type_check(module, ctx, [str], int)

def test_multiple_functions_one_collides(self):
"""Module has helper + main function; only main collides with context."""
process = "some_value" # noqa: F841

source = textwrap.dedent("""\
def helper(x: int) -> str:
return str(x)
def process(items: list[int]) -> list[str]:
return [helper(i) for i in items]
""")
module = ast.parse(source)
ctx = get_context()
mypy_type_check(module, ctx, [list[int]], list[str])

def test_multiple_functions_both_collide(self):
"""Both helper and main function names collide with context variables."""
helper = lambda: None # noqa: E731, F841
compute = 123 # noqa: F841

source = textwrap.dedent("""\
def helper(x: int) -> str:
return str(x)
def compute(n: int) -> str:
return helper(n)
""")
module = ast.parse(source)
ctx = get_context()
mypy_type_check(module, ctx, [int], str)

def test_collision_with_class_stub(self):
"""Function name collides with a runtime class stub in context."""

class MyModel:
value: int

# Also define a function named MyModel in synthesized code
source = textwrap.dedent("""\
def MyModel(x: int) -> int:
return x * 2
""")
module = ast.parse(source)
ctx = ChainMap({"MyModel": MyModel}, get_context())
mypy_type_check(module, ctx, [int], int)

def test_collision_does_not_mutate_original_ast(self):
"""Renaming should not modify the original module AST."""
count_char = lambda s: s.count("a") # noqa: E731, F841

source = textwrap.dedent("""\
def count_char(s: str) -> int:
return s.count('a')
""")
module = ast.parse(source)
original_name = module.body[-1].name

ctx = get_context()
mypy_type_check(module, ctx, [str], int)

# Original AST must be untouched
assert module.body[-1].name == original_name

def test_helper_reference_updated_after_rename(self):
"""When a helper function is renamed, calls to it inside other
functions are also updated so mypy still sees valid code."""
validate = True # noqa: F841 — collides with helper name

source = textwrap.dedent("""\
def validate(x: int) -> bool:
return x > 0
def run(x: int) -> bool:
return validate(x)
""")
module = ast.parse(source)
ctx = get_context()
mypy_type_check(module, ctx, [int], bool)
Loading