From a9e26d8411f1db56598876247de1f62195b90d4b Mon Sep 17 00:00:00 2001 From: Edgar Costa Date: Tue, 10 Mar 2026 19:16:03 -0400 Subject: [PATCH 1/2] Add parameter name consistency checker for headers, docs, and sources Python tool (dev/check_param_names.py) that compares function parameter names across header declarations, RST documentation, and .c source definitions. Supports auto-fixing mismatches with --fix and --dry-run. Includes 86 unit tests (dev/test_check_param_names.py). --- dev/check_param_names.py | 1086 +++++++++++++++++++++++++++++++++ dev/test_check_param_names.py | 1072 ++++++++++++++++++++++++++++++++ 2 files changed, 2158 insertions(+) create mode 100755 dev/check_param_names.py create mode 100644 dev/test_check_param_names.py diff --git a/dev/check_param_names.py b/dev/check_param_names.py new file mode 100755 index 0000000000..95c6e33f09 --- /dev/null +++ b/dev/check_param_names.py @@ -0,0 +1,1086 @@ +#!/usr/bin/env python3 +""" +Compare parameter names between FLINT header files, RST documentation, +and .c source files. + +Reports mismatches where the same function has different parameter names +across these sources. + +Usage: + python3 dev/check_param_names.py [--module MODULE] [--check-src] + python3 dev/check_param_names.py --check-src -m fmpz + +Run from the FLINT root directory. +""" + +import os +import re +import sys +import argparse +from collections import OrderedDict + + +# Decorators/macros that appear before return types in headers +DECL_PREFIXES = re.compile( + r"^(?:" + r"WARN_UNUSED_RESULT|" + r"FLINT_FORCE_INLINE|" + r"\w+_INLINE|" # FMPZ_INLINE, GR_POLY_INLINE, etc. + r"FLINT_DLL" + r")\s*", +) + + +def find_modules(src_dir, doc_dir): + """Find modules that have both a header and a doc file.""" + modules = [] + for f in sorted(os.listdir(src_dir)): + if f.endswith(".h"): + mod = f[:-2] # strip .h + rst = os.path.join(doc_dir, mod + ".rst") + header = os.path.join(src_dir, f) + if os.path.isfile(rst): + modules.append((mod, header, rst)) + return modules + + +def extract_param_name(param_str): + """ + Extract the parameter name from a C parameter declaration string. + + Examples: + "const fmpz_t x" -> "x" + "slong n" -> "n" + "ulong * out" -> "out" + "FILE * file" -> "file" + "nn_srcptr xp" -> "xp" + "..." -> None + "void" -> None + "const padic_ctx_t FLINT_UNUSED(ctx)" -> "ctx" + """ + param_str = param_str.strip() + if not param_str or param_str == "void" or param_str == "...": + return None + + # Handle FLINT_UNUSED(name) + m = re.search(r"FLINT_UNUSED\((\w+)\)", param_str) + if m: + return m.group(1) + + # Remove array brackets: e.g. "ulong out[3]" -> name is "out" + param_str = re.sub(r"\[.*?\]", "", param_str).strip() + + # Split on spaces and asterisks, take the last word token + tokens = re.findall(r"\w+", param_str) + if not tokens: + return None + + # The last token is the name, unless the entire thing is just type(s) + # with no name (e.g., "void" - already handled above) + name = tokens[-1] + + # Sanity check: if the "name" looks like a type keyword, skip + type_keywords = { + "void", "int", "long", "short", "char", "unsigned", "signed", + "float", "double", "size_t", "ssize_t", "FILE", + } + if name in type_keywords: + return None + + # If the "name" looks like a type (ends with _t, _struct, _ptr, + # _srcptr), it's an unnamed parameter (type only). + type_suffixes = ("_t", "_struct", "_ptr", "_srcptr") + if name.endswith(type_suffixes): + return None + + return name + + +def split_params(params_str): + """ + Split a parameter list string by commas, respecting nested parentheses + and function pointer syntax. + """ + params = [] + depth = 0 + current = [] + for ch in params_str: + if ch in ("(", "["): + depth += 1 + current.append(ch) + elif ch in (")", "]"): + depth -= 1 + current.append(ch) + elif ch == "," and depth == 0: + params.append("".join(current).strip()) + current = [] + else: + current.append(ch) + last = "".join(current).strip() + if last: + params.append(last) + return params + + +def parse_func_signature(sig): + """ + Parse a C function signature string, returning (func_name, [param_names]) + or None if it can't be parsed. + + The signature should be like: + "void fmpz_add(fmpz_t f, const fmpz_t g, const fmpz_t h)" + """ + # Normalize whitespace + sig = re.sub(r"\s+", " ", sig).strip() + + # Remove trailing semicolons + sig = sig.rstrip(";").strip() + + # Must have parentheses + if "(" not in sig or ")" not in sig: + return None + + # Find the first '(' at depth 0 (the parameter list opening) + depth = 0 + paren_start = -1 + for i, ch in enumerate(sig): + if ch == "(": + if depth == 0: + paren_start = i + break + depth += 1 + elif ch == ")": + depth -= 1 + + if paren_start < 0: + return None + + # Find matching close paren + depth = 1 + paren_end = -1 + for i in range(paren_start + 1, len(sig)): + if sig[i] == "(": + depth += 1 + elif sig[i] == ")": + depth -= 1 + if depth == 0: + paren_end = i + break + + if paren_end < 0: + return None + + before_paren = sig[:paren_start].strip() + params_str = sig[paren_start + 1 : paren_end].strip() + + # Skip function pointer declarations: "typedef void (*name)(...)" + if "(*" in before_paren: + return None + + # Extract function name: last word before the paren + name_match = re.search(r"(\w+)\s*$", before_paren) + if not name_match: + return None + func_name = name_match.group(1) + + # Parse parameter names + if not params_str or params_str == "void": + return (func_name, []) + + param_parts = split_params(params_str) + param_names = [] + for p in param_parts: + # Function pointer parameters: "int (*cmp)(void *, const void *)" + fptr_match = re.match(r".*\(\s*\*\s*(\w+)\s*\)", p) + if fptr_match: + param_names.append(fptr_match.group(1)) + continue + + name = extract_param_name(p) + if name is not None: + param_names.append(name) + + return (func_name, param_names) + + +def strip_c_comments_and_preprocessor(content): + """ + Remove block comments, line comments, and preprocessor directive lines + from C source code. Also remove extern "C" { wrappers. + Returns cleaned text. + """ + # Remove block comments + content = re.sub(r"/\*.*?\*/", " ", content, flags=re.DOTALL) + + # Remove line comments (shouldn't appear in FLINT style, but just in case) + content = re.sub(r"//[^\n]*", "", content) + + # Remove preprocessor directive lines (keep code inside #ifdef blocks) + lines = content.split("\n") + clean_lines = [] + in_macro_def = False + for line in lines: + stripped = line.strip() + if in_macro_def: + if not stripped.endswith("\\"): + in_macro_def = False + clean_lines.append("") + continue + if stripped.startswith("#"): + if stripped.endswith("\\"): + in_macro_def = True + clean_lines.append("") + continue + clean_lines.append(line) + + text = "\n".join(clean_lines) + + # Remove extern "C" { wrappers + text = re.sub(r'extern\s+"C"\s*\{', " ", text) + + return text + + +def extract_declarations(text): + """ + Extract top-level declarations (ending with ';') from cleaned C text. + Skips anything inside brace-delimited blocks (function bodies, structs). + Returns list of (declaration_text, line_number). + """ + declarations = [] + current_decl = [] + brace_depth = 0 + + for i, ch in enumerate(text): + if ch == "{": + brace_depth += 1 + current_decl = [] + elif ch == "}": + brace_depth -= 1 + if brace_depth < 0: + brace_depth = 0 + current_decl = [] + elif brace_depth == 0: + if ch == ";": + decl_text = "".join(current_decl).strip() + if decl_text and "(" in decl_text and ")" in decl_text: + pos_in_text = i - len(decl_text) + line_num = text[:pos_in_text].count("\n") + 1 + declarations.append((decl_text, line_num)) + current_decl = [] + else: + if not current_decl and ch == "\n": + pass # skip leading newlines + else: + current_decl.append(ch) + + return declarations + + +def extract_definitions(text): + """ + Extract top-level function definitions from cleaned C text. + A definition is a function signature followed by a '{' body. + Returns list of (signature_text, line_number). + """ + definitions = [] + brace_depth = 0 + current_sig = [] + sig_line = 1 + i = 0 + + while i < len(text): + ch = text[i] + + if ch == "{": + if brace_depth == 0 and current_sig: + sig_text = "".join(current_sig).strip() + if sig_text and "(" in sig_text and ")" in sig_text: + definitions.append((sig_text, sig_line)) + current_sig = [] + brace_depth += 1 + elif ch == "}": + brace_depth -= 1 + if brace_depth < 0: + brace_depth = 0 + if brace_depth == 0: + current_sig = [] + elif brace_depth == 0: + if ch == ";": + # This is a declaration, not a definition — skip it + current_sig = [] + else: + if not current_sig and ch == "\n": + pass + else: + if not current_sig: + sig_line = text[:i].count("\n") + 1 + current_sig.append(ch) + i += 1 + + return definitions + + +def parse_header_declarations(header_path): + """ + Parse non-inline function declarations from a header file. + Returns dict: func_name -> (param_names, line_number, raw_declaration) + """ + with open(header_path, "r") as f: + content = f.read() + + text = strip_c_comments_and_preprocessor(content) + declarations = extract_declarations(text) + + functions = {} + for decl_text, line_num in declarations: + # Skip typedefs + if re.match(r"\s*typedef\b", decl_text): + continue + # Skip extern declarations without function signatures + if re.match(r"\s*extern\b", decl_text) and "(" not in decl_text: + continue + + # Strip decorator prefixes + clean = decl_text.strip() + while True: + m = DECL_PREFIXES.match(clean) + if m: + clean = clean[m.end():] + else: + break + + result = parse_func_signature(clean) + if result: + func_name, param_names = result + # Skip ALL_CAPS names without underscores (macros) + if func_name.isupper() and "_" not in func_name: + continue + if func_name not in functions: + functions[func_name] = (param_names, line_num, decl_text.strip()) + + return functions + + +def parse_source_definitions(src_dir, module): + """ + Parse function definitions from .c source files for a module. + Returns dict: func_name -> (param_names, file_path, line_number) + """ + module_dir = os.path.join(src_dir, module) + if not os.path.isdir(module_dir): + return {} + + functions = {} + + for fname in sorted(os.listdir(module_dir)): + if not fname.endswith(".c"): + continue + # Skip test files + if fname.startswith("t-") or fname == "main.c": + continue + + fpath = os.path.join(module_dir, fname) + with open(fpath, "r") as f: + content = f.read() + + text = strip_c_comments_and_preprocessor(content) + definitions = extract_definitions(text) + + for sig_text, line_num in definitions: + # Strip 'static' qualifier — we skip static functions + clean = sig_text.strip() + if re.match(r"\s*static\b", clean): + continue + + # Strip decorator prefixes + while True: + m = DECL_PREFIXES.match(clean) + if m: + clean = clean[m.end():] + else: + break + + result = parse_func_signature(clean) + if result: + func_name, param_names = result + if func_name not in functions: + functions[func_name] = (param_names, fpath, line_num) + + return functions + + +def parse_rst_functions(rst_path): + """ + Parse function signatures from RST documentation. + Returns dict: func_name -> (param_names, line_number, raw_signature) + """ + functions = {} + + func_directive = re.compile(r"^\.\.\s+(?:c:)?function\s*::\s*(.+)$") + continuation = re.compile(r"^\s{5,}(\S.+)$") + + with open(rst_path, "r") as f: + lines = f.readlines() + + i = 0 + while i < len(lines): + line = lines[i].rstrip() + m = func_directive.match(line) + if m: + sigs = [(m.group(1).strip(), i + 1)] + + j = i + 1 + while j < len(lines): + cm = continuation.match(lines[j].rstrip()) + if cm: + text = cm.group(1).strip() + if "(" in text and ")" in text: + sigs.append((text, j + 1)) + j += 1 + else: + break + else: + break + + for sig_text, line_num in sigs: + result = parse_func_signature(sig_text) + if result: + func_name, param_names = result + if func_name not in functions: + functions[func_name] = ( + param_names, line_num, sig_text.strip() + ) + + i = j + else: + i += 1 + + return functions + + +def compare_params(funcs_a, funcs_b): + """ + Compare parameter names between two function dictionaries. + Returns list of mismatch dicts. + """ + mismatches = [] + common = set(funcs_a.keys()) & set(funcs_b.keys()) + + for func_name in sorted(common): + a_params = funcs_a[func_name][0] + b_params = funcs_b[func_name][0] + + if len(a_params) != len(b_params): + mismatches.append({ + "func": func_name, + "type": "param_count", + "a_params": a_params, + "b_params": b_params, + "a_info": funcs_a[func_name], + "b_info": funcs_b[func_name], + }) + continue + + if a_params != b_params: + diffs = [] + for k, (a, b) in enumerate(zip(a_params, b_params)): + if a != b: + diffs.append((k, a, b)) + mismatches.append({ + "func": func_name, + "type": "param_name", + "diffs": diffs, + "a_params": a_params, + "b_params": b_params, + "a_info": funcs_a[func_name], + "b_info": funcs_b[func_name], + }) + + return mismatches + + +def format_location(info, label): + """Format a location string from function info tuple.""" + if len(info) == 3: + # (params, line_or_path, line_or_raw) + # Header: (params, line_num, raw_decl) + # Doc: (params, line_num, raw_sig) + # Source: (params, file_path, line_num) + if isinstance(info[1], str) and os.sep in info[1]: + return f"{label} ({info[1]}:{info[2]})" + elif isinstance(info[1], int): + return f"{label} (line {info[1]})" + return label + + +def print_mismatches(mismatches, label_a, label_b, path_a="", path_b=""): + """Print mismatch results.""" + for mm in mismatches: + a_info = mm["a_info"] + b_info = mm["b_info"] + + if mm["type"] == "param_count": + print(f"\n {mm['func']}: parameter count mismatch") + loc_a = format_location(a_info, label_a) + loc_b = format_location(b_info, label_b) + if path_a: + print(f" {path_a}:{a_info[1]}: {mm['a_params']}") + else: + print(f" {loc_a}: {mm['a_params']}") + if path_b: + print(f" {path_b}:{b_info[1]}: {mm['b_params']}") + else: + print(f" {loc_b}: {mm['b_params']}") + else: + print(f"\n {mm['func']}:") + for idx, a_name, b_name in mm["diffs"]: + print(f" param {idx}: {label_a} has '{a_name}'" + f", {label_b} has '{b_name}'") + if path_a: + print(f" {label_a} ({path_a}:{a_info[1]})") + else: + loc = format_location(a_info, label_a) + print(f" {loc}") + if path_b: + print(f" {label_b} ({path_b}:{b_info[1]})") + else: + # Source info has (params, filepath, line_num) + if isinstance(b_info[1], str): + print(f" {label_b} ({b_info[1]}:{b_info[2]})") + else: + loc = format_location(b_info, label_b) + print(f" {loc}") + + +def find_function_span(content, func_name): + """ + Find the byte range of a function definition (signature + body) in + C source content. Returns (sig_start, body_end) or None. + sig_start is the index of the start of the signature. + body_end is the index just past the closing '}'. + """ + # Search for the function name followed by '(' + pattern = re.compile(r"\b" + re.escape(func_name) + r"\s*\(") + for m in pattern.finditer(content): + # Walk backward to find the start of the signature + # (return type, possibly on previous line) + sig_start = m.start() + # Walk backward past whitespace and return type tokens + j = sig_start - 1 + while j >= 0 and content[j] in " \t": + j -= 1 + # Walk backward past the return type line + while j >= 0 and content[j] != "\n" and content[j] != ";"\ + and content[j] != "}": + j -= 1 + if j >= 0 and content[j] == "\n": + # Check if previous line is the return type + line_start = j + 1 + prev_text = content[line_start:sig_start].strip() + if prev_text and not prev_text.startswith("/*") \ + and not prev_text.startswith("//") \ + and not prev_text.startswith("#"): + sig_start = line_start + + # Walk forward from the match to find the opening '{' + pos = m.end() + paren_depth = 1 + while pos < len(content) and paren_depth > 0: + if content[pos] == "(": + paren_depth += 1 + elif content[pos] == ")": + paren_depth -= 1 + pos += 1 + + # Now pos is right after the closing ')' of the parameter list. + # Skip whitespace to find '{' or ';' + while pos < len(content) and content[pos] in " \t\n\r": + pos += 1 + + if pos >= len(content): + continue + + if content[pos] == ";": + # This is a declaration, not a definition — skip + continue + + if content[pos] != "{": + continue + + # Found opening brace — find matching close brace + brace_depth = 1 + pos += 1 + while pos < len(content) and brace_depth > 0: + if content[pos] == "{": + brace_depth += 1 + elif content[pos] == "}": + brace_depth -= 1 + pos += 1 + + return (sig_start, pos) + + return None + + +def rename_param_in_range(content, start, end, old_name, new_name): + """ + Rename a parameter within a byte range of content, avoiding + struct member accesses (->old_name, .old_name). + Returns modified content. + """ + region = content[start:end] + + # Replace old_name as a word, but not when preceded by -> or . + # (?.]) excludes the second char of -> and the . of member access + pattern = r"(?.])\b" + re.escape(old_name) + r"\b" + new_region = re.sub(pattern, new_name, region) + + return content[:start] + new_region + content[end:] + + +def fix_source_file(filepath, func_name, renames, dry_run=False): + """ + Rename parameters in a function definition in a .c source file. + renames: list of (old_name, new_name) + Returns True if changes were made. + """ + with open(filepath, "r") as f: + content = f.read() + + span = find_function_span(content, func_name) + if span is None: + print(f" WARNING: could not find {func_name} in {filepath}") + return False + + start, end = span + region = content[start:end] + + # Check for macros above the function that capture variable names. + # If a #define before the function references old_name, renaming + # will break macro expansions. + preamble = content[:start] + safe_renames = [] + for old_name, new_name in renames: + # Check collision: new_name already exists in the function span + new_pat = r"(?.])\b" + re.escape(new_name) + r"\b" + if re.search(new_pat, region): + print(f" SKIP {func_name} rename {old_name}->{new_name}" + f" in {filepath}: collision with existing '{new_name}'") + continue + # Check macro capture: old_name in a #define above. + # Handle multi-line macros (backslash continuation). + macro_pat = (r"#\s*define\s+\w+(?:[^\n]*\\\n)*[^\n]*\b" + + re.escape(old_name) + r"\b") + if re.search(macro_pat, preamble): + print(f" SKIP {func_name} rename {old_name}->{new_name}" + f" in {filepath}: macro captures '{old_name}'") + continue + safe_renames.append((old_name, new_name)) + + if not safe_renames: + return False + + changed = False + for old_name, new_name in safe_renames: + new_content = rename_param_in_range( + content, start, end, old_name, new_name + ) + if new_content != content: + end += len(new_content) - len(content) + content = new_content + changed = True + + if changed: + if dry_run: + print(f" [dry-run] would fix {func_name} in {filepath}") + else: + with open(filepath, "w") as f: + f.write(content) + print(f" fixed {func_name} in {filepath}") + return changed + + +def fix_declaration_file(filepath, func_name, renames, dry_run=False): + """ + Rename parameters in a function declaration in a .h header file + or a .. function:: directive in an .rst doc file. + Only modifies the signature line(s), not any body. + Skips inline function definitions (only fixes pure declarations + ending with ';'). + renames: list of (old_name, new_name) + Returns True if changes were made. + """ + with open(filepath, "r") as f: + content = f.read() + + # Find the function name in the file + pattern = re.compile(r"\b" + re.escape(func_name) + r"\s*\(") + changed = False + + for m in pattern.finditer(content): + # Find the full declaration: from func_name( to the closing ) + pos = m.end() + paren_depth = 1 + while pos < len(content) and paren_depth > 0: + if content[pos] == "(": + paren_depth += 1 + elif content[pos] == ")": + paren_depth -= 1 + pos += 1 + + # For .h files, check if this is a declaration (;) or + # definition ({). Skip inline definitions — they need + # fix_source_file instead. + if filepath.endswith(".h"): + rest = content[pos:pos + 40].lstrip() + if not rest.startswith(";"): + continue # Skip inline definitions + + param_start = m.end() - 1 # the '(' + param_end = pos # just past ')' + + region = content[param_start:param_end] + new_region = region + for old_name, new_name in renames: + pat = r"\b" + re.escape(old_name) + r"\b" + new_region = re.sub(pat, new_name, new_region) + + if new_region != region: + content = content[:param_start] + new_region + content[param_end:] + changed = True + break # Only fix first occurrence + + if changed: + if dry_run: + print(f" [dry-run] would fix {func_name} in {filepath}") + else: + with open(filepath, "w") as f: + f.write(content) + print(f" fixed {func_name} in {filepath}") + return changed + + +def collect_mismatches(modules, src_dir, check_src=False): + """ + Collect all mismatches across modules. + Returns list of (mod, header_path, rst_path, hdr_doc_mm, hdr_src_mm, + hdr_funcs, doc_funcs, src_funcs). + """ + results = [] + for mod, header_path, rst_path in modules: + hdr_funcs = parse_header_declarations(header_path) + doc_funcs = parse_rst_functions(rst_path) + src_funcs = {} + + hdr_doc_mm = compare_params(hdr_funcs, doc_funcs) + + hdr_src_mm = [] + if check_src: + src_funcs = parse_source_definitions(src_dir, mod) + hdr_src_mm = compare_params(hdr_funcs, src_funcs) + + results.append((mod, header_path, rst_path, + hdr_doc_mm, hdr_src_mm, + hdr_funcs, doc_funcs, src_funcs)) + return results + + +def apply_fixes(results, src_dir, dry_run=False): + """ + Auto-fix mismatches. Strategy: + - If doc and source agree but header differs: fix header + - If doc and header agree but source differs: fix source + - If header and source agree but doc differs: fix header+source + (doc is source of truth for naming) + - If all three differ: skip (needs manual review) + """ + fixed = 0 + skipped = 0 + + for (mod, header_path, rst_path, + hdr_doc_mm, hdr_src_mm, + hdr_funcs, doc_funcs, src_funcs) in results: + + if not hdr_doc_mm and not hdr_src_mm: + continue + + # Build a unified view per function + all_funcs = set() + for mm in hdr_doc_mm: + all_funcs.add(mm["func"]) + for mm in hdr_src_mm: + all_funcs.add(mm["func"]) + + hdr_doc_by_func = {mm["func"]: mm for mm in hdr_doc_mm} + hdr_src_by_func = {mm["func"]: mm for mm in hdr_src_mm} + + for func_name in sorted(all_funcs): + hd = hdr_doc_by_func.get(func_name) + hs = hdr_src_by_func.get(func_name) + + hdr_params = hdr_funcs[func_name][0] if func_name in hdr_funcs else None + doc_params = doc_funcs[func_name][0] if func_name in doc_funcs else None + src_params = src_funcs[func_name][0] if func_name in src_funcs else None + + # Skip param count mismatches — need manual review + if (hd and hd["type"] == "param_count") or \ + (hs and hs["type"] == "param_count"): + print(f" SKIP {func_name}: parameter count mismatch" + " (manual review needed)") + skipped += 1 + continue + + # Determine which names to use + renames_hdr = [] # (old, new) for header + renames_src = [] # (old, new) for source file(s) + + if hd and hs: + # Header disagrees with both doc and source + # Check if doc and source agree + if doc_params == src_params: + # doc and source agree — fix header to match + for idx, old, new in hd["diffs"]: + renames_hdr.append((old, new)) + else: + # All three differ — use doc as truth + for idx, old, new in hd["diffs"]: + renames_hdr.append((old, new)) + # Also fix source to match doc + # Recompute source renames against doc + if src_params and doc_params and \ + len(src_params) == len(doc_params): + for k, (s, d) in enumerate( + zip(src_params, doc_params)): + if s != d: + renames_src.append((s, d)) + else: + print(f" SKIP {func_name}: all three differ" + " (manual review needed)") + skipped += 1 + continue + elif hd and not hs: + # Header vs doc mismatch only (no source mismatch or + # source not checked) + if src_params is not None: + # Source exists — check what it says + if src_params == doc_params: + # Source agrees with doc — fix header + for idx, old, new in hd["diffs"]: + renames_hdr.append((old, new)) + elif src_params == hdr_params: + # Source agrees with header — doc is truth, + # fix header + source + for idx, old, new in hd["diffs"]: + renames_hdr.append((old, new)) + renames_src.append((old, new)) + else: + print(f" SKIP {func_name}: all three differ" + " (manual review needed)") + skipped += 1 + continue + else: + # No source — just fix header to match doc + for idx, old, new in hd["diffs"]: + renames_hdr.append((old, new)) + elif hs and not hd: + # Header vs source mismatch only + # Header agrees with doc (or doc doesn't exist) + # Fix source to match header + # diffs are (idx, hdr_name, src_name) — rename + # src_name -> hdr_name in the source file + for idx, hdr_name, src_name in hs["diffs"]: + renames_src.append((src_name, hdr_name)) + + # Filter out renames that would collide. + # Case 1: new_name is already a param not being renamed. + # Case 2: overlapping renames (swaps) where new_name + # equals old_name of another rename — sequential + # regex can't handle this correctly. + def filter_safe_renames(renames, params, label): + if not renames or not params: + return renames + old_names = {old for old, new in renames} + new_names = {new for old, new in renames} + # Check for overlapping renames (any new is also an old) + if old_names & new_names: + for old, new in renames: + print(f" SKIP {func_name} rename {old}->{new}" + f" in {label}: overlapping renames") + return [] + # Check if new_name collides with an unchanged param + unchanged = set(params) - old_names + safe = [] + for old, new in renames: + if new in unchanged: + print(f" SKIP {func_name} rename {old}->{new}" + f" in {label}: '{new}' already a param") + else: + safe.append((old, new)) + return safe + + renames_hdr = filter_safe_renames( + renames_hdr, hdr_params, "header" + ) + + # Apply fixes + if renames_hdr: + ok = fix_declaration_file( + header_path, func_name, renames_hdr, dry_run + ) + if not ok: + # Declaration not found (inline definition only). + # Try fix_source_file on the header instead. + ok = fix_source_file( + header_path, func_name, renames_hdr, dry_run + ) + if ok: + fixed += 1 + + if renames_src: + # Find the source file + if func_name in src_funcs: + src_path = src_funcs[func_name][1] + ok = fix_source_file( + src_path, func_name, renames_src, dry_run + ) + if ok: + fixed += 1 + + return fixed, skipped + + +def main(): + parser = argparse.ArgumentParser( + description="Check parameter name consistency between " + "FLINT headers, docs, and source files" + ) + parser.add_argument( + "--module", "-m", + help="Only check a specific module (e.g., fmpz, acb_poly)", + ) + parser.add_argument( + "--check-src", action="store_true", + help="Also check .c source files against headers", + ) + parser.add_argument( + "--fix", action="store_true", + help="Auto-fix mismatches (doc is source of truth for naming)", + ) + parser.add_argument( + "--dry-run", action="store_true", + help="With --fix, show what would be changed without modifying files", + ) + parser.add_argument( + "--list-missing", action="store_true", + help="Also list functions missing from source or docs", + ) + args = parser.parse_args() + + if args.fix: + args.check_src = True + + src_dir = "src" + doc_dir = "doc/source" + + if not os.path.isdir(src_dir) or not os.path.isdir(doc_dir): + print("Error: run from the FLINT root directory", file=sys.stderr) + sys.exit(1) + + modules = find_modules(src_dir, doc_dir) + if args.module: + modules = [(m, h, r) for m, h, r in modules if m == args.module] + if not modules: + print(f"Error: module '{args.module}' not found", file=sys.stderr) + sys.exit(1) + + results = collect_mismatches(modules, src_dir, args.check_src) + + if args.fix: + fixed, skipped = apply_fixes(results, src_dir, args.dry_run) + print(f"\n{'='*70}") + print(f"Fixed: {fixed}, Skipped: {skipped}") + if not args.dry_run and fixed > 0: + print("Re-running check to verify...") + results2 = collect_mismatches(modules, src_dir, True) + remaining = sum( + len(r[3]) + len(r[4]) for r in results2 + ) + print(f"Remaining mismatches: {remaining}") + return + + total_hdr_doc = 0 + total_hdr_doc_checked = 0 + total_hdr_src = 0 + total_hdr_src_checked = 0 + + for (mod, header_path, rst_path, + hdr_doc_mm, hdr_src_mm, + hdr_funcs, doc_funcs, src_funcs) in results: + + hdr_doc_common = len(set(hdr_funcs.keys()) & set(doc_funcs.keys())) + total_hdr_doc += len(hdr_doc_mm) + total_hdr_doc_checked += hdr_doc_common + + if args.check_src: + hdr_src_common = len( + set(hdr_funcs.keys()) & set(src_funcs.keys()) + ) + total_hdr_src += len(hdr_src_mm) + total_hdr_src_checked += hdr_src_common + + if hdr_doc_mm or hdr_src_mm: + print(f"\n{'='*70}") + print(f"Module: {mod}") + print(f"{'='*70}") + + if hdr_doc_mm: + print(f"\n --- header vs doc ---") + print_mismatches( + hdr_doc_mm, "header", "doc", + path_a=header_path, path_b=rst_path, + ) + + if hdr_src_mm: + print(f"\n --- header vs source ---") + print_mismatches( + hdr_src_mm, "header", "source", + path_a=header_path, + ) + + if args.list_missing: + only_hdr = sorted( + set(hdr_funcs.keys()) - set(doc_funcs.keys()) + ) + only_doc = sorted( + set(doc_funcs.keys()) - set(hdr_funcs.keys()) + ) + if only_hdr or only_doc: + if not hdr_doc_mm and not hdr_src_mm: + print(f"\n{'='*70}") + print(f"Module: {mod}") + print(f"{'='*70}") + if only_hdr: + print(f"\n In header but not in docs ({len(only_hdr)}):") + for fn in only_hdr: + print(f" {fn}") + if only_doc: + print(f"\n In docs but not in header ({len(only_doc)}):") + for fn in only_doc: + print(f" {fn}") + + print(f"\n{'='*70}") + print(f"Header vs doc: {total_hdr_doc} mismatches in" + f" {total_hdr_doc_checked} functions" + f" across {len(modules)} modules") + if args.check_src: + print(f"Header vs src: {total_hdr_src} mismatches in" + f" {total_hdr_src_checked} functions" + f" across {len(modules)} modules") + + if total_hdr_doc > 0 or total_hdr_src > 0: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/dev/test_check_param_names.py b/dev/test_check_param_names.py new file mode 100644 index 0000000000..94ed9fcf68 --- /dev/null +++ b/dev/test_check_param_names.py @@ -0,0 +1,1072 @@ +#!/usr/bin/env python3 +"""Tests for check_param_names.py""" + +import os +import tempfile +import unittest + +from check_param_names import ( + extract_param_name, + split_params, + parse_func_signature, + strip_c_comments_and_preprocessor, + extract_declarations, + extract_definitions, + parse_header_declarations, + parse_source_definitions, + parse_rst_functions, + compare_params, + find_function_span, + rename_param_in_range, + fix_source_file, + fix_declaration_file, + collect_mismatches, + apply_fixes, +) + + +class TestExtractParamName(unittest.TestCase): + def test_simple_types(self): + self.assertEqual(extract_param_name("slong n"), "n") + self.assertEqual(extract_param_name("ulong x"), "x") + self.assertEqual(extract_param_name("int flag"), "flag") + + def test_const_qualified(self): + self.assertEqual(extract_param_name("const fmpz_t x"), "x") + self.assertEqual(extract_param_name("const padic_ctx_t ctx"), "ctx") + + def test_pointer_types(self): + self.assertEqual(extract_param_name("ulong * out"), "out") + self.assertEqual(extract_param_name("FILE * file"), "file") + self.assertEqual(extract_param_name("nn_srcptr xp"), "xp") + self.assertEqual(extract_param_name("char * str"), "str") + + def test_flint_unused(self): + self.assertEqual( + extract_param_name("const padic_ctx_t FLINT_UNUSED(ctx)"), + "ctx", + ) + self.assertEqual( + extract_param_name("slong FLINT_UNUSED(n)"), + "n", + ) + + def test_array_brackets(self): + self.assertEqual(extract_param_name("ulong out[3]"), "out") + self.assertEqual(extract_param_name("slong perm[10]"), "perm") + + def test_void_and_ellipsis(self): + self.assertIsNone(extract_param_name("void")) + self.assertIsNone(extract_param_name("...")) + self.assertIsNone(extract_param_name("")) + + def test_type_only(self): + """When param is just a type with no name, return None.""" + self.assertIsNone(extract_param_name("int")) + self.assertIsNone(extract_param_name("void")) + self.assertIsNone(extract_param_name("FILE")) + + def test_complex_types(self): + self.assertEqual(extract_param_name("flint_bitcnt_t bits"), "bits") + self.assertEqual(extract_param_name("gr_ctx_t ctx"), "ctx") + self.assertEqual(extract_param_name("const fmpz_poly_t poly"), "poly") + + +class TestSplitParams(unittest.TestCase): + def test_simple(self): + self.assertEqual( + split_params("fmpz_t f, const fmpz_t g, const fmpz_t h"), + ["fmpz_t f", "const fmpz_t g", "const fmpz_t h"], + ) + + def test_single_param(self): + self.assertEqual(split_params("fmpz_t f"), ["fmpz_t f"]) + + def test_empty(self): + self.assertEqual(split_params(""), []) + + def test_void(self): + self.assertEqual(split_params("void"), ["void"]) + + def test_function_pointer(self): + result = split_params( + "int (*cmp)(void *, const void *, const void *), slong n" + ) + self.assertEqual(len(result), 2) + self.assertIn("(*cmp)", result[0]) + self.assertEqual(result[1], "slong n") + + def test_nested_parens(self): + result = split_params("acb_calc_func_t func, void * param") + self.assertEqual(result, ["acb_calc_func_t func", "void * param"]) + + +class TestParseFuncSignature(unittest.TestCase): + def test_simple(self): + result = parse_func_signature( + "void fmpz_add(fmpz_t f, const fmpz_t g, const fmpz_t h)" + ) + self.assertEqual(result, ("fmpz_add", ["f", "g", "h"])) + + def test_no_params(self): + result = parse_func_signature("void _fmpz_cleanup(void)") + self.assertEqual(result, ("_fmpz_cleanup", [])) + + def test_return_type_pointer(self): + result = parse_func_signature("mpz_ptr _fmpz_promote(fmpz_t f)") + self.assertEqual(result, ("_fmpz_promote", ["f"])) + + def test_with_semicolon(self): + result = parse_func_signature( + "void fmpz_clear(fmpz_t f);" + ) + self.assertEqual(result, ("fmpz_clear", ["f"])) + + def test_multiline(self): + result = parse_func_signature( + "void fmpz_multi_CRT_ui(fmpz_t output, " + "nn_srcptr residues, const fmpz_comb_t comb, " + "fmpz_comb_temp_t ctemp, int sign)" + ) + self.assertEqual( + result, + ("fmpz_multi_CRT_ui", ["output", "residues", "comb", + "ctemp", "sign"]), + ) + + def test_function_pointer_param(self): + result = parse_func_signature( + "void qsort_r(void * base, slong n, slong size, " + "int (*cmp)(void *, const void *, const void *), void * arg)" + ) + self.assertIsNotNone(result) + self.assertEqual(result[0], "qsort_r") + self.assertIn("cmp", result[1]) + + def test_no_parens(self): + self.assertIsNone(parse_func_signature("int x")) + + def test_typedef_function_pointer(self): + """Function pointer typedefs are filtered by the caller + (parse_header_declarations skips typedefs), so parse_func_signature + may return a result — it just won't be a correct function name. + The important thing is that parse_header_declarations filters these.""" + # Direct call may parse it (incorrectly), but that's OK + # since the caller filters typedefs before calling this function. + pass + + def test_warn_unused_result(self): + """Decorators should be handled by the caller, not parse_func_signature.""" + result = parse_func_signature( + "int gr_poly_set(gr_poly_t res, const gr_poly_t src, gr_ctx_t ctx)" + ) + self.assertEqual(result, ("gr_poly_set", ["res", "src", "ctx"])) + + def test_flint_unused_param(self): + result = parse_func_signature( + "void foo(slong n, const nmod_t FLINT_UNUSED(mod))" + ) + self.assertEqual(result, ("foo", ["n", "mod"])) + + def test_ellipsis(self): + result = parse_func_signature( + "int flint_printf(const char * fmt, ...)" + ) + self.assertEqual(result, ("flint_printf", ["fmt"])) + + +class TestStripCCommentsAndPreprocessor(unittest.TestCase): + def test_block_comment(self): + text = strip_c_comments_and_preprocessor( + "/* comment */\nvoid foo(int x);\n" + ) + self.assertIn("void foo(int x);", text) + self.assertNotIn("comment", text) + + def test_multiline_block_comment(self): + text = strip_c_comments_and_preprocessor( + "/* multi\nline\ncomment */\nvoid bar(void);\n" + ) + self.assertIn("void bar(void);", text) + self.assertNotIn("multi", text) + + def test_preprocessor_directive(self): + text = strip_c_comments_and_preprocessor( + "#include \nvoid baz(int n);\n" + ) + self.assertIn("void baz(int n);", text) + self.assertNotIn("include", text) + + def test_multiline_macro(self): + text = strip_c_comments_and_preprocessor( + "#define FOO(x) \\\n ((x) + 1)\nvoid quux(int a);\n" + ) + self.assertIn("void quux(int a);", text) + self.assertNotIn("FOO", text) + + def test_ifdef_keeps_body(self): + """Code inside #ifdef blocks should be preserved.""" + text = strip_c_comments_and_preprocessor( + "#ifdef HAVE_FEATURE\n" + "void feature_func(int x);\n" + "#endif\n" + ) + self.assertIn("void feature_func(int x);", text) + + def test_extern_c_removed(self): + text = strip_c_comments_and_preprocessor( + '#ifdef __cplusplus\n' + 'extern "C" {\n' + '#endif\n' + 'void foo(int x);\n' + '#ifdef __cplusplus\n' + '}\n' + '#endif\n' + ) + self.assertIn("void foo(int x);", text) + self.assertNotIn('extern "C"', text) + + +class TestExtractDeclarations(unittest.TestCase): + def test_simple_declaration(self): + text = "void foo(int x);\n" + decls = extract_declarations(text) + self.assertEqual(len(decls), 1) + self.assertIn("void foo(int x)", decls[0][0]) + + def test_skips_function_body(self): + text = ( + "void inline_func(int x)\n" + "{\n" + " return;\n" + "}\n" + "void declared_func(int y);\n" + ) + decls = extract_declarations(text) + self.assertEqual(len(decls), 1) + self.assertIn("declared_func", decls[0][0]) + + def test_multiple_declarations(self): + text = ( + "void foo(int a);\n" + "int bar(slong b, ulong c);\n" + ) + decls = extract_declarations(text) + self.assertEqual(len(decls), 2) + + def test_multiline_declaration(self): + text = ( + "void long_func(int a,\n" + " int b,\n" + " int c);\n" + ) + decls = extract_declarations(text) + self.assertEqual(len(decls), 1) + sig = decls[0][0] + self.assertIn("long_func", sig) + + def test_no_declarations_when_no_parens(self): + text = "int x;\nchar * str;\n" + decls = extract_declarations(text) + self.assertEqual(len(decls), 0) + + +class TestExtractDefinitions(unittest.TestCase): + def test_simple_definition(self): + text = ( + "void foo(int x)\n" + "{\n" + " return;\n" + "}\n" + ) + defs = extract_definitions(text) + self.assertEqual(len(defs), 1) + self.assertIn("foo", defs[0][0]) + + def test_skips_declarations(self): + text = ( + "void declared_only(int x);\n" + "void defined(int y)\n" + "{\n" + " return;\n" + "}\n" + ) + defs = extract_definitions(text) + self.assertEqual(len(defs), 1) + self.assertIn("defined", defs[0][0]) + + def test_return_type_on_separate_line(self): + text = ( + "void\n" + "foo(int x, int y)\n" + "{\n" + " return;\n" + "}\n" + ) + defs = extract_definitions(text) + self.assertEqual(len(defs), 1) + self.assertIn("foo", defs[0][0]) + + def test_multiple_definitions(self): + text = ( + "void foo(int a)\n" + "{\n" + " return;\n" + "}\n" + "\n" + "int bar(int b)\n" + "{\n" + " return 0;\n" + "}\n" + ) + defs = extract_definitions(text) + self.assertEqual(len(defs), 2) + + def test_nested_braces(self): + text = ( + "void foo(int x)\n" + "{\n" + " if (x) {\n" + " return;\n" + " }\n" + "}\n" + ) + defs = extract_definitions(text) + self.assertEqual(len(defs), 1) + self.assertIn("foo", defs[0][0]) + + def test_static_and_nonstatic(self): + """Both static and non-static definitions should be extracted.""" + text = ( + "static void helper(int x)\n" + "{\n" + " return;\n" + "}\n" + "\n" + "void public_func(int y)\n" + "{\n" + " helper(y);\n" + "}\n" + ) + defs = extract_definitions(text) + self.assertEqual(len(defs), 2) + + +class TestParseHeaderDeclarations(unittest.TestCase): + def _write_temp(self, content): + fd, path = tempfile.mkstemp(suffix=".h") + os.write(fd, content.encode()) + os.close(fd) + return path + + def test_basic_header(self): + path = self._write_temp( + "/* copyright */\n" + "#ifndef FOO_H\n" + "#define FOO_H\n" + "#ifdef __cplusplus\n" + 'extern "C" {\n' + "#endif\n" + "void foo_init(foo_t x);\n" + "void foo_clear(foo_t x);\n" + "int foo_add(foo_t r, const foo_t a, const foo_t b);\n" + "#ifdef __cplusplus\n" + "}\n" + "#endif\n" + "#endif\n" + ) + try: + funcs = parse_header_declarations(path) + self.assertIn("foo_init", funcs) + self.assertIn("foo_clear", funcs) + self.assertIn("foo_add", funcs) + self.assertEqual(funcs["foo_init"][0], ["x"]) + self.assertEqual(funcs["foo_add"][0], ["r", "a", "b"]) + finally: + os.unlink(path) + + def test_skips_inline(self): + path = self._write_temp( + "void foo_declared(int x);\n" + "static inline void foo_inline(int y)\n" + "{\n" + " return;\n" + "}\n" + ) + try: + funcs = parse_header_declarations(path) + self.assertIn("foo_declared", funcs) + self.assertNotIn("foo_inline", funcs) + finally: + os.unlink(path) + + def test_skips_typedefs(self): + path = self._write_temp( + "typedef void (*callback_t)(int x);\n" + "void real_func(int y);\n" + ) + try: + funcs = parse_header_declarations(path) + self.assertIn("real_func", funcs) + self.assertNotIn("callback_t", funcs) + finally: + os.unlink(path) + + def test_multiline_declaration(self): + path = self._write_temp( + "void long_func(int a,\n" + " int b,\n" + " int c);\n" + ) + try: + funcs = parse_header_declarations(path) + self.assertIn("long_func", funcs) + self.assertEqual(funcs["long_func"][0], ["a", "b", "c"]) + finally: + os.unlink(path) + + def test_warn_unused_result(self): + path = self._write_temp( + "WARN_UNUSED_RESULT int foo(int x, int y);\n" + ) + try: + funcs = parse_header_declarations(path) + self.assertIn("foo", funcs) + self.assertEqual(funcs["foo"][0], ["x", "y"]) + finally: + os.unlink(path) + + +class TestParseSourceDefinitions(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.src_dir = self.tmpdir + self.mod_dir = os.path.join(self.tmpdir, "mymod") + os.makedirs(self.mod_dir) + + def tearDown(self): + import shutil + shutil.rmtree(self.tmpdir) + + def _write_c(self, name, content): + path = os.path.join(self.mod_dir, name) + with open(path, "w") as f: + f.write(content) + + def test_simple_definition(self): + self._write_c("add.c", + "void mymod_add(foo_t r, const foo_t a, const foo_t b)\n" + "{\n" + " return;\n" + "}\n" + ) + funcs = parse_source_definitions(self.src_dir, "mymod") + self.assertIn("mymod_add", funcs) + self.assertEqual(funcs["mymod_add"][0], ["r", "a", "b"]) + + def test_skips_static(self): + self._write_c("impl.c", + "static void helper(int x)\n" + "{\n" + " return;\n" + "}\n" + "\n" + "void mymod_func(int y)\n" + "{\n" + " helper(y);\n" + "}\n" + ) + funcs = parse_source_definitions(self.src_dir, "mymod") + self.assertNotIn("helper", funcs) + self.assertIn("mymod_func", funcs) + + def test_skips_test_files(self): + self._write_c("t-add.c", + "void test_add(int x)\n" + "{\n" + " return;\n" + "}\n" + ) + self._write_c("add.c", + "void mymod_add(int y)\n" + "{\n" + " return;\n" + "}\n" + ) + funcs = parse_source_definitions(self.src_dir, "mymod") + self.assertNotIn("test_add", funcs) + self.assertIn("mymod_add", funcs) + + def test_return_type_on_separate_line(self): + self._write_c("gcd.c", + "void\n" + "mymod_gcd(foo_t f, const foo_t g, const foo_t h)\n" + "{\n" + " return;\n" + "}\n" + ) + funcs = parse_source_definitions(self.src_dir, "mymod") + self.assertIn("mymod_gcd", funcs) + self.assertEqual(funcs["mymod_gcd"][0], ["f", "g", "h"]) + + +class TestParseRstFunctions(unittest.TestCase): + def _write_temp(self, content): + fd, path = tempfile.mkstemp(suffix=".rst") + os.write(fd, content.encode()) + os.close(fd) + return path + + def test_simple_directive(self): + path = self._write_temp( + ".. function:: void foo_init(foo_t x)\n" + "\n" + " Initializes x.\n" + ) + try: + funcs = parse_rst_functions(path) + self.assertIn("foo_init", funcs) + self.assertEqual(funcs["foo_init"][0], ["x"]) + finally: + os.unlink(path) + + def test_continuation_lines(self): + path = self._write_temp( + ".. function:: void foo_init(foo_t x)\n" + " void foo_clear(foo_t x)\n" + "\n" + " Init/clear.\n" + ) + try: + funcs = parse_rst_functions(path) + self.assertIn("foo_init", funcs) + self.assertIn("foo_clear", funcs) + finally: + os.unlink(path) + + def test_c_function_directive(self): + path = self._write_temp( + ".. c:function:: int bar(slong n, ulong k)\n" + "\n" + " Computes something.\n" + ) + try: + funcs = parse_rst_functions(path) + self.assertIn("bar", funcs) + self.assertEqual(funcs["bar"][0], ["n", "k"]) + finally: + os.unlink(path) + + def test_multiple_directives(self): + path = self._write_temp( + ".. function:: void alpha(int a)\n" + "\n" + " Does alpha.\n" + "\n" + ".. function:: void beta(int b)\n" + "\n" + " Does beta.\n" + ) + try: + funcs = parse_rst_functions(path) + self.assertIn("alpha", funcs) + self.assertIn("beta", funcs) + finally: + os.unlink(path) + + def test_line_numbers(self): + path = self._write_temp( + "Title\n" + "=====\n" + "\n" + ".. function:: void first(int x)\n" + "\n" + " First func.\n" + "\n" + ".. function:: void second(int y)\n" + "\n" + " Second func.\n" + ) + try: + funcs = parse_rst_functions(path) + self.assertEqual(funcs["first"][1], 4) + self.assertEqual(funcs["second"][1], 8) + finally: + os.unlink(path) + + +class TestCompareParams(unittest.TestCase): + def _make_funcs(self, **kwargs): + """Helper: create funcs dict from name -> param_list mapping.""" + return { + name: (params, 1, f"void {name}(...)") + for name, params in kwargs.items() + } + + def test_no_mismatches(self): + a = self._make_funcs(foo=["x", "y"], bar=["a"]) + b = self._make_funcs(foo=["x", "y"], bar=["a"]) + self.assertEqual(compare_params(a, b), []) + + def test_name_mismatch(self): + a = self._make_funcs(foo=["x", "y"]) + b = self._make_funcs(foo=["a", "b"]) + mm = compare_params(a, b) + self.assertEqual(len(mm), 1) + self.assertEqual(mm[0]["func"], "foo") + self.assertEqual(mm[0]["type"], "param_name") + self.assertEqual(mm[0]["diffs"], [(0, "x", "a"), (1, "y", "b")]) + + def test_count_mismatch(self): + a = self._make_funcs(foo=["x", "y"]) + b = self._make_funcs(foo=["x", "y", "z"]) + mm = compare_params(a, b) + self.assertEqual(len(mm), 1) + self.assertEqual(mm[0]["type"], "param_count") + + def test_only_common_compared(self): + a = self._make_funcs(foo=["x"], only_a=["y"]) + b = self._make_funcs(foo=["x"], only_b=["z"]) + mm = compare_params(a, b) + self.assertEqual(len(mm), 0) + + def test_partial_mismatch(self): + a = self._make_funcs(foo=["x", "y", "z"]) + b = self._make_funcs(foo=["x", "b", "z"]) + mm = compare_params(a, b) + self.assertEqual(len(mm), 1) + self.assertEqual(mm[0]["diffs"], [(1, "y", "b")]) + + +class TestEndToEnd(unittest.TestCase): + """End-to-end tests with temp header + RST + source files.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.src_dir = os.path.join(self.tmpdir, "src") + self.doc_dir = os.path.join(self.tmpdir, "doc", "source") + os.makedirs(self.src_dir) + os.makedirs(self.doc_dir) + + def tearDown(self): + import shutil + shutil.rmtree(self.tmpdir) + + def _write(self, path, content): + full = os.path.join(self.tmpdir, path) + os.makedirs(os.path.dirname(full), exist_ok=True) + with open(full, "w") as f: + f.write(content) + return full + + def test_header_vs_doc_match(self): + hdr = self._write("src/mymod.h", + "void mymod_add(foo_t r, const foo_t a, const foo_t b);\n" + ) + rst = self._write("doc/source/mymod.rst", + ".. function:: void mymod_add(foo_t r, const foo_t a, const foo_t b)\n" + "\n" + " Adds a and b.\n" + ) + h = parse_header_declarations(hdr) + d = parse_rst_functions(rst) + mm = compare_params(h, d) + self.assertEqual(len(mm), 0) + + def test_header_vs_doc_mismatch(self): + hdr = self._write("src/mymod.h", + "void mymod_add(foo_t f, const foo_t g, const foo_t h);\n" + ) + rst = self._write("doc/source/mymod.rst", + ".. function:: void mymod_add(foo_t r, const foo_t a, const foo_t b)\n" + "\n" + " Adds a and b.\n" + ) + h = parse_header_declarations(hdr) + d = parse_rst_functions(rst) + mm = compare_params(h, d) + self.assertEqual(len(mm), 1) + self.assertEqual(mm[0]["func"], "mymod_add") + + def test_header_vs_source_match(self): + self._write("src/mymod.h", + "void mymod_add(foo_t r, const foo_t a, const foo_t b);\n" + ) + self._write("src/mymod/add.c", + "void mymod_add(foo_t r, const foo_t a, const foo_t b)\n" + "{\n" + " return;\n" + "}\n" + ) + hdr = os.path.join(self.tmpdir, "src", "mymod.h") + h = parse_header_declarations(hdr) + s = parse_source_definitions( + os.path.join(self.tmpdir, "src"), "mymod" + ) + mm = compare_params(h, s) + self.assertEqual(len(mm), 0) + + def test_header_vs_source_mismatch(self): + self._write("src/mymod.h", + "void mymod_add(foo_t r, const foo_t a, const foo_t b);\n" + ) + self._write("src/mymod/add.c", + "void mymod_add(foo_t f, const foo_t g, const foo_t h)\n" + "{\n" + " return;\n" + "}\n" + ) + hdr = os.path.join(self.tmpdir, "src", "mymod.h") + h = parse_header_declarations(hdr) + s = parse_source_definitions( + os.path.join(self.tmpdir, "src"), "mymod" + ) + mm = compare_params(h, s) + self.assertEqual(len(mm), 1) + self.assertEqual(mm[0]["func"], "mymod_add") + self.assertEqual( + mm[0]["diffs"], + [(0, "r", "f"), (1, "a", "g"), (2, "b", "h")], + ) + + def test_three_way_consistency(self): + """All three sources match — no mismatches.""" + hdr = self._write("src/mymod.h", + "void mymod_add(foo_t r, const foo_t a, const foo_t b);\n" + "void mymod_neg(foo_t r, const foo_t a);\n" + ) + self._write("doc/source/mymod.rst", + ".. function:: void mymod_add(foo_t r, const foo_t a, const foo_t b)\n" + "\n" + " Adds.\n" + "\n" + ".. function:: void mymod_neg(foo_t r, const foo_t a)\n" + "\n" + " Negates.\n" + ) + self._write("src/mymod/add.c", + "void mymod_add(foo_t r, const foo_t a, const foo_t b)\n" + "{\n" + " return;\n" + "}\n" + ) + self._write("src/mymod/neg.c", + "void mymod_neg(foo_t r, const foo_t a)\n" + "{\n" + " return;\n" + "}\n" + ) + h = parse_header_declarations(hdr) + d = parse_rst_functions( + os.path.join(self.tmpdir, "doc", "source", "mymod.rst") + ) + s = parse_source_definitions( + os.path.join(self.tmpdir, "src"), "mymod" + ) + self.assertEqual(compare_params(h, d), []) + self.assertEqual(compare_params(h, s), []) + + def test_three_way_doc_off(self): + """Doc has different names, but header and source agree.""" + hdr = self._write("src/mymod.h", + "void mymod_add(foo_t r, const foo_t a, const foo_t b);\n" + ) + self._write("doc/source/mymod.rst", + ".. function:: void mymod_add(foo_t res, const foo_t x, const foo_t y)\n" + "\n" + " Adds.\n" + ) + self._write("src/mymod/add.c", + "void mymod_add(foo_t r, const foo_t a, const foo_t b)\n" + "{\n" + " return;\n" + "}\n" + ) + h = parse_header_declarations(hdr) + d = parse_rst_functions( + os.path.join(self.tmpdir, "doc", "source", "mymod.rst") + ) + s = parse_source_definitions( + os.path.join(self.tmpdir, "src"), "mymod" + ) + hd = compare_params(h, d) + hs = compare_params(h, s) + self.assertEqual(len(hd), 1) # header vs doc mismatch + self.assertEqual(len(hs), 0) # header vs source match + + +class TestRealFiles(unittest.TestCase): + """Integration tests against actual FLINT files (skipped if not present).""" + + def setUp(self): + if not os.path.isfile("src/fmpz.h"): + self.skipTest("Not running from FLINT root directory") + + def test_fmpz_header_parses(self): + funcs = parse_header_declarations("src/fmpz.h") + self.assertGreater(len(funcs), 100) + self.assertIn("fmpz_add", funcs) + self.assertEqual(funcs["fmpz_add"][0], ["f", "g", "h"]) + + def test_fmpz_doc_parses(self): + funcs = parse_rst_functions("doc/source/fmpz.rst") + self.assertGreater(len(funcs), 100) + self.assertIn("fmpz_add", funcs) + + def test_fmpz_source_parses(self): + funcs = parse_source_definitions("src", "fmpz") + self.assertGreater(len(funcs), 50) + self.assertIn("fmpz_add", funcs) + + def test_acb_header_skips_inlines(self): + funcs = parse_header_declarations("src/acb.h") + # acb_add is an inline — should be skipped + self.assertNotIn("acb_add", funcs) + # acb_mul is a declaration — should be found + self.assertIn("acb_mul", funcs) + + def test_gr_poly_header_with_warn_unused(self): + funcs = parse_header_declarations("src/gr_poly.h") + self.assertIn("gr_poly_set", funcs) + self.assertIn("gr_poly_mul", funcs) + + def test_padic_mat_header_parsed(self): + """padic_mat header should parse correctly.""" + funcs = parse_header_declarations("src/padic_mat.h") + self.assertIn("padic_mat_is_canonical", funcs, + "padic_mat_is_canonical should be in header") + + +class TestRenameParamInRange(unittest.TestCase): + def test_simple_rename(self): + content = "void foo(int x)\n{\n return x + 1;\n}\n" + result = rename_param_in_range(content, 0, len(content), "x", "n") + self.assertEqual(result, "void foo(int n)\n{\n return n + 1;\n}\n") + + def test_avoids_struct_member(self): + """Must not rename struct->field or obj.field patterns.""" + content = ( + "void foo(int length)\n" + "{\n" + " x = length + cache->length;\n" + "}\n" + ) + result = rename_param_in_range( + content, 0, len(content), "length", "len" + ) + self.assertIn("int len)", result) + self.assertIn("x = len + cache->length", result) + + def test_avoids_dot_member(self): + content = "void foo(int x)\n{\n a = x + obj.x;\n}\n" + result = rename_param_in_range(content, 0, len(content), "x", "n") + self.assertIn("a = n + obj.x", result) + + def test_word_boundary(self): + """Must not rename substrings.""" + content = "void foo(int x)\n{\n int xx = x;\n}\n" + result = rename_param_in_range(content, 0, len(content), "x", "n") + self.assertIn("int xx = n", result) + self.assertIn("int n)", result) + + +class TestFindFunctionSpan(unittest.TestCase): + def test_simple_function(self): + content = "void foo(int x)\n{\n return;\n}\n" + span = find_function_span(content, "foo") + self.assertIsNotNone(span) + start, end = span + self.assertIn("foo", content[start:end]) + self.assertIn("return", content[start:end]) + + def test_skips_declarations(self): + content = ( + "void foo(int x);\n" + "void bar(int y)\n" + "{\n" + " return;\n" + "}\n" + ) + span = find_function_span(content, "foo") + self.assertIsNone(span) + span = find_function_span(content, "bar") + self.assertIsNotNone(span) + + def test_return_type_on_separate_line(self): + content = "void\nfoo(int x)\n{\n return;\n}\n" + span = find_function_span(content, "foo") + self.assertIsNotNone(span) + + +class TestFixSourceFile(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + import shutil + shutil.rmtree(self.tmpdir) + + def _write(self, name, content): + path = os.path.join(self.tmpdir, name) + with open(path, "w") as f: + f.write(content) + return path + + def test_fix_simple(self): + path = self._write("func.c", + "void mymod_foo(foo_t f, const foo_t g)\n" + "{\n" + " bar(f, g);\n" + "}\n" + ) + fix_source_file(path, "mymod_foo", [("f", "r"), ("g", "a")]) + with open(path) as f: + result = f.read() + self.assertIn("foo_t r, const foo_t a)", result) + self.assertIn("bar(r, a);", result) + + def test_fix_avoids_struct_members(self): + path = self._write("func.c", + "void mymod_foo(slong length)\n" + "{\n" + " x = length + cache->length;\n" + "}\n" + ) + fix_source_file(path, "mymod_foo", [("length", "len")]) + with open(path) as f: + result = f.read() + self.assertIn("slong len)", result) + self.assertIn("x = len + cache->length;", result) + + +class TestFixDeclarationFile(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + import shutil + shutil.rmtree(self.tmpdir) + + def _write(self, name, content): + path = os.path.join(self.tmpdir, name) + with open(path, "w") as f: + f.write(content) + return path + + def test_fix_header_declaration(self): + path = self._write("mod.h", + "void mymod_foo(foo_t f, const foo_t g);\n" + "void mymod_bar(int x);\n" + ) + fix_declaration_file(path, "mymod_foo", [("f", "r"), ("g", "a")]) + with open(path) as f: + result = f.read() + self.assertIn("foo_t r, const foo_t a)", result) + # mymod_bar should be untouched + self.assertIn("mymod_bar(int x)", result) + + +class TestApplyFixes(unittest.TestCase): + """Test the full fix workflow.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + import shutil + shutil.rmtree(self.tmpdir) + + def _write(self, path, content): + full = os.path.join(self.tmpdir, path) + os.makedirs(os.path.dirname(full), exist_ok=True) + with open(full, "w") as f: + f.write(content) + return full + + def test_fix_source_to_match_header(self): + """When header and doc agree, source gets fixed.""" + hdr_path = self._write("src/mymod.h", + "void mymod_foo(foo_t r, const foo_t a);\n" + ) + rst_path = self._write("doc/source/mymod.rst", + ".. function:: void mymod_foo(foo_t r, const foo_t a)\n" + "\n" + " Does foo.\n" + ) + src_path = self._write("src/mymod/foo.c", + "void mymod_foo(foo_t f, const foo_t g)\n" + "{\n" + " bar(f, g);\n" + "}\n" + ) + src_dir = os.path.join(self.tmpdir, "src") + doc_dir = os.path.join(self.tmpdir, "doc", "source") + modules = [("mymod", hdr_path, rst_path)] + results = collect_mismatches(modules, src_dir, check_src=True) + fixed, skipped = apply_fixes(results, src_dir) + self.assertEqual(fixed, 1) + self.assertEqual(skipped, 0) + with open(src_path) as f: + content = f.read() + self.assertIn("foo_t r, const foo_t a)", content) + self.assertIn("bar(r, a);", content) + + def test_fix_header_and_source_to_match_doc(self): + """When header and source agree but doc differs, + both header and source get fixed to match doc.""" + hdr_path = self._write("src/mymod.h", + "void mymod_foo(foo_t f, const foo_t g);\n" + ) + rst_path = self._write("doc/source/mymod.rst", + ".. function:: void mymod_foo(foo_t r, const foo_t a)\n" + "\n" + " Does foo.\n" + ) + src_path = self._write("src/mymod/foo.c", + "void mymod_foo(foo_t f, const foo_t g)\n" + "{\n" + " bar(f, g);\n" + "}\n" + ) + src_dir = os.path.join(self.tmpdir, "src") + modules = [("mymod", hdr_path, rst_path)] + results = collect_mismatches(modules, src_dir, check_src=True) + fixed, skipped = apply_fixes(results, src_dir) + self.assertEqual(fixed, 2) # header + source + with open(hdr_path) as f: + hdr = f.read() + self.assertIn("foo_t r, const foo_t a)", hdr) + with open(src_path) as f: + src = f.read() + self.assertIn("foo_t r, const foo_t a)", src) + self.assertIn("bar(r, a);", src) + + def test_fix_struct_member_safety(self): + """Renaming param must not affect struct->member access.""" + hdr_path = self._write("src/mymod.h", + "void mymod_foo(slong len);\n" + ) + rst_path = self._write("doc/source/mymod.rst", + ".. function:: void mymod_foo(slong len)\n" + "\n" + " Does foo.\n" + ) + src_path = self._write("src/mymod/foo.c", + "void mymod_foo(slong length)\n" + "{\n" + " x = length + cache->length;\n" + "}\n" + ) + src_dir = os.path.join(self.tmpdir, "src") + modules = [("mymod", hdr_path, rst_path)] + results = collect_mismatches(modules, src_dir, check_src=True) + fixed, skipped = apply_fixes(results, src_dir) + self.assertEqual(fixed, 1) + with open(src_path) as f: + content = f.read() + self.assertIn("slong len)", content) + self.assertIn("x = len + cache->length;", content) + + +if __name__ == "__main__": + unittest.main() From 81b63006eeda6abe662a908bc2f36f84a919e324 Mon Sep 17 00:00:00 2001 From: Edgar Costa Date: Wed, 11 Mar 2026 10:53:36 -0400 Subject: [PATCH 2/2] Detect swapped parameters and warn about possible bugs When two parameters have their names exchanged between header and source (e.g. header has val,len but source has len,val), the tool now prints a WARNING flagging a possible bug. This catches cases like _padic_poly_fprint_pretty where the header and source disagree on parameter order. --- dev/check_param_names.py | 17 +++++++++++++++++ dev/test_check_param_names.py | 22 ++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/dev/check_param_names.py b/dev/check_param_names.py index 95c6e33f09..c3fe491ba2 100755 --- a/dev/check_param_names.py +++ b/dev/check_param_names.py @@ -484,13 +484,25 @@ def compare_params(funcs_a, funcs_b): if a_params != b_params: diffs = [] + swaps = [] for k, (a, b) in enumerate(zip(a_params, b_params)): if a != b: diffs.append((k, a, b)) + # Detect swapped parameters: pairs where a[i]=b[j] and + # a[j]=b[i] for i