diff --git a/dev/check_param_names.py b/dev/check_param_names.py new file mode 100755 index 0000000000..c3fe491ba2 --- /dev/null +++ b/dev/check_param_names.py @@ -0,0 +1,1103 @@ +#!/usr/bin/env python3 +""" +Compare parameter names between FLINT header files, RST documentation, +and .c source files. + +Reports mismatches where the same function has different parameter names +across these sources. + +Usage: + python3 dev/check_param_names.py [--module MODULE] [--check-src] + python3 dev/check_param_names.py --check-src -m fmpz + +Run from the FLINT root directory. +""" + +import os +import re +import sys +import argparse +from collections import OrderedDict + + +# Decorators/macros that appear before return types in headers +DECL_PREFIXES = re.compile( + r"^(?:" + r"WARN_UNUSED_RESULT|" + r"FLINT_FORCE_INLINE|" + r"\w+_INLINE|" # FMPZ_INLINE, GR_POLY_INLINE, etc. + r"FLINT_DLL" + r")\s*", +) + + +def find_modules(src_dir, doc_dir): + """Find modules that have both a header and a doc file.""" + modules = [] + for f in sorted(os.listdir(src_dir)): + if f.endswith(".h"): + mod = f[:-2] # strip .h + rst = os.path.join(doc_dir, mod + ".rst") + header = os.path.join(src_dir, f) + if os.path.isfile(rst): + modules.append((mod, header, rst)) + return modules + + +def extract_param_name(param_str): + """ + Extract the parameter name from a C parameter declaration string. + + Examples: + "const fmpz_t x" -> "x" + "slong n" -> "n" + "ulong * out" -> "out" + "FILE * file" -> "file" + "nn_srcptr xp" -> "xp" + "..." -> None + "void" -> None + "const padic_ctx_t FLINT_UNUSED(ctx)" -> "ctx" + """ + param_str = param_str.strip() + if not param_str or param_str == "void" or param_str == "...": + return None + + # Handle FLINT_UNUSED(name) + m = re.search(r"FLINT_UNUSED\((\w+)\)", param_str) + if m: + return m.group(1) + + # Remove array brackets: e.g. "ulong out[3]" -> name is "out" + param_str = re.sub(r"\[.*?\]", "", param_str).strip() + + # Split on spaces and asterisks, take the last word token + tokens = re.findall(r"\w+", param_str) + if not tokens: + return None + + # The last token is the name, unless the entire thing is just type(s) + # with no name (e.g., "void" - already handled above) + name = tokens[-1] + + # Sanity check: if the "name" looks like a type keyword, skip + type_keywords = { + "void", "int", "long", "short", "char", "unsigned", "signed", + "float", "double", "size_t", "ssize_t", "FILE", + } + if name in type_keywords: + return None + + # If the "name" looks like a type (ends with _t, _struct, _ptr, + # _srcptr), it's an unnamed parameter (type only). + type_suffixes = ("_t", "_struct", "_ptr", "_srcptr") + if name.endswith(type_suffixes): + return None + + return name + + +def split_params(params_str): + """ + Split a parameter list string by commas, respecting nested parentheses + and function pointer syntax. + """ + params = [] + depth = 0 + current = [] + for ch in params_str: + if ch in ("(", "["): + depth += 1 + current.append(ch) + elif ch in (")", "]"): + depth -= 1 + current.append(ch) + elif ch == "," and depth == 0: + params.append("".join(current).strip()) + current = [] + else: + current.append(ch) + last = "".join(current).strip() + if last: + params.append(last) + return params + + +def parse_func_signature(sig): + """ + Parse a C function signature string, returning (func_name, [param_names]) + or None if it can't be parsed. + + The signature should be like: + "void fmpz_add(fmpz_t f, const fmpz_t g, const fmpz_t h)" + """ + # Normalize whitespace + sig = re.sub(r"\s+", " ", sig).strip() + + # Remove trailing semicolons + sig = sig.rstrip(";").strip() + + # Must have parentheses + if "(" not in sig or ")" not in sig: + return None + + # Find the first '(' at depth 0 (the parameter list opening) + depth = 0 + paren_start = -1 + for i, ch in enumerate(sig): + if ch == "(": + if depth == 0: + paren_start = i + break + depth += 1 + elif ch == ")": + depth -= 1 + + if paren_start < 0: + return None + + # Find matching close paren + depth = 1 + paren_end = -1 + for i in range(paren_start + 1, len(sig)): + if sig[i] == "(": + depth += 1 + elif sig[i] == ")": + depth -= 1 + if depth == 0: + paren_end = i + break + + if paren_end < 0: + return None + + before_paren = sig[:paren_start].strip() + params_str = sig[paren_start + 1 : paren_end].strip() + + # Skip function pointer declarations: "typedef void (*name)(...)" + if "(*" in before_paren: + return None + + # Extract function name: last word before the paren + name_match = re.search(r"(\w+)\s*$", before_paren) + if not name_match: + return None + func_name = name_match.group(1) + + # Parse parameter names + if not params_str or params_str == "void": + return (func_name, []) + + param_parts = split_params(params_str) + param_names = [] + for p in param_parts: + # Function pointer parameters: "int (*cmp)(void *, const void *)" + fptr_match = re.match(r".*\(\s*\*\s*(\w+)\s*\)", p) + if fptr_match: + param_names.append(fptr_match.group(1)) + continue + + name = extract_param_name(p) + if name is not None: + param_names.append(name) + + return (func_name, param_names) + + +def strip_c_comments_and_preprocessor(content): + """ + Remove block comments, line comments, and preprocessor directive lines + from C source code. Also remove extern "C" { wrappers. + Returns cleaned text. + """ + # Remove block comments + content = re.sub(r"/\*.*?\*/", " ", content, flags=re.DOTALL) + + # Remove line comments (shouldn't appear in FLINT style, but just in case) + content = re.sub(r"//[^\n]*", "", content) + + # Remove preprocessor directive lines (keep code inside #ifdef blocks) + lines = content.split("\n") + clean_lines = [] + in_macro_def = False + for line in lines: + stripped = line.strip() + if in_macro_def: + if not stripped.endswith("\\"): + in_macro_def = False + clean_lines.append("") + continue + if stripped.startswith("#"): + if stripped.endswith("\\"): + in_macro_def = True + clean_lines.append("") + continue + clean_lines.append(line) + + text = "\n".join(clean_lines) + + # Remove extern "C" { wrappers + text = re.sub(r'extern\s+"C"\s*\{', " ", text) + + return text + + +def extract_declarations(text): + """ + Extract top-level declarations (ending with ';') from cleaned C text. + Skips anything inside brace-delimited blocks (function bodies, structs). + Returns list of (declaration_text, line_number). + """ + declarations = [] + current_decl = [] + brace_depth = 0 + + for i, ch in enumerate(text): + if ch == "{": + brace_depth += 1 + current_decl = [] + elif ch == "}": + brace_depth -= 1 + if brace_depth < 0: + brace_depth = 0 + current_decl = [] + elif brace_depth == 0: + if ch == ";": + decl_text = "".join(current_decl).strip() + if decl_text and "(" in decl_text and ")" in decl_text: + pos_in_text = i - len(decl_text) + line_num = text[:pos_in_text].count("\n") + 1 + declarations.append((decl_text, line_num)) + current_decl = [] + else: + if not current_decl and ch == "\n": + pass # skip leading newlines + else: + current_decl.append(ch) + + return declarations + + +def extract_definitions(text): + """ + Extract top-level function definitions from cleaned C text. + A definition is a function signature followed by a '{' body. + Returns list of (signature_text, line_number). + """ + definitions = [] + brace_depth = 0 + current_sig = [] + sig_line = 1 + i = 0 + + while i < len(text): + ch = text[i] + + if ch == "{": + if brace_depth == 0 and current_sig: + sig_text = "".join(current_sig).strip() + if sig_text and "(" in sig_text and ")" in sig_text: + definitions.append((sig_text, sig_line)) + current_sig = [] + brace_depth += 1 + elif ch == "}": + brace_depth -= 1 + if brace_depth < 0: + brace_depth = 0 + if brace_depth == 0: + current_sig = [] + elif brace_depth == 0: + if ch == ";": + # This is a declaration, not a definition — skip it + current_sig = [] + else: + if not current_sig and ch == "\n": + pass + else: + if not current_sig: + sig_line = text[:i].count("\n") + 1 + current_sig.append(ch) + i += 1 + + return definitions + + +def parse_header_declarations(header_path): + """ + Parse non-inline function declarations from a header file. + Returns dict: func_name -> (param_names, line_number, raw_declaration) + """ + with open(header_path, "r") as f: + content = f.read() + + text = strip_c_comments_and_preprocessor(content) + declarations = extract_declarations(text) + + functions = {} + for decl_text, line_num in declarations: + # Skip typedefs + if re.match(r"\s*typedef\b", decl_text): + continue + # Skip extern declarations without function signatures + if re.match(r"\s*extern\b", decl_text) and "(" not in decl_text: + continue + + # Strip decorator prefixes + clean = decl_text.strip() + while True: + m = DECL_PREFIXES.match(clean) + if m: + clean = clean[m.end():] + else: + break + + result = parse_func_signature(clean) + if result: + func_name, param_names = result + # Skip ALL_CAPS names without underscores (macros) + if func_name.isupper() and "_" not in func_name: + continue + if func_name not in functions: + functions[func_name] = (param_names, line_num, decl_text.strip()) + + return functions + + +def parse_source_definitions(src_dir, module): + """ + Parse function definitions from .c source files for a module. + Returns dict: func_name -> (param_names, file_path, line_number) + """ + module_dir = os.path.join(src_dir, module) + if not os.path.isdir(module_dir): + return {} + + functions = {} + + for fname in sorted(os.listdir(module_dir)): + if not fname.endswith(".c"): + continue + # Skip test files + if fname.startswith("t-") or fname == "main.c": + continue + + fpath = os.path.join(module_dir, fname) + with open(fpath, "r") as f: + content = f.read() + + text = strip_c_comments_and_preprocessor(content) + definitions = extract_definitions(text) + + for sig_text, line_num in definitions: + # Strip 'static' qualifier — we skip static functions + clean = sig_text.strip() + if re.match(r"\s*static\b", clean): + continue + + # Strip decorator prefixes + while True: + m = DECL_PREFIXES.match(clean) + if m: + clean = clean[m.end():] + else: + break + + result = parse_func_signature(clean) + if result: + func_name, param_names = result + if func_name not in functions: + functions[func_name] = (param_names, fpath, line_num) + + return functions + + +def parse_rst_functions(rst_path): + """ + Parse function signatures from RST documentation. + Returns dict: func_name -> (param_names, line_number, raw_signature) + """ + functions = {} + + func_directive = re.compile(r"^\.\.\s+(?:c:)?function\s*::\s*(.+)$") + continuation = re.compile(r"^\s{5,}(\S.+)$") + + with open(rst_path, "r") as f: + lines = f.readlines() + + i = 0 + while i < len(lines): + line = lines[i].rstrip() + m = func_directive.match(line) + if m: + sigs = [(m.group(1).strip(), i + 1)] + + j = i + 1 + while j < len(lines): + cm = continuation.match(lines[j].rstrip()) + if cm: + text = cm.group(1).strip() + if "(" in text and ")" in text: + sigs.append((text, j + 1)) + j += 1 + else: + break + else: + break + + for sig_text, line_num in sigs: + result = parse_func_signature(sig_text) + if result: + func_name, param_names = result + if func_name not in functions: + functions[func_name] = ( + param_names, line_num, sig_text.strip() + ) + + i = j + else: + i += 1 + + return functions + + +def compare_params(funcs_a, funcs_b): + """ + Compare parameter names between two function dictionaries. + Returns list of mismatch dicts. + """ + mismatches = [] + common = set(funcs_a.keys()) & set(funcs_b.keys()) + + for func_name in sorted(common): + a_params = funcs_a[func_name][0] + b_params = funcs_b[func_name][0] + + if len(a_params) != len(b_params): + mismatches.append({ + "func": func_name, + "type": "param_count", + "a_params": a_params, + "b_params": b_params, + "a_info": funcs_a[func_name], + "b_info": funcs_b[func_name], + }) + continue + + if a_params != b_params: + diffs = [] + swaps = [] + for k, (a, b) in enumerate(zip(a_params, b_params)): + if a != b: + diffs.append((k, a, b)) + # Detect swapped parameters: pairs where a[i]=b[j] and + # a[j]=b[i] for i= 0 and content[j] in " \t": + j -= 1 + # Walk backward past the return type line + while j >= 0 and content[j] != "\n" and content[j] != ";"\ + and content[j] != "}": + j -= 1 + if j >= 0 and content[j] == "\n": + # Check if previous line is the return type + line_start = j + 1 + prev_text = content[line_start:sig_start].strip() + if prev_text and not prev_text.startswith("/*") \ + and not prev_text.startswith("//") \ + and not prev_text.startswith("#"): + sig_start = line_start + + # Walk forward from the match to find the opening '{' + pos = m.end() + paren_depth = 1 + while pos < len(content) and paren_depth > 0: + if content[pos] == "(": + paren_depth += 1 + elif content[pos] == ")": + paren_depth -= 1 + pos += 1 + + # Now pos is right after the closing ')' of the parameter list. + # Skip whitespace to find '{' or ';' + while pos < len(content) and content[pos] in " \t\n\r": + pos += 1 + + if pos >= len(content): + continue + + if content[pos] == ";": + # This is a declaration, not a definition — skip + continue + + if content[pos] != "{": + continue + + # Found opening brace — find matching close brace + brace_depth = 1 + pos += 1 + while pos < len(content) and brace_depth > 0: + if content[pos] == "{": + brace_depth += 1 + elif content[pos] == "}": + brace_depth -= 1 + pos += 1 + + return (sig_start, pos) + + return None + + +def rename_param_in_range(content, start, end, old_name, new_name): + """ + Rename a parameter within a byte range of content, avoiding + struct member accesses (->old_name, .old_name). + Returns modified content. + """ + region = content[start:end] + + # Replace old_name as a word, but not when preceded by -> or . + # (?.]) excludes the second char of -> and the . of member access + pattern = r"(?.])\b" + re.escape(old_name) + r"\b" + new_region = re.sub(pattern, new_name, region) + + return content[:start] + new_region + content[end:] + + +def fix_source_file(filepath, func_name, renames, dry_run=False): + """ + Rename parameters in a function definition in a .c source file. + renames: list of (old_name, new_name) + Returns True if changes were made. + """ + with open(filepath, "r") as f: + content = f.read() + + span = find_function_span(content, func_name) + if span is None: + print(f" WARNING: could not find {func_name} in {filepath}") + return False + + start, end = span + region = content[start:end] + + # Check for macros above the function that capture variable names. + # If a #define before the function references old_name, renaming + # will break macro expansions. + preamble = content[:start] + safe_renames = [] + for old_name, new_name in renames: + # Check collision: new_name already exists in the function span + new_pat = r"(?.])\b" + re.escape(new_name) + r"\b" + if re.search(new_pat, region): + print(f" SKIP {func_name} rename {old_name}->{new_name}" + f" in {filepath}: collision with existing '{new_name}'") + continue + # Check macro capture: old_name in a #define above. + # Handle multi-line macros (backslash continuation). + macro_pat = (r"#\s*define\s+\w+(?:[^\n]*\\\n)*[^\n]*\b" + + re.escape(old_name) + r"\b") + if re.search(macro_pat, preamble): + print(f" SKIP {func_name} rename {old_name}->{new_name}" + f" in {filepath}: macro captures '{old_name}'") + continue + safe_renames.append((old_name, new_name)) + + if not safe_renames: + return False + + changed = False + for old_name, new_name in safe_renames: + new_content = rename_param_in_range( + content, start, end, old_name, new_name + ) + if new_content != content: + end += len(new_content) - len(content) + content = new_content + changed = True + + if changed: + if dry_run: + print(f" [dry-run] would fix {func_name} in {filepath}") + else: + with open(filepath, "w") as f: + f.write(content) + print(f" fixed {func_name} in {filepath}") + return changed + + +def fix_declaration_file(filepath, func_name, renames, dry_run=False): + """ + Rename parameters in a function declaration in a .h header file + or a .. function:: directive in an .rst doc file. + Only modifies the signature line(s), not any body. + Skips inline function definitions (only fixes pure declarations + ending with ';'). + renames: list of (old_name, new_name) + Returns True if changes were made. + """ + with open(filepath, "r") as f: + content = f.read() + + # Find the function name in the file + pattern = re.compile(r"\b" + re.escape(func_name) + r"\s*\(") + changed = False + + for m in pattern.finditer(content): + # Find the full declaration: from func_name( to the closing ) + pos = m.end() + paren_depth = 1 + while pos < len(content) and paren_depth > 0: + if content[pos] == "(": + paren_depth += 1 + elif content[pos] == ")": + paren_depth -= 1 + pos += 1 + + # For .h files, check if this is a declaration (;) or + # definition ({). Skip inline definitions — they need + # fix_source_file instead. + if filepath.endswith(".h"): + rest = content[pos:pos + 40].lstrip() + if not rest.startswith(";"): + continue # Skip inline definitions + + param_start = m.end() - 1 # the '(' + param_end = pos # just past ')' + + region = content[param_start:param_end] + new_region = region + for old_name, new_name in renames: + pat = r"\b" + re.escape(old_name) + r"\b" + new_region = re.sub(pat, new_name, new_region) + + if new_region != region: + content = content[:param_start] + new_region + content[param_end:] + changed = True + break # Only fix first occurrence + + if changed: + if dry_run: + print(f" [dry-run] would fix {func_name} in {filepath}") + else: + with open(filepath, "w") as f: + f.write(content) + print(f" fixed {func_name} in {filepath}") + return changed + + +def collect_mismatches(modules, src_dir, check_src=False): + """ + Collect all mismatches across modules. + Returns list of (mod, header_path, rst_path, hdr_doc_mm, hdr_src_mm, + hdr_funcs, doc_funcs, src_funcs). + """ + results = [] + for mod, header_path, rst_path in modules: + hdr_funcs = parse_header_declarations(header_path) + doc_funcs = parse_rst_functions(rst_path) + src_funcs = {} + + hdr_doc_mm = compare_params(hdr_funcs, doc_funcs) + + hdr_src_mm = [] + if check_src: + src_funcs = parse_source_definitions(src_dir, mod) + hdr_src_mm = compare_params(hdr_funcs, src_funcs) + + results.append((mod, header_path, rst_path, + hdr_doc_mm, hdr_src_mm, + hdr_funcs, doc_funcs, src_funcs)) + return results + + +def apply_fixes(results, src_dir, dry_run=False): + """ + Auto-fix mismatches. Strategy: + - If doc and source agree but header differs: fix header + - If doc and header agree but source differs: fix source + - If header and source agree but doc differs: fix header+source + (doc is source of truth for naming) + - If all three differ: skip (needs manual review) + """ + fixed = 0 + skipped = 0 + + for (mod, header_path, rst_path, + hdr_doc_mm, hdr_src_mm, + hdr_funcs, doc_funcs, src_funcs) in results: + + if not hdr_doc_mm and not hdr_src_mm: + continue + + # Build a unified view per function + all_funcs = set() + for mm in hdr_doc_mm: + all_funcs.add(mm["func"]) + for mm in hdr_src_mm: + all_funcs.add(mm["func"]) + + hdr_doc_by_func = {mm["func"]: mm for mm in hdr_doc_mm} + hdr_src_by_func = {mm["func"]: mm for mm in hdr_src_mm} + + for func_name in sorted(all_funcs): + hd = hdr_doc_by_func.get(func_name) + hs = hdr_src_by_func.get(func_name) + + hdr_params = hdr_funcs[func_name][0] if func_name in hdr_funcs else None + doc_params = doc_funcs[func_name][0] if func_name in doc_funcs else None + src_params = src_funcs[func_name][0] if func_name in src_funcs else None + + # Skip param count mismatches — need manual review + if (hd and hd["type"] == "param_count") or \ + (hs and hs["type"] == "param_count"): + print(f" SKIP {func_name}: parameter count mismatch" + " (manual review needed)") + skipped += 1 + continue + + # Determine which names to use + renames_hdr = [] # (old, new) for header + renames_src = [] # (old, new) for source file(s) + + if hd and hs: + # Header disagrees with both doc and source + # Check if doc and source agree + if doc_params == src_params: + # doc and source agree — fix header to match + for idx, old, new in hd["diffs"]: + renames_hdr.append((old, new)) + else: + # All three differ — use doc as truth + for idx, old, new in hd["diffs"]: + renames_hdr.append((old, new)) + # Also fix source to match doc + # Recompute source renames against doc + if src_params and doc_params and \ + len(src_params) == len(doc_params): + for k, (s, d) in enumerate( + zip(src_params, doc_params)): + if s != d: + renames_src.append((s, d)) + else: + print(f" SKIP {func_name}: all three differ" + " (manual review needed)") + skipped += 1 + continue + elif hd and not hs: + # Header vs doc mismatch only (no source mismatch or + # source not checked) + if src_params is not None: + # Source exists — check what it says + if src_params == doc_params: + # Source agrees with doc — fix header + for idx, old, new in hd["diffs"]: + renames_hdr.append((old, new)) + elif src_params == hdr_params: + # Source agrees with header — doc is truth, + # fix header + source + for idx, old, new in hd["diffs"]: + renames_hdr.append((old, new)) + renames_src.append((old, new)) + else: + print(f" SKIP {func_name}: all three differ" + " (manual review needed)") + skipped += 1 + continue + else: + # No source — just fix header to match doc + for idx, old, new in hd["diffs"]: + renames_hdr.append((old, new)) + elif hs and not hd: + # Header vs source mismatch only + # Header agrees with doc (or doc doesn't exist) + # Fix source to match header + # diffs are (idx, hdr_name, src_name) — rename + # src_name -> hdr_name in the source file + for idx, hdr_name, src_name in hs["diffs"]: + renames_src.append((src_name, hdr_name)) + + # Filter out renames that would collide. + # Case 1: new_name is already a param not being renamed. + # Case 2: overlapping renames (swaps) where new_name + # equals old_name of another rename — sequential + # regex can't handle this correctly. + def filter_safe_renames(renames, params, label): + if not renames or not params: + return renames + old_names = {old for old, new in renames} + new_names = {new for old, new in renames} + # Check for overlapping renames (any new is also an old) + if old_names & new_names: + for old, new in renames: + print(f" SKIP {func_name} rename {old}->{new}" + f" in {label}: overlapping renames") + return [] + # Check if new_name collides with an unchanged param + unchanged = set(params) - old_names + safe = [] + for old, new in renames: + if new in unchanged: + print(f" SKIP {func_name} rename {old}->{new}" + f" in {label}: '{new}' already a param") + else: + safe.append((old, new)) + return safe + + renames_hdr = filter_safe_renames( + renames_hdr, hdr_params, "header" + ) + + # Apply fixes + if renames_hdr: + ok = fix_declaration_file( + header_path, func_name, renames_hdr, dry_run + ) + if not ok: + # Declaration not found (inline definition only). + # Try fix_source_file on the header instead. + ok = fix_source_file( + header_path, func_name, renames_hdr, dry_run + ) + if ok: + fixed += 1 + + if renames_src: + # Find the source file + if func_name in src_funcs: + src_path = src_funcs[func_name][1] + ok = fix_source_file( + src_path, func_name, renames_src, dry_run + ) + if ok: + fixed += 1 + + return fixed, skipped + + +def main(): + parser = argparse.ArgumentParser( + description="Check parameter name consistency between " + "FLINT headers, docs, and source files" + ) + parser.add_argument( + "--module", "-m", + help="Only check a specific module (e.g., fmpz, acb_poly)", + ) + parser.add_argument( + "--check-src", action="store_true", + help="Also check .c source files against headers", + ) + parser.add_argument( + "--fix", action="store_true", + help="Auto-fix mismatches (doc is source of truth for naming)", + ) + parser.add_argument( + "--dry-run", action="store_true", + help="With --fix, show what would be changed without modifying files", + ) + parser.add_argument( + "--list-missing", action="store_true", + help="Also list functions missing from source or docs", + ) + args = parser.parse_args() + + if args.fix: + args.check_src = True + + src_dir = "src" + doc_dir = "doc/source" + + if not os.path.isdir(src_dir) or not os.path.isdir(doc_dir): + print("Error: run from the FLINT root directory", file=sys.stderr) + sys.exit(1) + + modules = find_modules(src_dir, doc_dir) + if args.module: + modules = [(m, h, r) for m, h, r in modules if m == args.module] + if not modules: + print(f"Error: module '{args.module}' not found", file=sys.stderr) + sys.exit(1) + + results = collect_mismatches(modules, src_dir, args.check_src) + + if args.fix: + fixed, skipped = apply_fixes(results, src_dir, args.dry_run) + print(f"\n{'='*70}") + print(f"Fixed: {fixed}, Skipped: {skipped}") + if not args.dry_run and fixed > 0: + print("Re-running check to verify...") + results2 = collect_mismatches(modules, src_dir, True) + remaining = sum( + len(r[3]) + len(r[4]) for r in results2 + ) + print(f"Remaining mismatches: {remaining}") + return + + total_hdr_doc = 0 + total_hdr_doc_checked = 0 + total_hdr_src = 0 + total_hdr_src_checked = 0 + + for (mod, header_path, rst_path, + hdr_doc_mm, hdr_src_mm, + hdr_funcs, doc_funcs, src_funcs) in results: + + hdr_doc_common = len(set(hdr_funcs.keys()) & set(doc_funcs.keys())) + total_hdr_doc += len(hdr_doc_mm) + total_hdr_doc_checked += hdr_doc_common + + if args.check_src: + hdr_src_common = len( + set(hdr_funcs.keys()) & set(src_funcs.keys()) + ) + total_hdr_src += len(hdr_src_mm) + total_hdr_src_checked += hdr_src_common + + if hdr_doc_mm or hdr_src_mm: + print(f"\n{'='*70}") + print(f"Module: {mod}") + print(f"{'='*70}") + + if hdr_doc_mm: + print(f"\n --- header vs doc ---") + print_mismatches( + hdr_doc_mm, "header", "doc", + path_a=header_path, path_b=rst_path, + ) + + if hdr_src_mm: + print(f"\n --- header vs source ---") + print_mismatches( + hdr_src_mm, "header", "source", + path_a=header_path, + ) + + if args.list_missing: + only_hdr = sorted( + set(hdr_funcs.keys()) - set(doc_funcs.keys()) + ) + only_doc = sorted( + set(doc_funcs.keys()) - set(hdr_funcs.keys()) + ) + if only_hdr or only_doc: + if not hdr_doc_mm and not hdr_src_mm: + print(f"\n{'='*70}") + print(f"Module: {mod}") + print(f"{'='*70}") + if only_hdr: + print(f"\n In header but not in docs ({len(only_hdr)}):") + for fn in only_hdr: + print(f" {fn}") + if only_doc: + print(f"\n In docs but not in header ({len(only_doc)}):") + for fn in only_doc: + print(f" {fn}") + + print(f"\n{'='*70}") + print(f"Header vs doc: {total_hdr_doc} mismatches in" + f" {total_hdr_doc_checked} functions" + f" across {len(modules)} modules") + if args.check_src: + print(f"Header vs src: {total_hdr_src} mismatches in" + f" {total_hdr_src_checked} functions" + f" across {len(modules)} modules") + + if total_hdr_doc > 0 or total_hdr_src > 0: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/dev/test_check_param_names.py b/dev/test_check_param_names.py new file mode 100644 index 0000000000..51a87600f5 --- /dev/null +++ b/dev/test_check_param_names.py @@ -0,0 +1,1094 @@ +#!/usr/bin/env python3 +"""Tests for check_param_names.py""" + +import os +import tempfile +import unittest + +from check_param_names import ( + extract_param_name, + split_params, + parse_func_signature, + strip_c_comments_and_preprocessor, + extract_declarations, + extract_definitions, + parse_header_declarations, + parse_source_definitions, + parse_rst_functions, + compare_params, + find_function_span, + rename_param_in_range, + fix_source_file, + fix_declaration_file, + collect_mismatches, + apply_fixes, +) + + +class TestExtractParamName(unittest.TestCase): + def test_simple_types(self): + self.assertEqual(extract_param_name("slong n"), "n") + self.assertEqual(extract_param_name("ulong x"), "x") + self.assertEqual(extract_param_name("int flag"), "flag") + + def test_const_qualified(self): + self.assertEqual(extract_param_name("const fmpz_t x"), "x") + self.assertEqual(extract_param_name("const padic_ctx_t ctx"), "ctx") + + def test_pointer_types(self): + self.assertEqual(extract_param_name("ulong * out"), "out") + self.assertEqual(extract_param_name("FILE * file"), "file") + self.assertEqual(extract_param_name("nn_srcptr xp"), "xp") + self.assertEqual(extract_param_name("char * str"), "str") + + def test_flint_unused(self): + self.assertEqual( + extract_param_name("const padic_ctx_t FLINT_UNUSED(ctx)"), + "ctx", + ) + self.assertEqual( + extract_param_name("slong FLINT_UNUSED(n)"), + "n", + ) + + def test_array_brackets(self): + self.assertEqual(extract_param_name("ulong out[3]"), "out") + self.assertEqual(extract_param_name("slong perm[10]"), "perm") + + def test_void_and_ellipsis(self): + self.assertIsNone(extract_param_name("void")) + self.assertIsNone(extract_param_name("...")) + self.assertIsNone(extract_param_name("")) + + def test_type_only(self): + """When param is just a type with no name, return None.""" + self.assertIsNone(extract_param_name("int")) + self.assertIsNone(extract_param_name("void")) + self.assertIsNone(extract_param_name("FILE")) + + def test_complex_types(self): + self.assertEqual(extract_param_name("flint_bitcnt_t bits"), "bits") + self.assertEqual(extract_param_name("gr_ctx_t ctx"), "ctx") + self.assertEqual(extract_param_name("const fmpz_poly_t poly"), "poly") + + +class TestSplitParams(unittest.TestCase): + def test_simple(self): + self.assertEqual( + split_params("fmpz_t f, const fmpz_t g, const fmpz_t h"), + ["fmpz_t f", "const fmpz_t g", "const fmpz_t h"], + ) + + def test_single_param(self): + self.assertEqual(split_params("fmpz_t f"), ["fmpz_t f"]) + + def test_empty(self): + self.assertEqual(split_params(""), []) + + def test_void(self): + self.assertEqual(split_params("void"), ["void"]) + + def test_function_pointer(self): + result = split_params( + "int (*cmp)(void *, const void *, const void *), slong n" + ) + self.assertEqual(len(result), 2) + self.assertIn("(*cmp)", result[0]) + self.assertEqual(result[1], "slong n") + + def test_nested_parens(self): + result = split_params("acb_calc_func_t func, void * param") + self.assertEqual(result, ["acb_calc_func_t func", "void * param"]) + + +class TestParseFuncSignature(unittest.TestCase): + def test_simple(self): + result = parse_func_signature( + "void fmpz_add(fmpz_t f, const fmpz_t g, const fmpz_t h)" + ) + self.assertEqual(result, ("fmpz_add", ["f", "g", "h"])) + + def test_no_params(self): + result = parse_func_signature("void _fmpz_cleanup(void)") + self.assertEqual(result, ("_fmpz_cleanup", [])) + + def test_return_type_pointer(self): + result = parse_func_signature("mpz_ptr _fmpz_promote(fmpz_t f)") + self.assertEqual(result, ("_fmpz_promote", ["f"])) + + def test_with_semicolon(self): + result = parse_func_signature( + "void fmpz_clear(fmpz_t f);" + ) + self.assertEqual(result, ("fmpz_clear", ["f"])) + + def test_multiline(self): + result = parse_func_signature( + "void fmpz_multi_CRT_ui(fmpz_t output, " + "nn_srcptr residues, const fmpz_comb_t comb, " + "fmpz_comb_temp_t ctemp, int sign)" + ) + self.assertEqual( + result, + ("fmpz_multi_CRT_ui", ["output", "residues", "comb", + "ctemp", "sign"]), + ) + + def test_function_pointer_param(self): + result = parse_func_signature( + "void qsort_r(void * base, slong n, slong size, " + "int (*cmp)(void *, const void *, const void *), void * arg)" + ) + self.assertIsNotNone(result) + self.assertEqual(result[0], "qsort_r") + self.assertIn("cmp", result[1]) + + def test_no_parens(self): + self.assertIsNone(parse_func_signature("int x")) + + def test_typedef_function_pointer(self): + """Function pointer typedefs are filtered by the caller + (parse_header_declarations skips typedefs), so parse_func_signature + may return a result — it just won't be a correct function name. + The important thing is that parse_header_declarations filters these.""" + # Direct call may parse it (incorrectly), but that's OK + # since the caller filters typedefs before calling this function. + pass + + def test_warn_unused_result(self): + """Decorators should be handled by the caller, not parse_func_signature.""" + result = parse_func_signature( + "int gr_poly_set(gr_poly_t res, const gr_poly_t src, gr_ctx_t ctx)" + ) + self.assertEqual(result, ("gr_poly_set", ["res", "src", "ctx"])) + + def test_flint_unused_param(self): + result = parse_func_signature( + "void foo(slong n, const nmod_t FLINT_UNUSED(mod))" + ) + self.assertEqual(result, ("foo", ["n", "mod"])) + + def test_ellipsis(self): + result = parse_func_signature( + "int flint_printf(const char * fmt, ...)" + ) + self.assertEqual(result, ("flint_printf", ["fmt"])) + + +class TestStripCCommentsAndPreprocessor(unittest.TestCase): + def test_block_comment(self): + text = strip_c_comments_and_preprocessor( + "/* comment */\nvoid foo(int x);\n" + ) + self.assertIn("void foo(int x);", text) + self.assertNotIn("comment", text) + + def test_multiline_block_comment(self): + text = strip_c_comments_and_preprocessor( + "/* multi\nline\ncomment */\nvoid bar(void);\n" + ) + self.assertIn("void bar(void);", text) + self.assertNotIn("multi", text) + + def test_preprocessor_directive(self): + text = strip_c_comments_and_preprocessor( + "#include \nvoid baz(int n);\n" + ) + self.assertIn("void baz(int n);", text) + self.assertNotIn("include", text) + + def test_multiline_macro(self): + text = strip_c_comments_and_preprocessor( + "#define FOO(x) \\\n ((x) + 1)\nvoid quux(int a);\n" + ) + self.assertIn("void quux(int a);", text) + self.assertNotIn("FOO", text) + + def test_ifdef_keeps_body(self): + """Code inside #ifdef blocks should be preserved.""" + text = strip_c_comments_and_preprocessor( + "#ifdef HAVE_FEATURE\n" + "void feature_func(int x);\n" + "#endif\n" + ) + self.assertIn("void feature_func(int x);", text) + + def test_extern_c_removed(self): + text = strip_c_comments_and_preprocessor( + '#ifdef __cplusplus\n' + 'extern "C" {\n' + '#endif\n' + 'void foo(int x);\n' + '#ifdef __cplusplus\n' + '}\n' + '#endif\n' + ) + self.assertIn("void foo(int x);", text) + self.assertNotIn('extern "C"', text) + + +class TestExtractDeclarations(unittest.TestCase): + def test_simple_declaration(self): + text = "void foo(int x);\n" + decls = extract_declarations(text) + self.assertEqual(len(decls), 1) + self.assertIn("void foo(int x)", decls[0][0]) + + def test_skips_function_body(self): + text = ( + "void inline_func(int x)\n" + "{\n" + " return;\n" + "}\n" + "void declared_func(int y);\n" + ) + decls = extract_declarations(text) + self.assertEqual(len(decls), 1) + self.assertIn("declared_func", decls[0][0]) + + def test_multiple_declarations(self): + text = ( + "void foo(int a);\n" + "int bar(slong b, ulong c);\n" + ) + decls = extract_declarations(text) + self.assertEqual(len(decls), 2) + + def test_multiline_declaration(self): + text = ( + "void long_func(int a,\n" + " int b,\n" + " int c);\n" + ) + decls = extract_declarations(text) + self.assertEqual(len(decls), 1) + sig = decls[0][0] + self.assertIn("long_func", sig) + + def test_no_declarations_when_no_parens(self): + text = "int x;\nchar * str;\n" + decls = extract_declarations(text) + self.assertEqual(len(decls), 0) + + +class TestExtractDefinitions(unittest.TestCase): + def test_simple_definition(self): + text = ( + "void foo(int x)\n" + "{\n" + " return;\n" + "}\n" + ) + defs = extract_definitions(text) + self.assertEqual(len(defs), 1) + self.assertIn("foo", defs[0][0]) + + def test_skips_declarations(self): + text = ( + "void declared_only(int x);\n" + "void defined(int y)\n" + "{\n" + " return;\n" + "}\n" + ) + defs = extract_definitions(text) + self.assertEqual(len(defs), 1) + self.assertIn("defined", defs[0][0]) + + def test_return_type_on_separate_line(self): + text = ( + "void\n" + "foo(int x, int y)\n" + "{\n" + " return;\n" + "}\n" + ) + defs = extract_definitions(text) + self.assertEqual(len(defs), 1) + self.assertIn("foo", defs[0][0]) + + def test_multiple_definitions(self): + text = ( + "void foo(int a)\n" + "{\n" + " return;\n" + "}\n" + "\n" + "int bar(int b)\n" + "{\n" + " return 0;\n" + "}\n" + ) + defs = extract_definitions(text) + self.assertEqual(len(defs), 2) + + def test_nested_braces(self): + text = ( + "void foo(int x)\n" + "{\n" + " if (x) {\n" + " return;\n" + " }\n" + "}\n" + ) + defs = extract_definitions(text) + self.assertEqual(len(defs), 1) + self.assertIn("foo", defs[0][0]) + + def test_static_and_nonstatic(self): + """Both static and non-static definitions should be extracted.""" + text = ( + "static void helper(int x)\n" + "{\n" + " return;\n" + "}\n" + "\n" + "void public_func(int y)\n" + "{\n" + " helper(y);\n" + "}\n" + ) + defs = extract_definitions(text) + self.assertEqual(len(defs), 2) + + +class TestParseHeaderDeclarations(unittest.TestCase): + def _write_temp(self, content): + fd, path = tempfile.mkstemp(suffix=".h") + os.write(fd, content.encode()) + os.close(fd) + return path + + def test_basic_header(self): + path = self._write_temp( + "/* copyright */\n" + "#ifndef FOO_H\n" + "#define FOO_H\n" + "#ifdef __cplusplus\n" + 'extern "C" {\n' + "#endif\n" + "void foo_init(foo_t x);\n" + "void foo_clear(foo_t x);\n" + "int foo_add(foo_t r, const foo_t a, const foo_t b);\n" + "#ifdef __cplusplus\n" + "}\n" + "#endif\n" + "#endif\n" + ) + try: + funcs = parse_header_declarations(path) + self.assertIn("foo_init", funcs) + self.assertIn("foo_clear", funcs) + self.assertIn("foo_add", funcs) + self.assertEqual(funcs["foo_init"][0], ["x"]) + self.assertEqual(funcs["foo_add"][0], ["r", "a", "b"]) + finally: + os.unlink(path) + + def test_skips_inline(self): + path = self._write_temp( + "void foo_declared(int x);\n" + "static inline void foo_inline(int y)\n" + "{\n" + " return;\n" + "}\n" + ) + try: + funcs = parse_header_declarations(path) + self.assertIn("foo_declared", funcs) + self.assertNotIn("foo_inline", funcs) + finally: + os.unlink(path) + + def test_skips_typedefs(self): + path = self._write_temp( + "typedef void (*callback_t)(int x);\n" + "void real_func(int y);\n" + ) + try: + funcs = parse_header_declarations(path) + self.assertIn("real_func", funcs) + self.assertNotIn("callback_t", funcs) + finally: + os.unlink(path) + + def test_multiline_declaration(self): + path = self._write_temp( + "void long_func(int a,\n" + " int b,\n" + " int c);\n" + ) + try: + funcs = parse_header_declarations(path) + self.assertIn("long_func", funcs) + self.assertEqual(funcs["long_func"][0], ["a", "b", "c"]) + finally: + os.unlink(path) + + def test_warn_unused_result(self): + path = self._write_temp( + "WARN_UNUSED_RESULT int foo(int x, int y);\n" + ) + try: + funcs = parse_header_declarations(path) + self.assertIn("foo", funcs) + self.assertEqual(funcs["foo"][0], ["x", "y"]) + finally: + os.unlink(path) + + +class TestParseSourceDefinitions(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.src_dir = self.tmpdir + self.mod_dir = os.path.join(self.tmpdir, "mymod") + os.makedirs(self.mod_dir) + + def tearDown(self): + import shutil + shutil.rmtree(self.tmpdir) + + def _write_c(self, name, content): + path = os.path.join(self.mod_dir, name) + with open(path, "w") as f: + f.write(content) + + def test_simple_definition(self): + self._write_c("add.c", + "void mymod_add(foo_t r, const foo_t a, const foo_t b)\n" + "{\n" + " return;\n" + "}\n" + ) + funcs = parse_source_definitions(self.src_dir, "mymod") + self.assertIn("mymod_add", funcs) + self.assertEqual(funcs["mymod_add"][0], ["r", "a", "b"]) + + def test_skips_static(self): + self._write_c("impl.c", + "static void helper(int x)\n" + "{\n" + " return;\n" + "}\n" + "\n" + "void mymod_func(int y)\n" + "{\n" + " helper(y);\n" + "}\n" + ) + funcs = parse_source_definitions(self.src_dir, "mymod") + self.assertNotIn("helper", funcs) + self.assertIn("mymod_func", funcs) + + def test_skips_test_files(self): + self._write_c("t-add.c", + "void test_add(int x)\n" + "{\n" + " return;\n" + "}\n" + ) + self._write_c("add.c", + "void mymod_add(int y)\n" + "{\n" + " return;\n" + "}\n" + ) + funcs = parse_source_definitions(self.src_dir, "mymod") + self.assertNotIn("test_add", funcs) + self.assertIn("mymod_add", funcs) + + def test_return_type_on_separate_line(self): + self._write_c("gcd.c", + "void\n" + "mymod_gcd(foo_t f, const foo_t g, const foo_t h)\n" + "{\n" + " return;\n" + "}\n" + ) + funcs = parse_source_definitions(self.src_dir, "mymod") + self.assertIn("mymod_gcd", funcs) + self.assertEqual(funcs["mymod_gcd"][0], ["f", "g", "h"]) + + +class TestParseRstFunctions(unittest.TestCase): + def _write_temp(self, content): + fd, path = tempfile.mkstemp(suffix=".rst") + os.write(fd, content.encode()) + os.close(fd) + return path + + def test_simple_directive(self): + path = self._write_temp( + ".. function:: void foo_init(foo_t x)\n" + "\n" + " Initializes x.\n" + ) + try: + funcs = parse_rst_functions(path) + self.assertIn("foo_init", funcs) + self.assertEqual(funcs["foo_init"][0], ["x"]) + finally: + os.unlink(path) + + def test_continuation_lines(self): + path = self._write_temp( + ".. function:: void foo_init(foo_t x)\n" + " void foo_clear(foo_t x)\n" + "\n" + " Init/clear.\n" + ) + try: + funcs = parse_rst_functions(path) + self.assertIn("foo_init", funcs) + self.assertIn("foo_clear", funcs) + finally: + os.unlink(path) + + def test_c_function_directive(self): + path = self._write_temp( + ".. c:function:: int bar(slong n, ulong k)\n" + "\n" + " Computes something.\n" + ) + try: + funcs = parse_rst_functions(path) + self.assertIn("bar", funcs) + self.assertEqual(funcs["bar"][0], ["n", "k"]) + finally: + os.unlink(path) + + def test_multiple_directives(self): + path = self._write_temp( + ".. function:: void alpha(int a)\n" + "\n" + " Does alpha.\n" + "\n" + ".. function:: void beta(int b)\n" + "\n" + " Does beta.\n" + ) + try: + funcs = parse_rst_functions(path) + self.assertIn("alpha", funcs) + self.assertIn("beta", funcs) + finally: + os.unlink(path) + + def test_line_numbers(self): + path = self._write_temp( + "Title\n" + "=====\n" + "\n" + ".. function:: void first(int x)\n" + "\n" + " First func.\n" + "\n" + ".. function:: void second(int y)\n" + "\n" + " Second func.\n" + ) + try: + funcs = parse_rst_functions(path) + self.assertEqual(funcs["first"][1], 4) + self.assertEqual(funcs["second"][1], 8) + finally: + os.unlink(path) + + +class TestCompareParams(unittest.TestCase): + def _make_funcs(self, **kwargs): + """Helper: create funcs dict from name -> param_list mapping.""" + return { + name: (params, 1, f"void {name}(...)") + for name, params in kwargs.items() + } + + def test_no_mismatches(self): + a = self._make_funcs(foo=["x", "y"], bar=["a"]) + b = self._make_funcs(foo=["x", "y"], bar=["a"]) + self.assertEqual(compare_params(a, b), []) + + def test_name_mismatch(self): + a = self._make_funcs(foo=["x", "y"]) + b = self._make_funcs(foo=["a", "b"]) + mm = compare_params(a, b) + self.assertEqual(len(mm), 1) + self.assertEqual(mm[0]["func"], "foo") + self.assertEqual(mm[0]["type"], "param_name") + self.assertEqual(mm[0]["diffs"], [(0, "x", "a"), (1, "y", "b")]) + + def test_count_mismatch(self): + a = self._make_funcs(foo=["x", "y"]) + b = self._make_funcs(foo=["x", "y", "z"]) + mm = compare_params(a, b) + self.assertEqual(len(mm), 1) + self.assertEqual(mm[0]["type"], "param_count") + + def test_only_common_compared(self): + a = self._make_funcs(foo=["x"], only_a=["y"]) + b = self._make_funcs(foo=["x"], only_b=["z"]) + mm = compare_params(a, b) + self.assertEqual(len(mm), 0) + + def test_partial_mismatch(self): + a = self._make_funcs(foo=["x", "y", "z"]) + b = self._make_funcs(foo=["x", "b", "z"]) + mm = compare_params(a, b) + self.assertEqual(len(mm), 1) + self.assertEqual(mm[0]["diffs"], [(1, "y", "b")]) + + def test_swap_detection(self): + a = self._make_funcs(foo=["x", "val", "len", "y"]) + b = self._make_funcs(foo=["x", "len", "val", "y"]) + mm = compare_params(a, b) + self.assertEqual(len(mm), 1) + self.assertEqual(mm[0]["swaps"], [(1, 2, "val", "len")]) + + def test_no_swap_when_names_differ(self): + a = self._make_funcs(foo=["a", "b"]) + b = self._make_funcs(foo=["c", "d"]) + mm = compare_params(a, b) + self.assertEqual(len(mm), 1) + self.assertEqual(mm[0]["swaps"], []) + + def test_swap_among_multiple_diffs(self): + a = self._make_funcs(foo=["w", "val", "len", "z"]) + b = self._make_funcs(foo=["x", "len", "val", "z"]) + mm = compare_params(a, b) + self.assertEqual(len(mm), 1) + # Only positions 1,2 are swapped; position 0 is just different + self.assertEqual(mm[0]["swaps"], [(1, 2, "val", "len")]) + + +class TestEndToEnd(unittest.TestCase): + """End-to-end tests with temp header + RST + source files.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.src_dir = os.path.join(self.tmpdir, "src") + self.doc_dir = os.path.join(self.tmpdir, "doc", "source") + os.makedirs(self.src_dir) + os.makedirs(self.doc_dir) + + def tearDown(self): + import shutil + shutil.rmtree(self.tmpdir) + + def _write(self, path, content): + full = os.path.join(self.tmpdir, path) + os.makedirs(os.path.dirname(full), exist_ok=True) + with open(full, "w") as f: + f.write(content) + return full + + def test_header_vs_doc_match(self): + hdr = self._write("src/mymod.h", + "void mymod_add(foo_t r, const foo_t a, const foo_t b);\n" + ) + rst = self._write("doc/source/mymod.rst", + ".. function:: void mymod_add(foo_t r, const foo_t a, const foo_t b)\n" + "\n" + " Adds a and b.\n" + ) + h = parse_header_declarations(hdr) + d = parse_rst_functions(rst) + mm = compare_params(h, d) + self.assertEqual(len(mm), 0) + + def test_header_vs_doc_mismatch(self): + hdr = self._write("src/mymod.h", + "void mymod_add(foo_t f, const foo_t g, const foo_t h);\n" + ) + rst = self._write("doc/source/mymod.rst", + ".. function:: void mymod_add(foo_t r, const foo_t a, const foo_t b)\n" + "\n" + " Adds a and b.\n" + ) + h = parse_header_declarations(hdr) + d = parse_rst_functions(rst) + mm = compare_params(h, d) + self.assertEqual(len(mm), 1) + self.assertEqual(mm[0]["func"], "mymod_add") + + def test_header_vs_source_match(self): + self._write("src/mymod.h", + "void mymod_add(foo_t r, const foo_t a, const foo_t b);\n" + ) + self._write("src/mymod/add.c", + "void mymod_add(foo_t r, const foo_t a, const foo_t b)\n" + "{\n" + " return;\n" + "}\n" + ) + hdr = os.path.join(self.tmpdir, "src", "mymod.h") + h = parse_header_declarations(hdr) + s = parse_source_definitions( + os.path.join(self.tmpdir, "src"), "mymod" + ) + mm = compare_params(h, s) + self.assertEqual(len(mm), 0) + + def test_header_vs_source_mismatch(self): + self._write("src/mymod.h", + "void mymod_add(foo_t r, const foo_t a, const foo_t b);\n" + ) + self._write("src/mymod/add.c", + "void mymod_add(foo_t f, const foo_t g, const foo_t h)\n" + "{\n" + " return;\n" + "}\n" + ) + hdr = os.path.join(self.tmpdir, "src", "mymod.h") + h = parse_header_declarations(hdr) + s = parse_source_definitions( + os.path.join(self.tmpdir, "src"), "mymod" + ) + mm = compare_params(h, s) + self.assertEqual(len(mm), 1) + self.assertEqual(mm[0]["func"], "mymod_add") + self.assertEqual( + mm[0]["diffs"], + [(0, "r", "f"), (1, "a", "g"), (2, "b", "h")], + ) + + def test_three_way_consistency(self): + """All three sources match — no mismatches.""" + hdr = self._write("src/mymod.h", + "void mymod_add(foo_t r, const foo_t a, const foo_t b);\n" + "void mymod_neg(foo_t r, const foo_t a);\n" + ) + self._write("doc/source/mymod.rst", + ".. function:: void mymod_add(foo_t r, const foo_t a, const foo_t b)\n" + "\n" + " Adds.\n" + "\n" + ".. function:: void mymod_neg(foo_t r, const foo_t a)\n" + "\n" + " Negates.\n" + ) + self._write("src/mymod/add.c", + "void mymod_add(foo_t r, const foo_t a, const foo_t b)\n" + "{\n" + " return;\n" + "}\n" + ) + self._write("src/mymod/neg.c", + "void mymod_neg(foo_t r, const foo_t a)\n" + "{\n" + " return;\n" + "}\n" + ) + h = parse_header_declarations(hdr) + d = parse_rst_functions( + os.path.join(self.tmpdir, "doc", "source", "mymod.rst") + ) + s = parse_source_definitions( + os.path.join(self.tmpdir, "src"), "mymod" + ) + self.assertEqual(compare_params(h, d), []) + self.assertEqual(compare_params(h, s), []) + + def test_three_way_doc_off(self): + """Doc has different names, but header and source agree.""" + hdr = self._write("src/mymod.h", + "void mymod_add(foo_t r, const foo_t a, const foo_t b);\n" + ) + self._write("doc/source/mymod.rst", + ".. function:: void mymod_add(foo_t res, const foo_t x, const foo_t y)\n" + "\n" + " Adds.\n" + ) + self._write("src/mymod/add.c", + "void mymod_add(foo_t r, const foo_t a, const foo_t b)\n" + "{\n" + " return;\n" + "}\n" + ) + h = parse_header_declarations(hdr) + d = parse_rst_functions( + os.path.join(self.tmpdir, "doc", "source", "mymod.rst") + ) + s = parse_source_definitions( + os.path.join(self.tmpdir, "src"), "mymod" + ) + hd = compare_params(h, d) + hs = compare_params(h, s) + self.assertEqual(len(hd), 1) # header vs doc mismatch + self.assertEqual(len(hs), 0) # header vs source match + + +class TestRealFiles(unittest.TestCase): + """Integration tests against actual FLINT files (skipped if not present).""" + + def setUp(self): + if not os.path.isfile("src/fmpz.h"): + self.skipTest("Not running from FLINT root directory") + + def test_fmpz_header_parses(self): + funcs = parse_header_declarations("src/fmpz.h") + self.assertGreater(len(funcs), 100) + self.assertIn("fmpz_add", funcs) + self.assertEqual(funcs["fmpz_add"][0], ["f", "g", "h"]) + + def test_fmpz_doc_parses(self): + funcs = parse_rst_functions("doc/source/fmpz.rst") + self.assertGreater(len(funcs), 100) + self.assertIn("fmpz_add", funcs) + + def test_fmpz_source_parses(self): + funcs = parse_source_definitions("src", "fmpz") + self.assertGreater(len(funcs), 50) + self.assertIn("fmpz_add", funcs) + + def test_acb_header_skips_inlines(self): + funcs = parse_header_declarations("src/acb.h") + # acb_add is an inline — should be skipped + self.assertNotIn("acb_add", funcs) + # acb_mul is a declaration — should be found + self.assertIn("acb_mul", funcs) + + def test_gr_poly_header_with_warn_unused(self): + funcs = parse_header_declarations("src/gr_poly.h") + self.assertIn("gr_poly_set", funcs) + self.assertIn("gr_poly_mul", funcs) + + def test_padic_mat_header_parsed(self): + """padic_mat header should parse correctly.""" + funcs = parse_header_declarations("src/padic_mat.h") + self.assertIn("padic_mat_is_canonical", funcs, + "padic_mat_is_canonical should be in header") + + +class TestRenameParamInRange(unittest.TestCase): + def test_simple_rename(self): + content = "void foo(int x)\n{\n return x + 1;\n}\n" + result = rename_param_in_range(content, 0, len(content), "x", "n") + self.assertEqual(result, "void foo(int n)\n{\n return n + 1;\n}\n") + + def test_avoids_struct_member(self): + """Must not rename struct->field or obj.field patterns.""" + content = ( + "void foo(int length)\n" + "{\n" + " x = length + cache->length;\n" + "}\n" + ) + result = rename_param_in_range( + content, 0, len(content), "length", "len" + ) + self.assertIn("int len)", result) + self.assertIn("x = len + cache->length", result) + + def test_avoids_dot_member(self): + content = "void foo(int x)\n{\n a = x + obj.x;\n}\n" + result = rename_param_in_range(content, 0, len(content), "x", "n") + self.assertIn("a = n + obj.x", result) + + def test_word_boundary(self): + """Must not rename substrings.""" + content = "void foo(int x)\n{\n int xx = x;\n}\n" + result = rename_param_in_range(content, 0, len(content), "x", "n") + self.assertIn("int xx = n", result) + self.assertIn("int n)", result) + + +class TestFindFunctionSpan(unittest.TestCase): + def test_simple_function(self): + content = "void foo(int x)\n{\n return;\n}\n" + span = find_function_span(content, "foo") + self.assertIsNotNone(span) + start, end = span + self.assertIn("foo", content[start:end]) + self.assertIn("return", content[start:end]) + + def test_skips_declarations(self): + content = ( + "void foo(int x);\n" + "void bar(int y)\n" + "{\n" + " return;\n" + "}\n" + ) + span = find_function_span(content, "foo") + self.assertIsNone(span) + span = find_function_span(content, "bar") + self.assertIsNotNone(span) + + def test_return_type_on_separate_line(self): + content = "void\nfoo(int x)\n{\n return;\n}\n" + span = find_function_span(content, "foo") + self.assertIsNotNone(span) + + +class TestFixSourceFile(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + import shutil + shutil.rmtree(self.tmpdir) + + def _write(self, name, content): + path = os.path.join(self.tmpdir, name) + with open(path, "w") as f: + f.write(content) + return path + + def test_fix_simple(self): + path = self._write("func.c", + "void mymod_foo(foo_t f, const foo_t g)\n" + "{\n" + " bar(f, g);\n" + "}\n" + ) + fix_source_file(path, "mymod_foo", [("f", "r"), ("g", "a")]) + with open(path) as f: + result = f.read() + self.assertIn("foo_t r, const foo_t a)", result) + self.assertIn("bar(r, a);", result) + + def test_fix_avoids_struct_members(self): + path = self._write("func.c", + "void mymod_foo(slong length)\n" + "{\n" + " x = length + cache->length;\n" + "}\n" + ) + fix_source_file(path, "mymod_foo", [("length", "len")]) + with open(path) as f: + result = f.read() + self.assertIn("slong len)", result) + self.assertIn("x = len + cache->length;", result) + + +class TestFixDeclarationFile(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + import shutil + shutil.rmtree(self.tmpdir) + + def _write(self, name, content): + path = os.path.join(self.tmpdir, name) + with open(path, "w") as f: + f.write(content) + return path + + def test_fix_header_declaration(self): + path = self._write("mod.h", + "void mymod_foo(foo_t f, const foo_t g);\n" + "void mymod_bar(int x);\n" + ) + fix_declaration_file(path, "mymod_foo", [("f", "r"), ("g", "a")]) + with open(path) as f: + result = f.read() + self.assertIn("foo_t r, const foo_t a)", result) + # mymod_bar should be untouched + self.assertIn("mymod_bar(int x)", result) + + +class TestApplyFixes(unittest.TestCase): + """Test the full fix workflow.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + import shutil + shutil.rmtree(self.tmpdir) + + def _write(self, path, content): + full = os.path.join(self.tmpdir, path) + os.makedirs(os.path.dirname(full), exist_ok=True) + with open(full, "w") as f: + f.write(content) + return full + + def test_fix_source_to_match_header(self): + """When header and doc agree, source gets fixed.""" + hdr_path = self._write("src/mymod.h", + "void mymod_foo(foo_t r, const foo_t a);\n" + ) + rst_path = self._write("doc/source/mymod.rst", + ".. function:: void mymod_foo(foo_t r, const foo_t a)\n" + "\n" + " Does foo.\n" + ) + src_path = self._write("src/mymod/foo.c", + "void mymod_foo(foo_t f, const foo_t g)\n" + "{\n" + " bar(f, g);\n" + "}\n" + ) + src_dir = os.path.join(self.tmpdir, "src") + doc_dir = os.path.join(self.tmpdir, "doc", "source") + modules = [("mymod", hdr_path, rst_path)] + results = collect_mismatches(modules, src_dir, check_src=True) + fixed, skipped = apply_fixes(results, src_dir) + self.assertEqual(fixed, 1) + self.assertEqual(skipped, 0) + with open(src_path) as f: + content = f.read() + self.assertIn("foo_t r, const foo_t a)", content) + self.assertIn("bar(r, a);", content) + + def test_fix_header_and_source_to_match_doc(self): + """When header and source agree but doc differs, + both header and source get fixed to match doc.""" + hdr_path = self._write("src/mymod.h", + "void mymod_foo(foo_t f, const foo_t g);\n" + ) + rst_path = self._write("doc/source/mymod.rst", + ".. function:: void mymod_foo(foo_t r, const foo_t a)\n" + "\n" + " Does foo.\n" + ) + src_path = self._write("src/mymod/foo.c", + "void mymod_foo(foo_t f, const foo_t g)\n" + "{\n" + " bar(f, g);\n" + "}\n" + ) + src_dir = os.path.join(self.tmpdir, "src") + modules = [("mymod", hdr_path, rst_path)] + results = collect_mismatches(modules, src_dir, check_src=True) + fixed, skipped = apply_fixes(results, src_dir) + self.assertEqual(fixed, 2) # header + source + with open(hdr_path) as f: + hdr = f.read() + self.assertIn("foo_t r, const foo_t a)", hdr) + with open(src_path) as f: + src = f.read() + self.assertIn("foo_t r, const foo_t a)", src) + self.assertIn("bar(r, a);", src) + + def test_fix_struct_member_safety(self): + """Renaming param must not affect struct->member access.""" + hdr_path = self._write("src/mymod.h", + "void mymod_foo(slong len);\n" + ) + rst_path = self._write("doc/source/mymod.rst", + ".. function:: void mymod_foo(slong len)\n" + "\n" + " Does foo.\n" + ) + src_path = self._write("src/mymod/foo.c", + "void mymod_foo(slong length)\n" + "{\n" + " x = length + cache->length;\n" + "}\n" + ) + src_dir = os.path.join(self.tmpdir, "src") + modules = [("mymod", hdr_path, rst_path)] + results = collect_mismatches(modules, src_dir, check_src=True) + fixed, skipped = apply_fixes(results, src_dir) + self.assertEqual(fixed, 1) + with open(src_path) as f: + content = f.read() + self.assertIn("slong len)", content) + self.assertIn("x = len + cache->length;", content) + + +if __name__ == "__main__": + unittest.main()