Skip to content

Commit f612bf5

Browse files
committed
Widen match result type and emit safe str copies for mixed-length arms
The type checker computed unify_types between match arm types but discarded the result and used the first arm's type. With arms returning strings of different lengths (e.g. "high_priority": str(13) and "default_priority": str(16)), the match result was sized to the first arm, then codegen did __builtin_memcpy(&dst, &src, sizeof(src)) across the differently-laid-out str_N_t types - corrupting the destination because src's .len field ended up overwriting dst's .data[] bytes. Fold unify_types over all arms to compute the actual join, consistent with the str(N) <: str(M) subtyping the language already defines. Replace four unsafe struct-memcpy sites in the codegen (IRAssign on temps, IRAssign on named vars, IRMatch result assembly, and IRVariableDecl init - both the hoisted-redeclaration and fresh-declaration paths) with a single emit_str_copy helper that does length-respecting field copies, safe regardless of layout. Signed-off-by: Cong Wang <cwang@multikernel.io>
1 parent 8238e03 commit f612bf5

4 files changed

Lines changed: 148 additions & 57 deletions

File tree

src/ebpf_c_codegen.ml

Lines changed: 87 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,17 @@ let initialize_context_generators () =
260260
Kernelscript_context.Fprobe_codegen.register ();
261261
Kernelscript_context.Perf_event_codegen.register ()
262262

263+
(** Emit a safe str_N_t to str_M_t copy. The naive
264+
`__builtin_memcpy(&dst, &src, sizeof(src))` is wrong whenever the two
265+
str_N_t types have different sizes: layouts diverge, so the source's
266+
`.len` field gets memcpy'd into the destination's `.data[]` array and the
267+
destination's `.len` is left uninitialized (or filled with garbage past
268+
the end of src). Field-level copy respects the length-prefixed semantics
269+
regardless of either side's declared capacity. *)
270+
let emit_str_copy ctx ~dest ~src =
271+
emit_line ctx (sprintf "%s.len = %s.len;" dest src);
272+
emit_line ctx (sprintf "__builtin_memcpy(%s.data, %s.data, %s.len);" dest src src)
273+
263274
(** Emit all pending string literal declarations *)
264275
let emit_pending_string_literals ctx =
265276
List.iter (fun (var_name, content, size) ->
@@ -1720,28 +1731,46 @@ let generate_c_expression ctx ir_expr =
17201731
(* Generate regular if-else chain with temporary variable for complex cases *)
17211732
let temp_var = fresh_var ctx "match_result" in
17221733
let result_type = ebpf_type_from_ir_type ir_expr.expr_type in
1723-
1734+
let result_str_size = match ir_expr.expr_type with IRStr n -> Some n | _ -> None in
1735+
17241736
(* Generate temporary variable for the result *)
17251737
emit_line ctx (sprintf "%s %s;" result_type temp_var);
1726-
1727-
(* Defer string literals during match arm value generation *)
1728-
ctx.defer_string_literals <- true;
1729-
let arm_values = List.map (fun arm ->
1730-
(arm, generate_c_value ctx arm.ir_arm_value)
1731-
) arms in
1732-
ctx.defer_string_literals <- false;
1733-
1734-
(* Emit collected string literals before if-else chain *)
1735-
emit_pending_string_literals ctx;
1736-
1738+
1739+
(* For str-typed results, the type checker has widened the match's type
1740+
to the LUB of all arm types (str(max N)). Emit each arm value directly
1741+
into the result-typed slot - literals as compound literals at the result
1742+
size, runtime values via field-level copy that respects .len. This
1743+
avoids the unsafe `memcpy(&dst, &src, sizeof(src))` across different
1744+
str_N_t layouts, which used to write src's .len bytes into dst's .data.
1745+
For non-str results, keep the existing deferred-literal flow. *)
1746+
let arm_values = match result_str_size with
1747+
| Some _ -> List.map (fun arm -> (arm, "")) arms
1748+
| None ->
1749+
ctx.defer_string_literals <- true;
1750+
let vals = List.map (fun arm ->
1751+
(arm, generate_c_value ctx arm.ir_arm_value)
1752+
) arms in
1753+
ctx.defer_string_literals <- false;
1754+
emit_pending_string_literals ctx;
1755+
vals
1756+
in
1757+
17371758
(* Generate if-else chain *)
1738-
let needs_memcpy = match ir_expr.expr_type with IRStr _ -> true | _ -> false in
17391759
let generate_match_arm is_first (arm, arm_val_str) =
17401760
let emit_assignment () =
1741-
if needs_memcpy then
1742-
emit_line ctx (sprintf "__builtin_memcpy(&%s, &%s, sizeof(%s));" temp_var arm_val_str arm_val_str)
1743-
else
1744-
emit_line ctx (sprintf "%s = %s;" temp_var arm_val_str)
1761+
match result_str_size with
1762+
| Some result_size ->
1763+
(match arm.ir_arm_value.value_desc with
1764+
| IRLiteral (Ast.StringLit s) ->
1765+
let len = min (String.length s) result_size in
1766+
let truncated = if len < String.length s then String.sub s 0 len else s in
1767+
emit_line ctx (sprintf "%s = (str_%d_t){ .data = \"%s\", .len = %d };"
1768+
temp_var result_size (String.escaped truncated) len)
1769+
| _ ->
1770+
let val_str = generate_c_value ctx arm.ir_arm_value in
1771+
emit_str_copy ctx ~dest:temp_var ~src:val_str)
1772+
| None ->
1773+
emit_line ctx (sprintf "%s = %s;" temp_var arm_val_str)
17451774
in
17461775
match arm.ir_arm_pattern with
17471776
| IRConstantPattern const_val ->
@@ -1759,14 +1788,14 @@ let generate_c_expression ctx ir_expr =
17591788
decrease_indent ctx;
17601789
emit_line ctx "}"
17611790
in
1762-
1791+
17631792
(* Generate all arms *)
17641793
(match arm_values with
17651794
| [] -> () (* No arms - should not happen *)
17661795
| first_arm :: rest_arms ->
17671796
generate_match_arm true first_arm;
17681797
List.iter (generate_match_arm false) rest_arms);
1769-
1798+
17701799
(* Return the temporary variable *)
17711800
temp_var
17721801

@@ -1931,11 +1960,21 @@ and generate_c_instruction ctx ir_instr =
19311960
(* Check if variable is already declared (e.g., in callback functions) *)
19321961
let var_hash = Hashtbl.hash var_name in
19331962
if Hashtbl.mem ctx.declared_registers var_hash then
1934-
(* Variable already declared, just generate assignment if there's an initializer *)
1963+
(* Variable already declared (typically hoisted to function top), emit
1964+
the initializer's assignment at the original position. Cross-size
1965+
str-to-str needs length-respecting field copy; see emit_str_copy. *)
19351966
(match init_expr_opt with
19361967
| Some init_expr ->
1968+
let cross_size_str = match typ, init_expr.expr_desc with
1969+
| IRStr d, IRValue src_val ->
1970+
(match src_val.val_type with IRStr s -> s <> d | _ -> false)
1971+
| _ -> false
1972+
in
19371973
let init_str = generate_c_expression ctx init_expr in
1938-
emit_line ctx (sprintf "%s = %s;" var_name init_str)
1974+
if cross_size_str then
1975+
emit_str_copy ctx ~dest:var_name ~src:init_str
1976+
else
1977+
emit_line ctx (sprintf "%s = %s;" var_name init_str)
19391978
| None -> (* No initializer, no need to emit anything *) ())
19401979
else
19411980
(* Variable not declared yet, generate full declaration *)
@@ -1945,10 +1984,10 @@ and generate_c_instruction ctx ir_instr =
19451984
(* Check if this is a string assignment that needs special handling *)
19461985
(match typ, init_expr.expr_desc with
19471986
| IRStr dest_size, IRValue src_val when (match src_val.val_type with IRStr src_size -> src_size <= dest_size | _ -> false) ->
1948-
(* String to string assignment with compatible sizes - regenerate src with dest size *)
1987+
(* String to string with compatible sizes. *)
19491988
(match src_val.value_desc with
19501989
| IRLiteral (StringLit s) ->
1951-
(* Generate direct struct assignment for string literals *)
1990+
(* Literal: emit at the destination type via designated initializer. *)
19521991
let len = String.length s in
19531992
let max_content_len = dest_size in
19541993
let actual_len = min len max_content_len in
@@ -1957,10 +1996,17 @@ and generate_c_instruction ctx ir_instr =
19571996
emit_line ctx (sprintf " .data = \"%s\"," (String.escaped truncated_s));
19581997
emit_line ctx (sprintf " .len = %d" actual_len);
19591998
emit_line ctx "};"
1960-
| _ ->
1961-
(* For non-literal strings, use regular assignment *)
1962-
let init_str = generate_c_expression ctx init_expr in
1963-
emit_line ctx (sprintf "%s %s = %s;" type_str var_name init_str))
1999+
| _ ->
2000+
let src_size = match src_val.val_type with IRStr s -> s | _ -> dest_size in
2001+
if src_size = dest_size then
2002+
(* Same type: plain struct assignment. *)
2003+
(let init_str = generate_c_expression ctx init_expr in
2004+
emit_line ctx (sprintf "%s %s = %s;" type_str var_name init_str))
2005+
else
2006+
(* Cross-size: declare-then-field-copy (see emit_str_copy). *)
2007+
(emit_line ctx (sprintf "%s %s;" type_str var_name);
2008+
let src_str = generate_c_expression ctx init_expr in
2009+
emit_str_copy ctx ~dest:var_name ~src:src_str))
19642010
| IRStr _, _ ->
19652011
(* Other string expressions (concatenation, etc.) *)
19662012
let init_str = generate_c_expression ctx init_expr in
@@ -2570,13 +2616,14 @@ and generate_assignment ctx dest_val expr is_const =
25702616
| None -> ())
25712617
| _ -> ());
25722618

2573-
(* Use memcpy for cross-size string assignments *)
2574-
let use_memcpy = match dest_val.val_type, expr.expr_desc with
2619+
(* Cross-size string assignments need length-respecting field copy
2620+
(see emit_str_copy). Same-size stays as plain struct assignment. *)
2621+
let cross_size_str = match dest_val.val_type, expr.expr_desc with
25752622
| IRStr d, IRValue src_val -> (match src_val.val_type with IRStr s -> s <> d | _ -> false)
25762623
| _ -> false
25772624
in
2578-
if use_memcpy then
2579-
emit_line ctx (sprintf "__builtin_memcpy(&%s, &%s, sizeof(%s));" dest_str expr_str expr_str)
2625+
if cross_size_str then
2626+
emit_str_copy ctx ~dest:dest_str ~src:expr_str
25802627
else
25812628
emit_line ctx (sprintf "%s%s = %s;" assignment_prefix dest_str expr_str)
25822629
)
@@ -2596,12 +2643,14 @@ and generate_assignment ctx dest_val expr is_const =
25962643
(* Check if this is a string assignment *)
25972644
(match dest_val.val_type, expr.expr_desc with
25982645
| IRStr dest_size, IRValue src_val when (match src_val.val_type with IRStr src_size -> src_size <= dest_size | _ -> false) ->
2599-
(* String to string assignment - use memcpy for cross-size compatibility *)
2646+
(* String to string assignment - length-respecting field copy for
2647+
cross-size cases (see emit_str_copy); plain struct assignment when
2648+
sizes match. *)
26002649
let dest_str = generate_c_value ctx dest_val in
26012650
let src_str = generate_c_value ctx src_val in
26022651
(match src_val.val_type with
26032652
| IRStr src_size when src_size <> dest_size ->
2604-
emit_line ctx (sprintf "__builtin_memcpy(&%s, &%s, sizeof(%s));" dest_str src_str src_str)
2653+
emit_str_copy ctx ~dest:dest_str ~src:src_str
26052654
| _ ->
26062655
emit_line ctx (sprintf "%s = %s;" dest_str src_str))
26072656
| IRStr _, _ ->
@@ -3119,13 +3168,14 @@ let generate_assignment ctx dest_val expr is_const =
31193168
| None -> ())
31203169
| _ -> ());
31213170

3122-
(* Use memcpy for cross-size string assignments *)
3123-
let use_memcpy = match dest_val.val_type, expr.expr_desc with
3171+
(* Cross-size string assignments need length-respecting field copy
3172+
(see emit_str_copy). Same-size stays as plain struct assignment. *)
3173+
let cross_size_str = match dest_val.val_type, expr.expr_desc with
31243174
| IRStr d, IRValue src_val -> (match src_val.val_type with IRStr s -> s <> d | _ -> false)
31253175
| _ -> false
31263176
in
3127-
if use_memcpy then
3128-
emit_line ctx (sprintf "__builtin_memcpy(&%s, &%s, sizeof(%s));" dest_str expr_str expr_str)
3177+
if cross_size_str then
3178+
emit_str_copy ctx ~dest:dest_str ~src:expr_str
31293179
else
31303180
emit_line ctx (sprintf "%s%s = %s;" assignment_prefix dest_str expr_str)
31313181
)

src/type_checker.ml

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,29 +1484,28 @@ and type_check_expression ctx expr =
14841484
{ tarm_pattern = arm.arm_pattern; tarm_body = typed_arm_body; tarm_pos = arm.arm_pos }
14851485
) arms in
14861486

1487-
(* Determine the result type - all arms must have compatible types *)
1487+
(* Result type is the least upper bound of all arm types under the language's
1488+
subtyping relation (e.g. str(N) ⊔ str(M) = str(max N M)). unify_types
1489+
already implements this LUB; the match checker must use the unified type
1490+
it returns rather than discard it - otherwise a narrower first arm forces
1491+
the whole expression to be too narrow for later arms. *)
1492+
let extract_arm_type arm =
1493+
match arm.tarm_body with
1494+
| TSingleExpr expr -> expr.texpr_type
1495+
| TBlock stmts -> extract_block_return_type stmts arm.tarm_pos
1496+
in
14881497
let result_type = match typed_arms with
14891498
| [] -> type_error "Match expression must have at least one arm" expr.expr_pos
14901499
| first_arm :: rest_arms ->
1491-
let first_type = match first_arm.tarm_body with
1492-
| TSingleExpr expr -> expr.texpr_type
1493-
| TBlock stmts ->
1494-
extract_block_return_type stmts first_arm.tarm_pos
1495-
in
1496-
List.iter (fun arm ->
1497-
let arm_type = match arm.tarm_body with
1498-
| TSingleExpr expr -> expr.texpr_type
1499-
| TBlock stmts ->
1500-
extract_block_return_type stmts arm.tarm_pos
1501-
in
1502-
match unify_types first_type arm_type with
1503-
| Some _ -> () (* Compatible *)
1504-
| None ->
1505-
type_error ("All match arms must return compatible types. Expected " ^
1506-
string_of_bpf_type first_type ^ " but got " ^
1500+
List.fold_left (fun acc arm ->
1501+
let arm_type = extract_arm_type arm in
1502+
match unify_types acc arm_type with
1503+
| Some unified -> unified
1504+
| None ->
1505+
type_error ("All match arms must return compatible types. Expected " ^
1506+
string_of_bpf_type acc ^ " but got " ^
15071507
string_of_bpf_type arm_type) arm.tarm_pos
1508-
) rest_arms;
1509-
first_type
1508+
) (extract_arm_type first_arm) rest_arms
15101509
in
15111510

15121511
{ texpr_desc = TMatch (typed_matched_expr, typed_arms); texpr_type = result_type; texpr_pos = expr.expr_pos }

tests/dune

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@
284284
(executable
285285
(name test_match)
286286
(modules test_match)
287-
(libraries kernelscript alcotest test_utils))
287+
(libraries kernelscript alcotest test_utils str))
288288

289289
(executable
290290
(name test_pinned_globals)

tests/test_match.ml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,47 @@ let test_enum_constant_resolution_in_match () =
634634
should be resolved to their actual values, not hardcoded to 0 *)
635635
()
636636

637+
(** Mixed-length string match arms: the result type is the least upper bound
638+
str(max N_i) of all arm types, and codegen does length-respecting field
639+
copies. The unsafe `memcpy(&dst, &src, sizeof(src))` pattern (which writes
640+
src.len into dst.data on cross-size str_N_t copies) must not appear. *)
641+
let test_match_mixed_length_string_arms () =
642+
let input = {|
643+
@helper fn use_label(s: str(20)) -> u32 { return 0 }
644+
645+
@xdp fn classify(ctx: *xdp_md) -> xdp_action {
646+
var n: u32 = 1
647+
var label = match (n) {
648+
1: "short",
649+
2: "medium_str",
650+
default: "longer_default_str"
651+
}
652+
var _ = use_label(label)
653+
return 2
654+
}
655+
656+
fn main() -> i32 { return 0 }
657+
|} in
658+
let ast = Parse.parse_string input in
659+
let symbol_table = Symbol_table.build_symbol_table ast in
660+
let (typed_ast, _) = Type_checker.type_check_and_annotate_ast ast in
661+
let ir = Ir_generator.generate_ir typed_ast symbol_table "test" in
662+
let (generated_code, _) = Ebpf_c_codegen.compile_multi_to_c_with_analysis ir in
663+
let contains substr =
664+
try ignore (Str.search_forward (Str.regexp_string substr) generated_code 0); true
665+
with Not_found -> false
666+
in
667+
(* Widest arm is "longer_default_str" (18). With type-checker LUB widening,
668+
the result temp must be at least str_18_t (or wider if propagated from
669+
the use_label parameter type). *)
670+
check bool "match result widened to a wide-enough str type" true
671+
(contains "str_18_t" || contains "str_20_t");
672+
(* Cross-size struct memcpy with sizeof(src) corrupts the destination - the
673+
match codegen must not emit it. Field copies (`.len = ...` + memcpy of
674+
`.data` with `.len` bytes) are safe. *)
675+
check bool "no unsafe struct-memcpy for str match arms" false
676+
(contains "memcpy(&__match_result")
677+
637678
let suite = [
638679
"test_basic_match_parsing", `Quick, test_basic_match_parsing;
639680
"test_match_with_enums", `Quick, test_match_with_enums;
@@ -646,6 +687,7 @@ let suite = [
646687
"test_nested_match_structures", `Quick, test_nested_match_structures;
647688
"test_match_block_implicit_returns", `Quick, test_match_block_implicit_returns;
648689
"test_enum_constant_resolution_in_match", `Quick, test_enum_constant_resolution_in_match;
690+
"test_match_mixed_length_string_arms", `Quick, test_match_mixed_length_string_arms;
649691
]
650692

651693
let () = run "Match Construct Tests" [

0 commit comments

Comments
 (0)