diff --git a/examples/struct_ops_simple.ks b/examples/struct_ops_simple.ks index 9fff737..90ae54b 100644 --- a/examples/struct_ops_simple.ks +++ b/examples/struct_ops_simple.ks @@ -12,6 +12,10 @@ impl minimal_congestion_control { return 16 } + fn undo_cwnd(sk: *u8) -> u32 { + return ssthresh(sk) + } + fn cong_avoid(sk: *u8, ack: u32, acked: u32) -> void { // Minimal TCP congestion avoidance implementation // In a real implementation, this would adjust the congestion window diff --git a/src/ir_generator.ml b/src/ir_generator.ml index 3c05ffd..b42d8cb 100644 --- a/src/ir_generator.ml +++ b/src/ir_generator.ml @@ -67,6 +67,7 @@ type ir_context = { map_origin_variables: (string, (string * ir_value * (ir_value_desc * ir_type))) Hashtbl.t; (* var_name -> (map_name, key, underlying_info) *) (* Track inferred variable types for proper lookups *) variable_types: (string, ir_type) Hashtbl.t; (* var_name -> ir_type *) + mutable current_program_type: program_type option; } (** Create new IR generation context *) @@ -91,6 +92,7 @@ let create_context ?(global_variables = []) ?(helper_functions = []) symbol_tabl tbl); map_origin_variables = Hashtbl.create 32; variable_types = Hashtbl.create 32; + current_program_type = None; helper_functions = (let tbl = Hashtbl.create 16 in List.iter (fun helper_name -> Hashtbl.add tbl helper_name ()) helper_functions; tbl); @@ -349,6 +351,85 @@ let extract_struct_ops_kernel_name attributes = | _ -> acc ) "" attributes +let ast_struct_has_field ast struct_name field_name = + List.exists (function + | Ast.StructDecl struct_def when struct_def.Ast.struct_name = struct_name -> + List.exists (fun (name, _) -> name = field_name) struct_def.Ast.struct_fields + | _ -> false + ) ast + +let impl_block_has_static_field impl_block field_name = + List.exists (function + | Ast.ImplStaticField (name, _) when name = field_name -> true + | _ -> false + ) impl_block.Ast.impl_items + +let normalize_struct_ops_instance_name name = + let buffer = Buffer.create (String.length name * 2) in + let is_uppercase ch = ch >= 'A' && ch <= 'Z' in + let is_lowercase ch = ch >= 'a' && ch <= 'z' in + let is_digit ch = ch >= '0' && ch <= '9' in + let add_separator_if_needed idx ch = + if idx > 0 && is_uppercase ch then + let prev = name.[idx - 1] in + let next_is_lowercase = idx + 1 < String.length name && is_lowercase name.[idx + 1] in + if is_lowercase prev || is_digit prev || (is_uppercase prev && next_is_lowercase) then + Buffer.add_char buffer '_' + in + String.iteri (fun idx ch -> + add_separator_if_needed idx ch; + let normalized = + if is_uppercase ch then Char.lowercase_ascii ch + else if is_lowercase ch || is_digit ch || ch = '_' then ch + else '_' + in + Buffer.add_char buffer normalized + ) name; + Buffer.contents buffer + +let generate_default_struct_ops_name instance_name = + let max_len = 15 in + let normalized = normalize_struct_ops_instance_name instance_name in + if String.length normalized <= max_len then normalized + else + let parts = List.filter (fun part -> part <> "") (String.split_on_char '_' normalized) in + match parts with + | [] -> String.sub normalized 0 max_len + | first :: rest -> + let abbreviated = + match rest with + | [] -> first + | _ -> + let initials = rest |> List.map (fun part -> String.make 1 part.[0]) |> String.concat "" in + first ^ "_" ^ initials + in + if String.length abbreviated <= max_len then abbreviated + else String.sub abbreviated 0 max_len + +let should_lower_as_implicit_tail_call ctx name = + let is_function_pointer = + Hashtbl.mem ctx.function_parameters name || + match Hashtbl.find_opt ctx.variable_types name with + | Some (IRFunctionPointer _) -> true + | _ -> false + in + if is_function_pointer || Hashtbl.mem ctx.helper_functions name then + false + else + match ctx.current_function, ctx.current_program_type with + | Some _, Some Ast.StructOps -> false + | Some current_func_name, Some _ -> + let caller_is_attributed = + try Symbol_table.lookup_function ctx.symbol_table current_func_name <> None + with _ -> false + in + let target_is_attributed = + try Symbol_table.lookup_function ctx.symbol_table name <> None + with _ -> false + in + caller_is_attributed && target_is_attributed + | _ -> false + (** Map struct names to their corresponding context types *) let struct_name_to_context_type = function @@ -1659,14 +1740,12 @@ and lower_statement ctx stmt = (* Check if this is a simple function call that could be a tail call *) (match callee_expr.expr_desc with | Ast.Identifier name -> - (* Check if this is a helper function - if so, treat as regular call *) - if Hashtbl.mem ctx.helper_functions name then - let ret_val = lower_expression ctx expr in - IRReturnValue ret_val - else - (* This will be converted to tail call by tail call analyzer *) + if should_lower_as_implicit_tail_call ctx name then let arg_vals = List.map (lower_expression ctx) args in IRReturnCall (name, arg_vals) + else + let ret_val = lower_expression ctx expr in + IRReturnValue ret_val | _ -> (* Function pointer call - treat as regular return *) let ret_val = lower_expression ctx expr in @@ -1689,13 +1768,12 @@ and lower_statement ctx stmt = (* Check if this is a simple function call that could be a tail call *) (match callee_expr.expr_desc with | Ast.Identifier name -> - (* Check if this is a helper function - if so, treat as regular call *) - if Hashtbl.mem ctx.helper_functions name then - let ret_val = lower_expression ctx return_expr in - IRReturnValue ret_val - else + if should_lower_as_implicit_tail_call ctx name then let arg_vals = List.map (lower_expression ctx) args in IRReturnCall (name, arg_vals) + else + let ret_val = lower_expression ctx return_expr in + IRReturnValue ret_val | _ -> (* Function pointer call - treat as regular return *) let ret_val = lower_expression ctx return_expr in @@ -1712,13 +1790,12 @@ and lower_statement ctx stmt = | Ast.Call (callee_expr, args) -> (match callee_expr.expr_desc with | Ast.Identifier name -> - (* Check if this is a helper function - if so, treat as regular call *) - if Hashtbl.mem ctx.helper_functions name then - let ret_val = lower_expression ctx expr in - IRReturnValue ret_val - else + if should_lower_as_implicit_tail_call ctx name then let arg_vals = List.map (lower_expression ctx) args in IRReturnCall (name, arg_vals) + else + let ret_val = lower_expression ctx expr in + IRReturnValue ret_val | _ -> let ret_val = lower_expression ctx expr in IRReturnValue ret_val) @@ -1761,47 +1838,7 @@ and lower_statement ctx stmt = (* Check if this is a simple function call that could be a tail call *) (match callee_expr.expr_desc with | Ast.Identifier name -> - (* Check if this should be a tail call *) - let should_be_tail_call = - (* First check if the identifier is a function parameter or variable (function pointer) *) - let is_function_pointer = - Hashtbl.mem ctx.function_parameters name || - Hashtbl.mem ctx.variable_types name - in - - if is_function_pointer then - (* Function pointer calls should never be tail calls *) - false - else - (* Check if we're in an attributed function context *) - match ctx.current_function with - | Some current_func_name -> - (* Check if caller is attributed (has eBPF attributes) *) - let caller_is_attributed = - try - let caller_symbol = Symbol_table.lookup_function ctx.symbol_table current_func_name in - (* TODO: Check if caller has eBPF attributes like @xdp, @tc, etc. *) - (* For now, assume attributed functions are defined in symbol table *) - caller_symbol <> None - with _ -> false - in - - (* Check if target function is an attributed function *) - let target_is_attributed = - try - let target_symbol = Symbol_table.lookup_function ctx.symbol_table name in - (* TODO: Check if target has eBPF attributes like @xdp, @tc, etc. *) - (* For now, assume attributed functions are defined in symbol table *) - target_symbol <> None - with _ -> false - in - - (* Only allow tail calls between attributed functions *) - caller_is_attributed && target_is_attributed - | None -> false - in - - if should_be_tail_call then + if should_lower_as_implicit_tail_call ctx name then (* Generate tail call instruction *) let arg_vals = List.map (lower_expression ctx) args in let tail_call_index = 0 in (* This will be set by tail call analyzer *) @@ -2356,6 +2393,7 @@ let convert_match_return_calls_to_tail_calls ir_function = (** Lower AST function to IR function *) let lower_function ctx prog_name ?(program_type : program_type option = None) ?(func_target = None) (func_def : Ast.function_def) = ctx.current_function <- Some func_def.func_name; + ctx.current_program_type <- program_type; (* Reset for new function *) Hashtbl.clear ctx.variable_types; @@ -3125,6 +3163,19 @@ let lower_multi_program ast symbol_table source_name = in Some (field_name, field_val) ) impl_block.impl_items in + let ir_instance_fields = + if ast_struct_has_field ast kernel_struct_name "name" && not (impl_block_has_static_field impl_block "name") then + let generated_name = generate_default_struct_ops_name impl_block.impl_name in + let generated_name_val = + make_ir_value + (IRLiteral (StringLit generated_name)) + (IRStr (String.length generated_name + 1)) + impl_block.impl_pos + in + ir_instance_fields @ [("name", generated_name_val)] + else + ir_instance_fields + in let ir_instance = make_ir_struct_ops_instance impl_block.impl_name kernel_struct_name diff --git a/src/userspace_codegen.ml b/src/userspace_codegen.ml index 97d9c6c..0f3a66b 100644 --- a/src/userspace_codegen.ml +++ b/src/userspace_codegen.ml @@ -2223,24 +2223,11 @@ let rec generate_c_instruction_from_ir ctx instruction = | IRStruct (name, _) -> name | _ -> failwith "struct_ops register() argument must be an impl block instance") in - (* Generate struct_ops registration code *) - sprintf {|({ - if (!obj) { - fprintf(stderr, "eBPF skeleton not loaded for struct_ops registration\n"); - %s = -1; - } else { - struct bpf_map *map = bpf_object__find_map_by_name(obj->obj, "%s"); - if (!map) { - fprintf(stderr, "Failed to find struct_ops map '%s'\n"); - %s = -1; - } else { - struct bpf_link *link = bpf_map__attach_struct_ops(map); - %s = (link != NULL) ? 0 : -1; - if (link) bpf_link__destroy(link); - } - } - %s; -});|} result_str instance_name instance_name result_str result_str result_str + (* Generate struct_ops registration code via the generated helper to keep the link alive *) + sprintf {|({ + %s = attach_struct_ops_%s(); + %s; + });|} result_str instance_name result_str (** Generate C struct from IR struct definition *) let generate_c_struct_from_ir ir_struct = @@ -2355,6 +2342,123 @@ let collect_function_usage_from_ir_function ?(global_variables = []) ir_func = ) ir_func.basic_blocks; ctx.function_usage +type struct_ops_main_registration = { + result_value: ir_value; + result_name: string; + instance_name: string; +} + +let ir_value_variable_name ir_value = + match ir_value.value_desc with + | IRVariable name | IRTempVariable name -> Some name + | _ -> None + +let struct_ops_instance_name ir_value = + match ir_value.value_desc with + | IRVariable name -> Some name + | IRTempVariable name -> Some name + | _ -> + (match ir_value.val_type with + | IRStruct (name, _) -> Some name + | _ -> None) + +let find_struct_ops_main_registration ir_func = + let registrations = List.fold_left (fun acc block -> + List.fold_left (fun inner_acc instr -> + match instr.instr_desc with + | IRStructOpsRegister (result_val, struct_ops_val) -> + (match ir_value_variable_name result_val, struct_ops_instance_name struct_ops_val with + | Some result_name, Some instance_name -> + { result_value = result_val; result_name; instance_name } :: inner_acc + | _ -> inner_acc) + | _ -> inner_acc + ) acc block.instructions + ) [] ir_func.basic_blocks in + match List.rev ir_func.basic_blocks, registrations with + | last_block :: _, [registration] -> + (match List.rev last_block.instructions with + | { instr_desc = IRReturn (Some return_val); _ } :: _ -> + if ir_value_variable_name return_val = Some registration.result_name then + Some registration + else + None + | _ -> None) + | _ -> None + +let is_c_identifier value = + let is_ident_start = function + | 'a' .. 'z' | 'A' .. 'Z' | '_' -> true + | _ -> false + in + let is_ident_char = function + | 'a' .. 'z' | 'A' .. 'Z' | '0' .. '9' | '_' -> true + | _ -> false + in + String.length value > 0 + && is_ident_start value.[0] + && + let rec check index = + if index >= String.length value then true + else if is_ident_char value.[index] then check (index + 1) + else false + in + check 1 + +let extract_terminal_return_identifier body_c = + let lines = Array.of_list (String.split_on_char '\n' body_c) in + let rec drop_leading_blank_lines = function + | line :: rest when String.trim line = "" -> drop_leading_blank_lines rest + | remaining -> remaining + in + let rec find_last_nonempty index = + if index < 0 then None + else if String.trim lines.(index) = "" then find_last_nonempty (index - 1) + else Some index + in + match find_last_nonempty (Array.length lines - 1) with + | None -> None + | Some index -> + let trimmed_line = String.trim lines.(index) in + let prefix = "return " in + if String.length trimmed_line > String.length prefix + && String.sub trimmed_line 0 (String.length prefix) = prefix + && trimmed_line.[String.length trimmed_line - 1] = ';' then + let expr = String.sub trimmed_line (String.length prefix) (String.length trimmed_line - String.length prefix - 1) |> String.trim in + if is_c_identifier expr then + let kept_lines = + Array.to_list (Array.sub lines 0 index) + |> List.rev + |> drop_leading_blank_lines + |> List.rev + in + Some (String.concat "\n" kept_lines, expr) + else + None + else + None + +let extract_attach_result_identifier body_c instance_name = + let attach_call = sprintf "attach_struct_ops_%s();" instance_name in + let extract_identifier_from_lhs line = + match String.index_opt line '=' with + | None -> None + | Some eq_index -> + let lhs = String.sub line 0 eq_index |> String.trim in + if is_c_identifier lhs then Some lhs else None + in + String.split_on_char '\n' body_c + |> List.find_map (fun line -> + if String.contains line '=' && String.contains line 'a' && String.trim line <> "" then + let trimmed_line = String.trim line in + if String.length trimmed_line >= String.length attach_call + && String.contains trimmed_line '=' + && String.ends_with ~suffix:attach_call trimmed_line then + extract_identifier_from_lhs trimmed_line + else + None + else + None) + (** Generate config initialization from declaration defaults *) let generate_config_initialization (config_decl : Ast.config_declaration) = let config_name = config_decl.config_name in @@ -2568,6 +2672,13 @@ let generate_c_function_from_ir ?(global_variables = []) ?(base_name = "") ?(con let adjusted_return_type = if ir_func.func_name = "main" then "int" else return_type_str in if ir_func.func_name = "main" then + let has_struct_ops_instances = match ir_multi_prog with + | Some multi_prog -> Ir.get_struct_ops_instances multi_prog <> [] + | None -> false + in + let struct_ops_main_registration = + if has_struct_ops_instances then find_struct_ops_main_registration ir_func else None + in let args_parsing_code = if List.length ir_func.parameters > 0 then (* Generate argument parsing for struct parameter *) @@ -2593,9 +2704,13 @@ let generate_c_function_from_ir ?(global_variables = []) ?(base_name = "") ?(con obj = %s_ebpf__open_and_load(); if (!obj) { fprintf(stderr, "Failed to open and load eBPF skeleton\n"); - return 1; +%s return 1; } }|} base_name + (if has_struct_ops_instances then + " if (errno == EPERM) {\n fprintf(stderr, \"The kernel rejected BPF loading with EPERM. Make sure you run as root and the kernel supports struct_ops.\\n\");\n }\n" + else + "") else "" in @@ -2605,6 +2720,12 @@ let generate_c_function_from_ir ?(global_variables = []) ?(base_name = "") ?(con let auto_init_call = if needs_auto_init then " \n // Auto-initialize BPF maps\n atexit(cleanup_bpf_maps);\n if (init_bpf_maps() < 0) {\n return 1;\n }" else "" in + + let struct_ops_init_code = match ir_multi_prog with + | Some _ when has_struct_ops_instances -> + sprintf " if (bump_memlock_rlimit() < 0) {\n return 1;\n }\n\n if (ensure_struct_ops_privileges() < 0) {\n return 1;\n }\n\n atexit(cleanup_%s);" base_name + | _ -> "" + in (* Include setup code when object is loaded in main() *) let pinned_globals_vars = List.filter (fun gv -> gv.is_pinned) global_variables in @@ -2661,6 +2782,7 @@ let generate_c_function_from_ir ?(global_variables = []) ?(base_name = "") ?(con (* Combine skeleton loading with other initialization *) let initialization_code = String.concat "\n" (List.filter (fun s -> s <> "") [ + struct_ops_init_code; skeleton_loading_code; setup_call; auto_init_call; @@ -2668,6 +2790,65 @@ let generate_c_function_from_ir ?(global_variables = []) ?(base_name = "") ?(con error_handling_notice; ]) in + let body_parts = List.mapi (fun index block -> + let label_part = if block.label <> "entry" then [sprintf "%s:" block.label] else [] in + let instructions = + if index = List.length ir_func.basic_blocks - 1 then + match struct_ops_main_registration, List.rev block.instructions with + | Some registration, { instr_desc = IRReturn (Some return_val); _ } :: rest_rev + when ir_value_variable_name return_val = Some registration.result_name -> + List.rev rest_rev + | _ -> block.instructions + else + block.instructions + in + let instr_parts = List.map (generate_c_instruction_from_ir ctx) instructions in + let combined_parts = label_part @ instr_parts in + String.concat "\n " combined_parts + ) ir_func.basic_blocks in + + let body_c = String.concat "\n " body_parts in + let body_c = + let lifecycle_info = match struct_ops_main_registration with + | Some registration -> + let result_name = generate_c_value_from_ir ctx registration.result_value in + Some (body_c, result_name, registration.instance_name, result_name) + | None -> + (match ir_multi_prog with + | Some multi_prog -> + (match Ir.get_struct_ops_instances multi_prog with + | [instance] -> + (match extract_terminal_return_identifier body_c with + | Some (body_prefix, result_name) -> + let attach_result_name = match extract_attach_result_identifier body_prefix instance.ir_instance_name with + | Some name -> name + | None -> result_name + in + Some (body_prefix, result_name, instance.ir_instance_name, attach_result_name) + | None -> None) + | _ -> None) + | None -> None) + in + match lifecycle_info with + | Some (body_prefix, result_str, instance_name, attach_status_str) -> + let lifecycle_code = sprintf {|if (%s != 0) { + %s = %s; + return %s; + } + + wait_for_unregister_request(); + + %s = detach_struct_ops_%s(); + if (%s != 0) { + return %s; + } + + %s = 0; + return %s;|} attach_status_str result_str attach_status_str result_str result_str instance_name result_str result_str result_str result_str in + if body_prefix = "" then lifecycle_code else body_prefix ^ "\n \n " ^ lifecycle_code + | None -> body_c + in + (* Generate ONLY what the user explicitly wrote with skeleton loading at the beginning *) sprintf {|%s %s(%s) { %s%s%s @@ -2707,11 +2888,138 @@ let generate_struct_ops_attach_functions ir_multi_program = else let attach_functions = List.map (fun struct_ops_inst -> let instance_name = struct_ops_inst.ir_instance_name in - sprintf "int attach_struct_ops_%s(void) { return 0; }\nint detach_struct_ops_%s(void) { return 0; }" + sprintf {|int attach_struct_ops_%s(void) { + struct bpf_map *map; + + if (!obj) { + fprintf(stderr, "eBPF skeleton not loaded for struct_ops registration\n"); + return -1; + } + + if (%s_link) { + return 0; + } + + map = bpf_object__find_map_by_name(obj->obj, "%s"); + if (!map) { + fprintf(stderr, "Failed to find struct_ops map '%s'\n"); + return -1; + } + + %s_link = bpf_map__attach_struct_ops(map); + if (!%s_link) { + fprintf(stderr, "Failed to register struct_ops instance '%s': %%s\n", strerror(errno)); + return -1; + } + + printf("Registered struct_ops instance: %s\n"); + return 0; +} + +int detach_struct_ops_%s(void) { + if (!%s_link) { + return 0; + } + + bpf_link__destroy(%s_link); + %s_link = NULL; + printf("Detached struct_ops instance: %s\n"); + return 0; +}|} + instance_name + instance_name instance_name instance_name + instance_name instance_name instance_name + instance_name + instance_name + instance_name + instance_name + instance_name + instance_name ) (Ir.get_struct_ops_instances ir_multi_program) in String.concat "\n" attach_functions +let generate_struct_ops_runtime_helpers base_name ir_multi_program = + let struct_ops_instances = Ir.get_struct_ops_instances ir_multi_program in + if struct_ops_instances = [] then + "" + else + let link_declarations = + struct_ops_instances + |> List.map (fun struct_ops_inst -> + sprintf "static struct bpf_link *%s_link = NULL;" struct_ops_inst.ir_instance_name) + |> String.concat "\n" + in + let cleanup_lines = + struct_ops_instances + |> List.map (fun struct_ops_inst -> + let instance_name = struct_ops_inst.ir_instance_name in + sprintf {| if (%s_link) { + bpf_link__destroy(%s_link); + %s_link = NULL; + }|} instance_name instance_name instance_name) + |> String.concat "\n\n" + in + sprintf {|%s + +static int bump_memlock_rlimit(void) { + struct rlimit rlim = { + .rlim_cur = RLIM_INFINITY, + .rlim_max = RLIM_INFINITY, + }; + + if (setrlimit(RLIMIT_MEMLOCK, &rlim) == 0) { + return 0; + } + + if (errno == EPERM) { + fprintf(stderr, "Warning: failed to raise RLIMIT_MEMLOCK: %%s\n", strerror(errno)); + fprintf(stderr, "Continuing anyway because newer kernels may use memcg accounting instead of memlock.\n"); + return 0; + } + + fprintf(stderr, "Failed to raise RLIMIT_MEMLOCK: %%s\n", strerror(errno)); + return -1; +} + +static int ensure_struct_ops_privileges(void) { + if (geteuid() == 0) { + return 0; + } + + fprintf(stderr, "Warning: struct_ops loading typically requires root privileges or CAP_BPF/CAP_SYS_ADMIN.\n"); + fprintf(stderr, "Continuing anyway; loading may still succeed if this process has the required capabilities.\n"); + fprintf(stderr, "If it fails with a permission error, try: sudo ./%s\n"); + return 0; +} + +static void cleanup_%s(void) { +%s + + if (obj) { + %s_ebpf__destroy(obj); + obj = NULL; + } +} + +static void wait_for_unregister_request(void) { + int ch; + + printf("struct_ops instance is active in the kernel.\n"); + printf("Inspect it from another shell with:\n"); + printf(" sudo bpftool struct_ops show\n"); + printf("Press Enter to unregister it and exit.\n"); + + do { + ch = getchar(); + } while (ch != '\n' && ch != EOF); +}|} + link_declarations + base_name + base_name + cleanup_lines + base_name + (** Generate command line argument parsing for struct parameter *) let generate_getopt_parsing (struct_name : string) (param_name : string) (struct_fields : (string * ir_type) list) = (* Generate option struct array for getopt_long *) @@ -4237,6 +4545,8 @@ static void handle_signal(int sig) { |} else "" in + let struct_ops_runtime_helpers = generate_struct_ops_runtime_helpers base_name ir_multi_prog in + (* Generate struct_ops attach functions *) let struct_ops_attach_functions = generate_struct_ops_attach_functions ir_multi_prog in @@ -4270,7 +4580,7 @@ static void handle_signal(int sig) { %s %s -|} includes string_typedefs unified_declarations string_helpers daemon_globals "" structs_with_pinned skeleton_code all_fd_declarations map_operation_functions ringbuf_handlers ringbuf_dispatch_functions auto_bpf_init_code getopt_parsing_code bpf_helper_functions struct_ops_attach_functions functions +|} includes string_typedefs unified_declarations string_helpers daemon_globals "" structs_with_pinned skeleton_code all_fd_declarations map_operation_functions ringbuf_handlers ringbuf_dispatch_functions auto_bpf_init_code getopt_parsing_code bpf_helper_functions (struct_ops_runtime_helpers ^ (if struct_ops_runtime_helpers <> "" && struct_ops_attach_functions <> "" then "\n\n" else "") ^ struct_ops_attach_functions) functions (** Generate userspace C code from IR multi-program *) let generate_userspace_code_from_ir ?(config_declarations = []) ?(tail_call_analysis = {Tail_call_analyzer.dependencies = []; prog_array_size = 0; index_mapping = Hashtbl.create 16; errors = []}) ?(kfunc_dependencies = {kfunc_definitions = []; private_functions = []; program_dependencies = []; module_name = ""}) ?(resolved_imports = []) (ir_multi_prog : ir_multi_program) ?(output_dir = ".") source_filename = diff --git a/tests/test_struct_ops.ml b/tests/test_struct_ops.ml index c47ca56..bed6154 100644 --- a/tests/test_struct_ops.ml +++ b/tests/test_struct_ops.ml @@ -331,7 +331,49 @@ let test_userspace_struct_ops_codegen () = (* Check that struct_ops setup is included *) check bool "Contains struct_ops setup" true - (try ignore (Str.search_forward (Str.regexp "MyTcpCong") userspace_code 0); true with Not_found -> false) + (try ignore (Str.search_forward (Str.regexp "MyTcpCong") userspace_code 0); true with Not_found -> false); + + check bool "Contains memlock helper for struct_ops" true + (contains_substr userspace_code "static int bump_memlock_rlimit(void)"); + + check bool "Contains privilege helper for struct_ops" true + (contains_substr userspace_code "static int ensure_struct_ops_privileges(void)"); + + check bool "Main calls struct_ops runtime checks" true + (contains_substr userspace_code "if (bump_memlock_rlimit() < 0)" && + contains_substr userspace_code "if (ensure_struct_ops_privileges() < 0)"); + + check bool "Contains struct_ops link global" true + (contains_substr userspace_code "static struct bpf_link *MyTcpCong_link = NULL;"); + + check bool "Contains struct_ops cleanup helper" true + (contains_substr userspace_code "static void cleanup_test(void)"); + + check bool "Contains wait helper for struct_ops" true + (contains_substr userspace_code "static void wait_for_unregister_request(void)"); + + check bool "Contains real attach helper for struct_ops" true + (contains_substr userspace_code "int attach_struct_ops_MyTcpCong(void)" && + contains_substr userspace_code "MyTcpCong_link = bpf_map__attach_struct_ops(map);"); + + check bool "Contains real detach helper for struct_ops" true + (contains_substr userspace_code "int detach_struct_ops_MyTcpCong(void)" && + contains_substr userspace_code "bpf_link__destroy(MyTcpCong_link);"); + + check bool "register() uses attach helper" true + (contains_substr userspace_code "attach_struct_ops_MyTcpCong()"); + + check bool "Struct_ops load failure includes EPERM hint" true + (contains_substr userspace_code "The kernel rejected BPF loading with EPERM. Make sure you run as root and the kernel supports struct_ops."); + + check bool "Main waits for unregister request" true + (contains_substr userspace_code "wait_for_unregister_request();"); + + check bool "Main detaches struct_ops before exit" true + (contains_substr userspace_code "detach_struct_ops_MyTcpCong();"); + + check bool "Main registers struct_ops cleanup" true + (contains_substr userspace_code "atexit(cleanup_test);") (** Test that malformed struct_ops attributes are parsed but should be caught *) let test_malformed_struct_ops_attribute () = @@ -837,7 +879,7 @@ let test_selective_struct_inclusion_in_ebpf () = let test_struct_ops_compilation_completeness () = let program = {| @struct_ops("tcp_congestion_ops") - impl MinimalCongestion { + impl minimal_congestion_control { fn ssthresh(sk: *u8) -> u32 { return 16 } @@ -845,8 +887,7 @@ let test_struct_ops_compilation_completeness () = fn cong_avoid(sk: *u8, ack: u32, acked: u32) -> void { // Implementation } - - name: "minimal_cc", + owner: null, } @@ -855,7 +896,7 @@ let test_struct_ops_compilation_completeness () = } fn main() -> i32 { - var result = register(MinimalCongestion) + var result = register(minimal_congestion_control) return result } |} in @@ -875,16 +916,47 @@ let test_struct_ops_compilation_completeness () = (* Check that the struct_ops instance can be instantiated (key thing that was failing) *) check bool "Contains struct_ops instance instantiation" true - (contains_substr c_code "MinimalCongestion" && contains_substr c_code "struct tcp_congestion_ops"); + (contains_substr c_code "minimal_congestion_control" && contains_substr c_code "struct tcp_congestion_ops"); (* Verify SEC annotations are present *) check bool "Contains .struct_ops section" true (contains_substr c_code "SEC(\".struct_ops\")"); + + (* Verify the compiler synthesizes a safe default name when omitted *) + check bool "Contains synthesized tcp_congestion_ops name" true + (contains_substr c_code ".name = \"minimal_cc\""); (* Verify individual function SEC annotations *) check bool "Contains struct_ops function sections" true (contains_substr c_code "SEC(\"struct_ops/") +(** Test struct_ops internal calls stay as direct calls instead of tail calls *) +let test_struct_ops_internal_calls_are_direct () = + let program = {| + @struct_ops("tcp_congestion_ops") + impl minimal_congestion_control { + fn ssthresh(sk: *u8) -> u32 { + return 16 + } + + fn undo_cwnd(sk: *u8) -> u32 { + return ssthresh(sk) + } + } + |} in + + let ast = Parse.parse_string program in + let ast_with_structs = ast @ Test_utils.StructOps.builtin_ast in + let symbol_table = Symbol_table.build_symbol_table ast_with_structs in + let (typed_ast, _) = Type_checker.type_check_and_annotate_ast ast_with_structs in + let ir = Ir_generator.generate_ir typed_ast symbol_table "test" in + let (c_code, _) = Ebpf_c_codegen.compile_multi_to_c_with_analysis ir in + + check bool "struct_ops direct call emitted" true + (contains_substr c_code "ssthresh(sk)"); + check bool "struct_ops tail call not emitted" false + (contains_substr c_code "bpf_tail_call(ctx, &prog_array") + (** NEW: Test struct inclusion logic with mixed struct types *) let test_mixed_struct_types_inclusion () = let program = {| @@ -1247,6 +1319,7 @@ let tests = [ (* NEW: Regression tests for struct inclusion bugs *) "selective struct inclusion in eBPF", `Quick, test_selective_struct_inclusion_in_ebpf; "struct_ops compilation completeness", `Quick, test_struct_ops_compilation_completeness; + "struct_ops internal direct calls", `Quick, test_struct_ops_internal_calls_are_direct; "mixed struct types inclusion", `Quick, test_mixed_struct_types_inclusion; "malformed struct_ops attribute", `Quick, test_malformed_struct_ops_attribute; "register() with non-struct", `Quick, test_register_with_non_struct;