Skip to content

Commit a651add

Browse files
Merge branch 'multikernel:main' into perf_pr
2 parents a0133a4 + 4031506 commit a651add

4 files changed

Lines changed: 555 additions & 84 deletions

File tree

examples/struct_ops_simple.ks

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ impl minimal_congestion_control {
1212
return 16
1313
}
1414

15+
fn undo_cwnd(sk: *u8) -> u32 {
16+
return ssthresh(sk)
17+
}
18+
1519
fn cong_avoid(sk: *u8, ack: u32, acked: u32) -> void {
1620
// Minimal TCP congestion avoidance implementation
1721
// In a real implementation, this would adjust the congestion window

src/ir_generator.ml

Lines changed: 126 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ type ir_context = {
6767
map_origin_variables: (string, (string * ir_value * (ir_value_desc * ir_type))) Hashtbl.t; (* var_name -> (map_name, key, underlying_info) *)
6868
(* Track inferred variable types for proper lookups *)
6969
variable_types: (string, ir_type) Hashtbl.t; (* var_name -> ir_type *)
70+
mutable current_program_type: program_type option;
7071
}
7172

7273
(** Create new IR generation context *)
@@ -91,6 +92,7 @@ let create_context ?(global_variables = []) ?(helper_functions = []) symbol_tabl
9192
tbl);
9293
map_origin_variables = Hashtbl.create 32;
9394
variable_types = Hashtbl.create 32;
95+
current_program_type = None;
9496
helper_functions = (let tbl = Hashtbl.create 16 in
9597
List.iter (fun helper_name -> Hashtbl.add tbl helper_name ()) helper_functions;
9698
tbl);
@@ -349,6 +351,103 @@ let extract_struct_ops_kernel_name attributes =
349351
| _ -> acc
350352
) "" attributes
351353

354+
let ast_struct_has_field ast struct_name field_name =
355+
List.exists (function
356+
| Ast.StructDecl struct_def when struct_def.Ast.struct_name = struct_name ->
357+
List.exists (fun (name, _) -> name = field_name) struct_def.Ast.struct_fields
358+
| _ -> false
359+
) ast
360+
361+
let impl_block_has_static_field impl_block field_name =
362+
List.exists (function
363+
| Ast.ImplStaticField (name, _) when name = field_name -> true
364+
| _ -> false
365+
) impl_block.Ast.impl_items
366+
367+
let normalize_struct_ops_instance_name name =
368+
let buffer = Buffer.create (String.length name * 2) in
369+
let is_uppercase ch = ch >= 'A' && ch <= 'Z' in
370+
let is_lowercase ch = ch >= 'a' && ch <= 'z' in
371+
let is_digit ch = ch >= '0' && ch <= '9' in
372+
let add_separator_if_needed idx ch =
373+
if idx > 0 && is_uppercase ch then
374+
let prev = name.[idx - 1] in
375+
let next_is_lowercase = idx + 1 < String.length name && is_lowercase name.[idx + 1] in
376+
if is_lowercase prev || is_digit prev || (is_uppercase prev && next_is_lowercase) then
377+
Buffer.add_char buffer '_'
378+
in
379+
String.iteri (fun idx ch ->
380+
add_separator_if_needed idx ch;
381+
let normalized =
382+
if is_uppercase ch then Char.lowercase_ascii ch
383+
else if is_lowercase ch || is_digit ch || ch = '_' then ch
384+
else '_'
385+
in
386+
Buffer.add_char buffer normalized
387+
) name;
388+
Buffer.contents buffer
389+
390+
let generate_default_struct_ops_name instance_name =
391+
(* BPF_OBJ_NAME_LEN is 16 bytes including the NUL terminator, so the
392+
usable name length is 15 characters. *)
393+
let max_len = 15 in
394+
let normalized = normalize_struct_ops_instance_name instance_name in
395+
if String.length normalized <= max_len then normalized
396+
else
397+
let parts = List.filter (fun part -> part <> "") (String.split_on_char '_' normalized) in
398+
match parts with
399+
| [] -> String.sub normalized 0 max_len
400+
| first :: rest ->
401+
let abbreviated =
402+
match rest with
403+
| [] -> first
404+
| _ ->
405+
let initials = rest |> List.map (fun part -> String.make 1 part.[0]) |> String.concat "" in
406+
first ^ "_" ^ initials
407+
in
408+
if String.length abbreviated <= max_len then abbreviated
409+
else String.sub abbreviated 0 max_len
410+
411+
(* Decide whether a tail-call return (IRReturnCall) should be emitted for a
412+
call to [name] in the current context.
413+
414+
Two intentional behaviour changes vs. the previous per-site inline logic:
415+
416+
1. [is_function_pointer] now checks for [IRFunctionPointer] specifically
417+
instead of [Hashtbl.mem ctx.variable_types name]. The old check was
418+
too broad: any local variable (int, pointer, …) with the same name
419+
would be treated as a function pointer and block tail-call lowering.
420+
421+
2. A tail call is only emitted when [current_program_type] is set to a
422+
known attributed type (e.g. XDP, TC, kprobe). Helper functions that
423+
are lowered outside of an attributed program context therefore never
424+
produce tail calls, which is correct because they have no prog_array
425+
to dispatch into. struct_ops methods are explicitly excluded via the
426+
[StructOps] branch. *)
427+
let should_lower_as_implicit_tail_call ctx name =
428+
let is_function_pointer =
429+
Hashtbl.mem ctx.function_parameters name ||
430+
match Hashtbl.find_opt ctx.variable_types name with
431+
| Some (IRFunctionPointer _) -> true
432+
| _ -> false
433+
in
434+
if is_function_pointer || Hashtbl.mem ctx.helper_functions name then
435+
false
436+
else
437+
match ctx.current_function, ctx.current_program_type with
438+
| Some _, Some Ast.StructOps -> false
439+
| Some current_func_name, Some _ ->
440+
let caller_is_attributed =
441+
try Symbol_table.lookup_function ctx.symbol_table current_func_name <> None
442+
with _ -> false
443+
in
444+
let target_is_attributed =
445+
try Symbol_table.lookup_function ctx.symbol_table name <> None
446+
with _ -> false
447+
in
448+
caller_is_attributed && target_is_attributed
449+
| _ -> false
450+
352451

353452
(** Map struct names to their corresponding context types *)
354453
let struct_name_to_context_type = function
@@ -1666,14 +1765,12 @@ and lower_statement ctx stmt =
16661765
(* Check if this is a simple function call that could be a tail call *)
16671766
(match callee_expr.expr_desc with
16681767
| Ast.Identifier name ->
1669-
(* Check if this is a helper function - if so, treat as regular call *)
1670-
if Hashtbl.mem ctx.helper_functions name then
1671-
let ret_val = lower_expression ctx expr in
1672-
IRReturnValue ret_val
1673-
else
1674-
(* This will be converted to tail call by tail call analyzer *)
1768+
if should_lower_as_implicit_tail_call ctx name then
16751769
let arg_vals = List.map (lower_expression ctx) args in
16761770
IRReturnCall (name, arg_vals)
1771+
else
1772+
let ret_val = lower_expression ctx expr in
1773+
IRReturnValue ret_val
16771774
| _ ->
16781775
(* Function pointer call - treat as regular return *)
16791776
let ret_val = lower_expression ctx expr in
@@ -1696,13 +1793,12 @@ and lower_statement ctx stmt =
16961793
(* Check if this is a simple function call that could be a tail call *)
16971794
(match callee_expr.expr_desc with
16981795
| Ast.Identifier name ->
1699-
(* Check if this is a helper function - if so, treat as regular call *)
1700-
if Hashtbl.mem ctx.helper_functions name then
1701-
let ret_val = lower_expression ctx return_expr in
1702-
IRReturnValue ret_val
1703-
else
1796+
if should_lower_as_implicit_tail_call ctx name then
17041797
let arg_vals = List.map (lower_expression ctx) args in
17051798
IRReturnCall (name, arg_vals)
1799+
else
1800+
let ret_val = lower_expression ctx return_expr in
1801+
IRReturnValue ret_val
17061802
| _ ->
17071803
(* Function pointer call - treat as regular return *)
17081804
let ret_val = lower_expression ctx return_expr in
@@ -1719,13 +1815,12 @@ and lower_statement ctx stmt =
17191815
| Ast.Call (callee_expr, args) ->
17201816
(match callee_expr.expr_desc with
17211817
| Ast.Identifier name ->
1722-
(* Check if this is a helper function - if so, treat as regular call *)
1723-
if Hashtbl.mem ctx.helper_functions name then
1724-
let ret_val = lower_expression ctx expr in
1725-
IRReturnValue ret_val
1726-
else
1818+
if should_lower_as_implicit_tail_call ctx name then
17271819
let arg_vals = List.map (lower_expression ctx) args in
17281820
IRReturnCall (name, arg_vals)
1821+
else
1822+
let ret_val = lower_expression ctx expr in
1823+
IRReturnValue ret_val
17291824
| _ ->
17301825
let ret_val = lower_expression ctx expr in
17311826
IRReturnValue ret_val)
@@ -1768,47 +1863,7 @@ and lower_statement ctx stmt =
17681863
(* Check if this is a simple function call that could be a tail call *)
17691864
(match callee_expr.expr_desc with
17701865
| Ast.Identifier name ->
1771-
(* Check if this should be a tail call *)
1772-
let should_be_tail_call =
1773-
(* First check if the identifier is a function parameter or variable (function pointer) *)
1774-
let is_function_pointer =
1775-
Hashtbl.mem ctx.function_parameters name ||
1776-
Hashtbl.mem ctx.variable_types name
1777-
in
1778-
1779-
if is_function_pointer then
1780-
(* Function pointer calls should never be tail calls *)
1781-
false
1782-
else
1783-
(* Check if we're in an attributed function context *)
1784-
match ctx.current_function with
1785-
| Some current_func_name ->
1786-
(* Check if caller is attributed (has eBPF attributes) *)
1787-
let caller_is_attributed =
1788-
try
1789-
let caller_symbol = Symbol_table.lookup_function ctx.symbol_table current_func_name in
1790-
(* TODO: Check if caller has eBPF attributes like @xdp, @tc, etc. *)
1791-
(* For now, assume attributed functions are defined in symbol table *)
1792-
caller_symbol <> None
1793-
with _ -> false
1794-
in
1795-
1796-
(* Check if target function is an attributed function *)
1797-
let target_is_attributed =
1798-
try
1799-
let target_symbol = Symbol_table.lookup_function ctx.symbol_table name in
1800-
(* TODO: Check if target has eBPF attributes like @xdp, @tc, etc. *)
1801-
(* For now, assume attributed functions are defined in symbol table *)
1802-
target_symbol <> None
1803-
with _ -> false
1804-
in
1805-
1806-
(* Only allow tail calls between attributed functions *)
1807-
caller_is_attributed && target_is_attributed
1808-
| None -> false
1809-
in
1810-
1811-
if should_be_tail_call then
1866+
if should_lower_as_implicit_tail_call ctx name then
18121867
(* Generate tail call instruction *)
18131868
let arg_vals = List.map (lower_expression ctx) args in
18141869
let tail_call_index = 0 in (* This will be set by tail call analyzer *)
@@ -2363,6 +2418,7 @@ let convert_match_return_calls_to_tail_calls ir_function =
23632418
(** Lower AST function to IR function *)
23642419
let lower_function ctx prog_name ?(program_type : program_type option = None) ?(func_target = None) (func_def : Ast.function_def) =
23652420
ctx.current_function <- Some func_def.func_name;
2421+
ctx.current_program_type <- program_type;
23662422

23672423
(* Reset for new function *)
23682424
Hashtbl.clear ctx.variable_types;
@@ -3133,6 +3189,19 @@ let lower_multi_program ast symbol_table source_name =
31333189
in
31343190
Some (field_name, field_val)
31353191
) impl_block.impl_items in
3192+
let ir_instance_fields =
3193+
if ast_struct_has_field ast kernel_struct_name "name" && not (impl_block_has_static_field impl_block "name") then
3194+
let generated_name = generate_default_struct_ops_name impl_block.impl_name in
3195+
let generated_name_val =
3196+
make_ir_value
3197+
(IRLiteral (StringLit generated_name))
3198+
(IRStr (String.length generated_name + 1))
3199+
impl_block.impl_pos
3200+
in
3201+
ir_instance_fields @ [("name", generated_name_val)]
3202+
else
3203+
ir_instance_fields
3204+
in
31363205
let ir_instance = make_ir_struct_ops_instance
31373206
impl_block.impl_name
31383207
kernel_struct_name

0 commit comments

Comments
 (0)