From c45f781771314a71856c9b348c640ba532f54349 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Tue, 10 Mar 2026 21:47:44 -0400 Subject: [PATCH 1/4] Cache line lookups in prism compiler --- prism_compile.c | 162 ++++++++++++++++++++++++++++++++++++------------ prism_compile.h | 7 +++ 2 files changed, 129 insertions(+), 40 deletions(-) diff --git a/prism_compile.c b/prism_compile.c index c7facf6e22548f..ad7aee095efe0f 100644 --- a/prism_compile.c +++ b/prism_compile.c @@ -140,33 +140,114 @@ pm_iseq_add_setlocal(rb_iseq_t *iseq, LINK_ANCHOR *const seq, int line, int node #define PM_COMPILE_NOT_POPPED(node) \ pm_compile_node(iseq, (node), ret, false, scope_node) +/** + * Cached line lookup that avoids repeated binary searches. Since the compiler + * walks the AST roughly in source order, consecutive lookups tend to be for + * nearby byte offsets. We cache the last result index in the scope node and + * try a short linear probe from there before falling back to binary search. + */ +static inline pm_line_column_t +pm_line_offset_list_line_column_cached(const pm_line_offset_list_t *list, uint32_t cursor, int32_t start_line, size_t *last_line) +{ + size_t hint = *last_line; + size_t size = list->size; + const uint32_t *offsets = list->offsets; + + RUBY_ASSERT(hint < size); + + /* Check if the cursor is on the same line as the hint. */ + if (offsets[hint] <= cursor) { + if (hint + 1 >= size || offsets[hint + 1] > cursor) { + *last_line = hint; + return ((pm_line_column_t) { + .line = ((int32_t) hint) + start_line, + .column = cursor - offsets[hint] + }); + } + + /* Linear scan forward (up to 8 lines before giving up). */ + size_t limit = hint + 9; + if (limit > size) limit = size; + for (size_t idx = hint + 1; idx < limit; idx++) { + if (offsets[idx] > cursor) { + *last_line = idx - 1; + return ((pm_line_column_t) { + .line = ((int32_t) (idx - 1)) + start_line, + .column = cursor - offsets[idx - 1] + }); + } + if (offsets[idx] == cursor) { + *last_line = idx; + return ((pm_line_column_t) { ((int32_t) idx) + start_line, 0 }); + } + } + } + else { + /* Linear scan backward (up to 8 lines before giving up). */ + size_t limit = hint > 8 ? hint - 8 : 0; + for (size_t idx = hint; idx > limit; idx--) { + if (offsets[idx - 1] <= cursor) { + *last_line = idx - 1; + return ((pm_line_column_t) { + .line = ((int32_t) (idx - 1)) + start_line, + .column = cursor - offsets[idx - 1] + }); + } + } + } + + /* Fall back to binary search. */ + pm_line_column_t result = pm_line_offset_list_line_column(list, cursor, start_line); + *last_line = (size_t) (result.line - start_line); + return result; +} + +/** + * The same as pm_line_offset_list_line_column_cached, but returning only the + * line number. + */ +static inline int32_t +pm_line_offset_list_line_cached(const pm_line_offset_list_t *list, uint32_t cursor, int32_t start_line, size_t *last_line) +{ + return pm_line_offset_list_line_column_cached(list, cursor, start_line, last_line).line; +} + #define PM_NODE_START_LOCATION(parser, node) \ - ((pm_node_location_t) { .line = pm_line_offset_list_line(&(parser)->line_offsets, ((const pm_node_t *) (node))->location.start, (parser)->start_line), .node_id = ((const pm_node_t *) (node))->node_id }) + ((pm_node_location_t) { .line = pm_line_offset_list_line_cached(&(parser)->line_offsets, ((const pm_node_t *) (node))->location.start, (parser)->start_line, &scope_node->last_line), .node_id = ((const pm_node_t *) (node))->node_id }) #define PM_NODE_END_LOCATION(parser, node) \ - ((pm_node_location_t) { .line = pm_line_offset_list_line(&(parser)->line_offsets, ((const pm_node_t *) (node))->location.start + ((const pm_node_t *) (node))->location.length, (parser)->start_line), .node_id = ((const pm_node_t *) (node))->node_id }) + ((pm_node_location_t) { .line = pm_line_offset_list_line_cached(&(parser)->line_offsets, ((const pm_node_t *) (node))->location.start + ((const pm_node_t *) (node))->location.length, (parser)->start_line, &scope_node->last_line), .node_id = ((const pm_node_t *) (node))->node_id }) #define PM_LOCATION_START_LOCATION(parser, location, id) \ - ((pm_node_location_t) { .line = pm_line_offset_list_line(&(parser)->line_offsets, (location)->start, (parser)->start_line), .node_id = id }) + ((pm_node_location_t) { .line = pm_line_offset_list_line_cached(&(parser)->line_offsets, (location)->start, (parser)->start_line, &scope_node->last_line), .node_id = id }) #define PM_NODE_START_LINE_COLUMN(parser, node) \ - pm_line_offset_list_line_column(&(parser)->line_offsets, ((const pm_node_t *) (node))->location.start, (parser)->start_line) + pm_line_offset_list_line_column_cached(&(parser)->line_offsets, ((const pm_node_t *) (node))->location.start, (parser)->start_line, &scope_node->last_line) #define PM_NODE_END_LINE_COLUMN(parser, node) \ - pm_line_offset_list_line_column(&(parser)->line_offsets, ((const pm_node_t *) (node))->location.start + ((const pm_node_t *) (node))->location.length, (parser)->start_line) + pm_line_offset_list_line_column_cached(&(parser)->line_offsets, ((const pm_node_t *) (node))->location.start + ((const pm_node_t *) (node))->location.length, (parser)->start_line, &scope_node->last_line) #define PM_LOCATION_START_LINE_COLUMN(parser, location) \ - pm_line_offset_list_line_column(&(parser)->line_offsets, (location)->start, (parser)->start_line) + pm_line_offset_list_line_column_cached(&(parser)->line_offsets, (location)->start, (parser)->start_line, &scope_node->last_line) static int -pm_node_line_number(const pm_parser_t *parser, const pm_node_t *node) +pm_location_line_number(const pm_parser_t *parser, const pm_location_t *location) { + return (int) pm_line_offset_list_line(&parser->line_offsets, location->start, parser->start_line); +} + +/** + * Cached variants that use the scope node's hint for fast lookups during + * compilation (where access patterns are roughly sequential). + */ +static inline int +pm_node_line_number_cached(const pm_parser_t *parser, const pm_node_t *node, pm_scope_node_t *scope_node) { - return (int) pm_line_offset_list_line(&parser->line_offsets, node->location.start, parser->start_line); + return (int) pm_line_offset_list_line_cached(&parser->line_offsets, node->location.start, parser->start_line, &scope_node->last_line); } -static int -pm_location_line_number(const pm_parser_t *parser, const pm_location_t *location) { - return (int) pm_line_offset_list_line(&parser->line_offsets, location->start, parser->start_line); +static inline int +pm_location_line_number_cached(const pm_parser_t *parser, const pm_location_t *location, pm_scope_node_t *scope_node) { + return (int) pm_line_offset_list_line_cached(&parser->line_offsets, location->start, parser->start_line, &scope_node->last_line); } /** @@ -305,7 +386,7 @@ parse_string_encoded(const pm_node_t *node, const pm_string_t *string, rb_encodi } static inline VALUE -parse_static_literal_string(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *string) +parse_static_literal_string(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *string) { rb_encoding *encoding; @@ -323,7 +404,7 @@ parse_static_literal_string(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, rb_enc_str_coderange(value); if (ISEQ_COMPILE_DATA(iseq)->option->debug_frozen_string_literal || RTEST(ruby_debug)) { - int line_number = pm_node_line_number(scope_node->parser, node); + int line_number = pm_node_line_number_cached(scope_node->parser, node, scope_node); value = rb_ractor_make_shareable(rb_str_with_debug_created_info(value, rb_iseq_path(iseq), line_number)); } @@ -368,7 +449,7 @@ parse_regexp_error(rb_iseq_t *iseq, int32_t line_number, const char *fmt, ...) } static VALUE -parse_regexp_string_part(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *unescaped, rb_encoding *implicit_regexp_encoding, rb_encoding *explicit_regexp_encoding) +parse_regexp_string_part(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *unescaped, rb_encoding *implicit_regexp_encoding, rb_encoding *explicit_regexp_encoding) { // If we were passed an explicit regexp encoding, then we need to double // check that it's okay here for this fragment of the string. @@ -390,12 +471,12 @@ parse_regexp_string_part(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, con VALUE string = rb_enc_str_new((const char *) pm_string_source(unescaped), pm_string_length(unescaped), encoding); VALUE error = rb_reg_check_preprocess(string); - if (error != Qnil) parse_regexp_error(iseq, pm_node_line_number(scope_node->parser, node), "%" PRIsVALUE, rb_obj_as_string(error)); + if (error != Qnil) parse_regexp_error(iseq, pm_node_line_number_cached(scope_node->parser, node, scope_node), "%" PRIsVALUE, rb_obj_as_string(error)); return string; } static VALUE -pm_static_literal_concat(rb_iseq_t *iseq, const pm_node_list_t *nodes, const pm_scope_node_t *scope_node, rb_encoding *implicit_regexp_encoding, rb_encoding *explicit_regexp_encoding, bool top) +pm_static_literal_concat(rb_iseq_t *iseq, const pm_node_list_t *nodes, pm_scope_node_t *scope_node, rb_encoding *implicit_regexp_encoding, rb_encoding *explicit_regexp_encoding, bool top) { VALUE current = Qnil; @@ -412,7 +493,7 @@ pm_static_literal_concat(rb_iseq_t *iseq, const pm_node_list_t *nodes, const pm_ else { string = parse_string_encoded(part, &((const pm_string_node_t *) part)->unescaped, scope_node->encoding); VALUE error = rb_reg_check_preprocess(string); - if (error != Qnil) parse_regexp_error(iseq, pm_node_line_number(scope_node->parser, part), "%" PRIsVALUE, rb_obj_as_string(error)); + if (error != Qnil) parse_regexp_error(iseq, pm_node_line_number_cached(scope_node->parser, part, scope_node), "%" PRIsVALUE, rb_obj_as_string(error)); } } else { @@ -525,11 +606,11 @@ parse_regexp_encoding(const pm_scope_node_t *scope_node, const pm_node_t *node) } static VALUE -parse_regexp(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, VALUE string) +parse_regexp(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, VALUE string) { VALUE errinfo = rb_errinfo(); - int32_t line_number = pm_node_line_number(scope_node->parser, node); + int32_t line_number = pm_node_line_number_cached(scope_node->parser, node, scope_node); VALUE regexp = rb_reg_compile(string, parse_regexp_flags(node), (const char *) pm_string_source(&scope_node->parser->filepath), line_number); if (NIL_P(regexp)) { @@ -544,7 +625,7 @@ parse_regexp(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t } static inline VALUE -parse_regexp_literal(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *unescaped) +parse_regexp_literal(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *unescaped) { rb_encoding *regexp_encoding = parse_regexp_encoding(scope_node, node); if (regexp_encoding == NULL) regexp_encoding = scope_node->encoding; @@ -555,7 +636,7 @@ parse_regexp_literal(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const p } static inline VALUE -parse_regexp_concat(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_node_list_t *parts) +parse_regexp_concat(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, const pm_node_list_t *parts) { rb_encoding *explicit_regexp_encoding = parse_regexp_encoding(scope_node, node); rb_encoding *implicit_regexp_encoding = explicit_regexp_encoding != NULL ? explicit_regexp_encoding : scope_node->encoding; @@ -748,7 +829,7 @@ pm_static_literal_string(rb_iseq_t *iseq, VALUE string, int line_number) * literal values can be compiled into a literal array. */ static VALUE -pm_static_literal_value(rb_iseq_t *iseq, const pm_node_t *node, const pm_scope_node_t *scope_node) +pm_static_literal_value(rb_iseq_t *iseq, const pm_node_t *node, pm_scope_node_t *scope_node) { // Every node that comes into this function should already be marked as // static literal. If it's not, then we have a bug somewhere. @@ -804,7 +885,7 @@ pm_static_literal_value(rb_iseq_t *iseq, const pm_node_t *node, const pm_scope_n } case PM_INTERPOLATED_STRING_NODE: { VALUE string = pm_static_literal_concat(iseq, &((const pm_interpolated_string_node_t *) node)->parts, scope_node, NULL, NULL, false); - int line_number = pm_node_line_number(scope_node->parser, node); + int line_number = pm_node_line_number_cached(scope_node->parser, node, scope_node); return pm_static_literal_string(iseq, string, line_number); } case PM_INTERPOLATED_SYMBOL_NODE: { @@ -832,7 +913,7 @@ pm_static_literal_value(rb_iseq_t *iseq, const pm_node_t *node, const pm_scope_n return pm_source_file_value(cast, scope_node); } case PM_SOURCE_LINE_NODE: - return INT2FIX(pm_node_line_number(scope_node->parser, node)); + return INT2FIX(pm_node_line_number_cached(scope_node->parser, node, scope_node)); case PM_STRING_NODE: { const pm_string_node_t *cast = (const pm_string_node_t *) node; return parse_static_literal_string(iseq, scope_node, node, &cast->unescaped); @@ -851,7 +932,7 @@ pm_static_literal_value(rb_iseq_t *iseq, const pm_node_t *node, const pm_scope_n * A helper for converting a pm_location_t into a rb_code_location_t. */ static rb_code_location_t -pm_code_location(const pm_scope_node_t *scope_node, const pm_node_t *node) +pm_code_location(pm_scope_node_t *scope_node, const pm_node_t *node) { const pm_line_column_t start_location = PM_NODE_START_LINE_COLUMN(scope_node->parser, node); const pm_line_column_t end_location = PM_NODE_END_LINE_COLUMN(scope_node->parser, node); @@ -2404,7 +2485,7 @@ pm_compile_pattern_eqq_error(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const static int pm_compile_pattern_match(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, LINK_ANCHOR *const ret, LABEL *unmatched_label, bool in_single_pattern, bool use_deconstructed_cache, unsigned int base_index) { - LABEL *matched_label = NEW_LABEL(pm_node_line_number(scope_node->parser, node)); + LABEL *matched_label = NEW_LABEL(pm_node_line_number_cached(scope_node->parser, node, scope_node)); CHECK(pm_compile_pattern(iseq, scope_node, node, ret, matched_label, unmatched_label, in_single_pattern, use_deconstructed_cache, base_index)); PUSH_LABEL(ret, matched_label); return COMPILE_OK; @@ -2493,7 +2574,7 @@ pm_compile_pattern_constant(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const * responsible for compiling in those error raising instructions. */ static void -pm_compile_pattern_error_handler(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, LINK_ANCHOR *const ret, LABEL *done_label, bool popped) +pm_compile_pattern_error_handler(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, LINK_ANCHOR *const ret, LABEL *done_label, bool popped) { const pm_node_location_t location = PM_NODE_START_LOCATION(scope_node->parser, node); LABEL *key_error_label = NEW_LABEL(location.line); @@ -3243,6 +3324,7 @@ pm_scope_node_init(const pm_node_t *node, pm_scope_node_t *scope, pm_scope_node_ scope->constants = previous->constants; scope->coverage_enabled = previous->coverage_enabled; scope->script_lines = previous->script_lines; + scope->last_line = previous->last_line; } switch (PM_NODE_TYPE(node)) { @@ -3411,7 +3493,7 @@ pm_iseq_builtin_function_name(const pm_scope_node_t *scope_node, const pm_node_t // Compile Primitive.attr! :leaf, ... static int -pm_compile_builtin_attr(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_arguments_node_t *arguments, const pm_node_location_t *node_location) +pm_compile_builtin_attr(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_arguments_node_t *arguments, const pm_node_location_t *node_location) { if (arguments == NULL) { COMPILE_ERROR(iseq, node_location->line, "attr!: no argument"); @@ -3699,7 +3781,7 @@ pm_compile_call(rb_iseq_t *iseq, const pm_call_node_t *call_node, LINK_ANCHOR *c } const pm_line_column_t start_location = PM_NODE_START_LINE_COLUMN(scope_node->parser, call_node); - const pm_line_column_t end_location = pm_line_offset_list_line_column(&scope_node->parser->line_offsets, end_cursor, scope_node->parser->start_line); + const pm_line_column_t end_location = pm_line_offset_list_line_column_cached(&scope_node->parser->line_offsets, end_cursor, scope_node->parser->start_line, &scope_node->last_line); code_location = (rb_code_location_t) { .beg_pos = { .lineno = start_location.line, .column = start_location.column }, @@ -3729,7 +3811,7 @@ pm_compile_call(rb_iseq_t *iseq, const pm_call_node_t *call_node, LINK_ANCHOR *c pm_scope_node_t next_scope_node; pm_scope_node_init(call_node->block, &next_scope_node, scope_node); - block_iseq = NEW_CHILD_ISEQ(&next_scope_node, make_name_for_block(iseq), ISEQ_TYPE_BLOCK, pm_node_line_number(scope_node->parser, call_node->block)); + block_iseq = NEW_CHILD_ISEQ(&next_scope_node, make_name_for_block(iseq), ISEQ_TYPE_BLOCK, pm_node_line_number_cached(scope_node->parser, call_node->block, scope_node)); pm_scope_node_destroy(&next_scope_node); ISEQ_COMPILE_DATA(iseq)->current_block = block_iseq; } @@ -4783,7 +4865,7 @@ pm_compile_destructured_param_locals(const pm_multi_target_node_t *node, st_tabl * as a positional parameter in a method, block, or lambda definition. */ static inline void -pm_compile_destructured_param_write(rb_iseq_t *iseq, const pm_required_parameter_node_t *node, LINK_ANCHOR *const ret, const pm_scope_node_t *scope_node) +pm_compile_destructured_param_write(rb_iseq_t *iseq, const pm_required_parameter_node_t *node, LINK_ANCHOR *const ret, pm_scope_node_t *scope_node) { const pm_node_location_t location = PM_NODE_START_LOCATION(scope_node->parser, node); pm_local_index_t index = pm_lookup_local_index(iseq, scope_node, node->name, 0); @@ -4799,7 +4881,7 @@ pm_compile_destructured_param_write(rb_iseq_t *iseq, const pm_required_parameter * for this simplified case. */ static void -pm_compile_destructured_param_writes(rb_iseq_t *iseq, const pm_multi_target_node_t *node, LINK_ANCHOR *const ret, const pm_scope_node_t *scope_node) +pm_compile_destructured_param_writes(rb_iseq_t *iseq, const pm_multi_target_node_t *node, LINK_ANCHOR *const ret, pm_scope_node_t *scope_node) { const pm_node_location_t location = PM_NODE_START_LOCATION(scope_node->parser, node); bool has_rest = (node->rest && PM_NODE_TYPE_P(node->rest, PM_SPLAT_NODE) && (((const pm_splat_node_t *) node->rest)->expression) != NULL); @@ -5435,7 +5517,7 @@ pm_compile_rescue(rb_iseq_t *iseq, const pm_begin_node_t *cast, const pm_node_lo &rescue_scope_node, rb_str_concat(rb_str_new2("rescue in "), ISEQ_BODY(iseq)->location.label), ISEQ_TYPE_RESCUE, - pm_node_line_number(parser, (const pm_node_t *) cast->rescue_clause) + pm_node_line_number_cached(parser, (const pm_node_t *) cast->rescue_clause, scope_node) ); pm_scope_node_destroy(&rescue_scope_node); @@ -5567,7 +5649,7 @@ pm_opt_str_freeze_p(const rb_iseq_t *iseq, const pm_call_node_t *node) * of the current iseq. */ static void -pm_compile_constant_read(rb_iseq_t *iseq, VALUE name, const pm_location_t *name_loc, uint32_t node_id, LINK_ANCHOR *const ret, const pm_scope_node_t *scope_node) +pm_compile_constant_read(rb_iseq_t *iseq, VALUE name, const pm_location_t *name_loc, uint32_t node_id, LINK_ANCHOR *const ret, pm_scope_node_t *scope_node) { const pm_node_location_t location = PM_LOCATION_START_LOCATION(scope_node->parser, name_loc, node_id); @@ -5667,7 +5749,7 @@ pm_compile_constant_path(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *co * Return the object that will be pushed onto the stack for the given node. */ static VALUE -pm_compile_shareable_constant_literal(rb_iseq_t *iseq, const pm_node_t *node, const pm_scope_node_t *scope_node) +pm_compile_shareable_constant_literal(rb_iseq_t *iseq, const pm_node_t *node, pm_scope_node_t *scope_node) { switch (PM_NODE_TYPE(node)) { case PM_TRUE_NODE: @@ -7507,7 +7589,7 @@ pm_compile_call_operator_write_node(rb_iseq_t *iseq, const pm_call_operator_writ * optimization entirely. */ static VALUE -pm_compile_case_node_dispatch(rb_iseq_t *iseq, VALUE dispatch, const pm_node_t *node, LABEL *label, const pm_scope_node_t *scope_node) +pm_compile_case_node_dispatch(rb_iseq_t *iseq, VALUE dispatch, const pm_node_t *node, LABEL *label, pm_scope_node_t *scope_node) { VALUE key = Qundef; switch (PM_NODE_TYPE(node)) { @@ -7590,7 +7672,7 @@ pm_compile_case_node(rb_iseq_t *iseq, const pm_case_node_t *cast, const pm_node_ const pm_when_node_t *clause = (const pm_when_node_t *) conditions->nodes[clause_index]; const pm_node_list_t *conditions = &clause->conditions; - int clause_lineno = pm_node_line_number(parser, (const pm_node_t *) clause); + int clause_lineno = pm_node_line_number_cached(parser, (const pm_node_t *) clause, scope_node); LABEL *label = NEW_LABEL(clause_lineno); PUSH_LABEL(body_seq, label); @@ -7623,7 +7705,7 @@ pm_compile_case_node(rb_iseq_t *iseq, const pm_case_node_t *cast, const pm_node_ PUSH_INSNL(cond_seq, cond_location, branchif, label); } else { - LABEL *next_label = NEW_LABEL(pm_node_line_number(parser, condition)); + LABEL *next_label = NEW_LABEL(pm_node_line_number_cached(parser, condition, scope_node)); pm_compile_branch_condition(iseq, cond_seq, condition, label, next_label, false, scope_node); PUSH_LABEL(cond_seq, next_label); } @@ -9705,7 +9787,7 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, pm_scope_node_t next_scope_node; pm_scope_node_init(node, &next_scope_node, scope_node); - int opening_lineno = pm_location_line_number(parser, &cast->opening_loc); + int opening_lineno = pm_location_line_number_cached(parser, &cast->opening_loc, scope_node); const rb_iseq_t *block = NEW_CHILD_ISEQ(&next_scope_node, make_name_for_block(iseq), ISEQ_TYPE_BLOCK, opening_lineno); pm_scope_node_destroy(&next_scope_node); @@ -10183,7 +10265,7 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, &rescue_scope_node, rb_str_concat(rb_str_new2("rescue in "), ISEQ_BODY(iseq)->location.label), ISEQ_TYPE_RESCUE, - pm_node_line_number(parser, cast->rescue_expression) + pm_node_line_number_cached(parser, cast->rescue_expression, scope_node) ); pm_scope_node_destroy(&rescue_scope_node); diff --git a/prism_compile.h b/prism_compile.h index ae1ec6bfd4f607..0b5fe685273a1e 100644 --- a/prism_compile.h +++ b/prism_compile.h @@ -61,6 +61,13 @@ typedef struct pm_scope_node { * the instructions pertaining to BEGIN{} nodes. */ struct iseq_link_anchor *pre_execution_anchor; + + /** + * Cached line hint for line offset list lookups. Since the compiler walks + * the AST roughly in source order, consecutive lookups tend to be for + * nearby byte offsets. This avoids repeated binary searches. + */ + size_t last_line; } pm_scope_node_t; void pm_scope_node_init(const pm_node_t *node, pm_scope_node_t *scope, pm_scope_node_t *previous); From 3f0e61fe12bb8717ccf3318269b31df66fc6e016 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Tue, 10 Mar 2026 22:04:48 -0400 Subject: [PATCH 2/4] Speed up integer parsing in prism_compile.c --- prism_compile.c | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/prism_compile.c b/prism_compile.c index ad7aee095efe0f..6bc1da58d0bab0 100644 --- a/prism_compile.c +++ b/prism_compile.c @@ -262,24 +262,21 @@ parse_integer_value(const pm_integer_t *integer) result = UINT2NUM(integer->value); } else { - VALUE string = rb_str_new(NULL, integer->length * 8); - unsigned char *bytes = (unsigned char *) RSTRING_PTR(string); - - size_t offset = integer->length * 8; - for (size_t value_index = 0; value_index < integer->length; value_index++) { - uint32_t value = integer->values[value_index]; - - for (int index = 0; index < 8; index++) { - int byte = (value >> (4 * index)) & 0xf; - bytes[--offset] = byte < 10 ? byte + '0' : byte - 10 + 'a'; - } - } - - result = rb_funcall(string, rb_intern("to_i"), 1, UINT2NUM(16)); + // The pm_integer_t stores values as an array of uint32_t in + // least-significant-word-first order (base 2^32). We can convert + // directly to a Ruby Integer using rb_integer_unpack, avoiding the + // overhead of constructing a hex string and calling rb_funcall. + result = rb_integer_unpack( + integer->values, + integer->length, + sizeof(uint32_t), + 0, + INTEGER_PACK_LSWORD_FIRST | INTEGER_PACK_NATIVE_BYTE_ORDER + ); } if (integer->negative) { - result = rb_funcall(result, rb_intern("-@"), 0); + result = rb_int_uminus(result); } if (!SPECIAL_CONST_P(result)) { From 1a0e67d3e3c32627b66a3da1a675da828b6aa45c Mon Sep 17 00:00:00 2001 From: Alan Wu Date: Mon, 9 Mar 2026 18:57:02 -0400 Subject: [PATCH 3/4] ZJIT: `::RubyVM::ZJIT.induce_side_exit!` and `induce_compile_failure!` Tells ZJIT to do a side exit or to fail to compile, useful testing and for bug reports. We are picky about the syntactic form so we can tell where the call lands early in the compiler pipeline. The `::` prefix allows us to interpret it without needing to know under what lexical scope the iseq will run. Special semantics with ZJIT doesn't interfere with ruby level semantics; running these methods do nothing in all modes. So it's no problem to call these without ZJIT. --- zjit.c | 1 + zjit.rb | 14 ++++ zjit/bindgen/src/main.rs | 2 + zjit/src/cruby.rs | 6 +- zjit/src/cruby_bindings.inc.rs | 16 +++++ zjit/src/hir.rs | 57 ++++++++++++++- zjit/src/hir/tests.rs | 127 ++++++++++++++++++++++++++++++++- zjit/src/state.rs | 41 +++++++++++ zjit/src/stats.rs | 4 ++ 9 files changed, 263 insertions(+), 5 deletions(-) diff --git a/zjit.c b/zjit.c index 0c463334cde42a..68336858e958b1 100644 --- a/zjit.c +++ b/zjit.c @@ -19,6 +19,7 @@ #include "vm_insnhelper.h" #include "probes.h" #include "probes_helper.h" +#include "constant.h" #include "iseq.h" #include "ruby/debug.h" #include "internal/cont.h" diff --git a/zjit.rb b/zjit.rb index 82630612749f44..bdaf557426955c 100644 --- a/zjit.rb +++ b/zjit.rb @@ -44,6 +44,20 @@ def trace_exit_locations_enabled? Primitive.rb_zjit_trace_exit_locations_enabled_p end + # A directive for the compiler to fail to compile the call to this method. + # To show this to ZJIT, say `::RubyVM::ZJIT.induce_compile_failure!` verbatim. + # Other forms are too dynamic to detect during compilation. + # + # Actually running this method does nothing, whether ZJIT sees the call or not. + def induce_compile_failure! = nil + + # A directive for the compiler to exit out of compiled code at the call site of this method. + # To show this to ZJIT, say `::RubyVM::ZJIT.induce_side_exit!` verbatim. + # Other forms are too dynamic to detect during compilation. + # + # Actually running this method does nothing, whether ZJIT sees the call or not. + def induce_side_exit! = nil + # If --zjit-trace-exits is enabled parse the hashes from # Primitive.rb_zjit_get_exit_locations into a format readable # by Stackprof. This will allow us to find the exact location of a diff --git a/zjit/bindgen/src/main.rs b/zjit/bindgen/src/main.rs index 18436270d74ba2..7144e73737ba03 100644 --- a/zjit/bindgen/src/main.rs +++ b/zjit/bindgen/src/main.rs @@ -145,6 +145,7 @@ fn main() { .allowlist_function("rb_gc_location") .allowlist_function("rb_gc_writebarrier") .allowlist_function("rb_gc_writebarrier_remember") + .allowlist_function("rb_gc_register_mark_object") .allowlist_function("rb_zjit_writebarrier_check_immediate") // VALUE variables for Ruby class objects @@ -432,6 +433,7 @@ fn main() { .allowlist_function("rb_vm_base_ptr") .allowlist_function("rb_ec_stack_check") .allowlist_function("rb_vm_top_self") + .allowlist_function("rb_const_lookup") // We define these manually, don't import them .blocklist_type("VALUE") diff --git a/zjit/src/cruby.rs b/zjit/src/cruby.rs index 38b99edb4bd3f6..ed59dd7d9f23c2 100644 --- a/zjit/src/cruby.rs +++ b/zjit/src/cruby.rs @@ -1544,6 +1544,10 @@ pub(crate) mod ids { name: RUBY_FL_FREEZE name: RUBY_ELTS_SHARED name: VM_FRAME_FLAG_MODIFIED_BLOCK_PARAM + name: RubyVM + name: ZJIT + name: induce_side_exit_bang content: b"induce_side_exit!" + name: induce_compile_failure_bang content: b"induce_compile_failure!" } /// Get an CRuby `ID` to an interned string, e.g. a particular method name. @@ -1551,7 +1555,7 @@ pub(crate) mod ids { ($id_name:ident) => {{ let id = $crate::cruby::ids::$id_name.load(std::sync::atomic::Ordering::Relaxed); debug_assert_ne!(0, id, "ids module should be initialized"); - ID(id) + $crate::cruby::ID(id) }} } pub(crate) use ID; diff --git a/zjit/src/cruby_bindings.inc.rs b/zjit/src/cruby_bindings.inc.rs index baf39100f48d71..76fdcd93f7a1be 100644 --- a/zjit/src/cruby_bindings.inc.rs +++ b/zjit/src/cruby_bindings.inc.rs @@ -1471,6 +1471,20 @@ pub const SHAPE_ID_FL_TOO_COMPLEX: shape_id_fl_type = 67108864; pub const SHAPE_ID_FL_NON_CANONICAL_MASK: shape_id_fl_type = 50331648; pub const SHAPE_ID_FLAGS_MASK: shape_id_fl_type = 133693440; pub type shape_id_fl_type = u32; +pub const CONST_DEPRECATED: rb_const_flag_t = 256; +pub const CONST_VISIBILITY_MASK: rb_const_flag_t = 255; +pub const CONST_PUBLIC: rb_const_flag_t = 0; +pub const CONST_PRIVATE: rb_const_flag_t = 1; +pub const CONST_VISIBILITY_MAX: rb_const_flag_t = 2; +pub type rb_const_flag_t = u32; +#[repr(C)] +pub struct rb_const_entry_struct { + pub flag: rb_const_flag_t, + pub line: ::std::os::raw::c_int, + pub value: VALUE, + pub file: VALUE, +} +pub type rb_const_entry_t = rb_const_entry_struct; #[repr(C)] pub struct rb_cvar_class_tbl_entry { pub index: u32, @@ -1912,6 +1926,7 @@ unsafe extern "C" { pub fn rb_gc_location(obj: VALUE) -> VALUE; pub fn rb_gc_enable() -> VALUE; pub fn rb_gc_disable() -> VALUE; + pub fn rb_gc_register_mark_object(object: VALUE); pub fn rb_gc_writebarrier(old: VALUE, young: VALUE); pub fn rb_class_get_superclass(klass: VALUE) -> VALUE; pub fn rb_funcallv( @@ -2057,6 +2072,7 @@ unsafe extern "C" { original_shape_id: shape_id_t, id: ID, ) -> shape_id_t; + pub fn rb_const_lookup(klass: VALUE, id: ID) -> *mut rb_const_entry_t; pub fn rb_ivar_get_at_no_ractor_check(obj: VALUE, index: attr_index_t) -> VALUE; pub fn rb_gvar_get(arg1: ID) -> VALUE; pub fn rb_gvar_set(arg1: ID, arg2: VALUE) -> VALUE; diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 090bf579661eec..46363f9e9d7064 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -7,10 +7,12 @@ #![allow(clippy::match_like_matches_macro)] use crate::{ backend::lir::C_ARG_OPNDS, - cast::IntoUsize, codegen::local_idx_to_ep_offset, cruby::*, invariants::{self, has_singleton_class_of}, payload::{get_or_create_iseq_payload, IseqPayload}, options::{debug, get_option, DumpHIR}, state::ZJITState, json::Json + cast::IntoUsize, codegen::local_idx_to_ep_offset, cruby::*, invariants::{self, has_singleton_class_of}, payload::{get_or_create_iseq_payload, IseqPayload}, options::{debug, get_option, DumpHIR}, state::ZJITState, json::Json, + state, }; use std::{ - cell::RefCell, collections::{BTreeSet, HashMap, HashSet, VecDeque}, ffi::{c_void, c_uint, c_int, CStr}, fmt::Display, mem::{align_of, size_of}, ptr, slice::Iter + cell::RefCell, collections::{BTreeSet, HashMap, HashSet, VecDeque}, ffi::{c_void, c_uint, c_int, CStr}, fmt::Display, mem::{align_of, size_of}, ptr, slice::Iter, + sync::atomic::Ordering, }; use crate::hir_type::{Type, types}; use crate::hir_effect::{Effect, abstract_heaps, effects}; @@ -529,6 +531,7 @@ pub enum SideExitReason { SplatKwNotNilOrHash, SplatKwPolymorphic, SplatKwNotProfiled, + DirectiveInduced, } #[derive(Debug, Clone, Copy)] @@ -6590,6 +6593,7 @@ pub enum ParseError { MalformedIseq(u32), // insn_idx into iseq_encoded Validation(ValidationError), NotAllowed, + DirectiveInduced, } /// Return the number of locals in the current ISEQ (includes parameters) @@ -7120,7 +7124,32 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result { } YARVINSN_opt_getconstant_path => { let ic = get_arg(pc, 0).as_ptr(); - state.stack_push(fun.push_insn(block, Insn::GetConstantPath { ic, state: exit_id })); + let get_const_path = fun.push_insn(block, Insn::GetConstantPath { ic, state: exit_id }); + state.stack_push(get_const_path); + + // Check for `::RubyVM::ZJIT` for directives + unsafe { + let mut current_segment = (*ic).segments; + let mut segments = [ID(0); 4 /* expected segment length */]; + for segment in segments.iter_mut() { + *segment = current_segment.read(); + if *segment == ID(0) { + break; + } + current_segment = current_segment.add(1); + } + if [ID!(NULL), ID!(RubyVM), ID!(ZJIT), ID(0)] == segments { + debug_assert_ne!(ID!(NULL), ID(0)); + let ruby_vm_mod = rb_const_lookup(rb_cObject, ID!(RubyVM)); + if !ruby_vm_mod.is_null() && (*ruby_vm_mod).value == rb_cRubyVM { + let zjit_module = VALUE(state::ZJIT_MODULE.load(Ordering::Relaxed)); + let lookedup_module = rb_const_lookup(rb_cRubyVM, ID!(ZJIT)); + if !lookedup_module.is_null() && (*lookedup_module).value == zjit_module { + fun.insn_types[get_const_path.0] = Type::from_value(zjit_module); + } + } + } + } } YARVINSN_branchunless | YARVINSN_branchunless_without_ints => { if opcode == YARVINSN_branchunless { @@ -7581,6 +7610,28 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result { break; // End the block } let argc = unsafe { vm_ci_argc((*cd).ci) }; + let mid = unsafe { rb_vm_ci_mid(call_info) }; + + // Check for calls to directives + if argc == 0 + && (mid == ID!(induce_side_exit_bang) || mid == ID!(induce_compile_failure_bang)) + && fun.type_of(state.stack_top()?) + .ruby_object() + .is_some_and(|obj| obj == VALUE(state::ZJIT_MODULE.load(Ordering::Relaxed))) + { + + if mid == ID!(induce_side_exit_bang) + && state::zjit_module_method_match_serial(ID!(induce_side_exit_bang), &state::INDUCE_SIDE_EXIT_SERIAL) + { + fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::DirectiveInduced }); + break; // End the block + } + if mid == ID!(induce_compile_failure_bang) + && state::zjit_module_method_match_serial(ID!(induce_compile_failure_bang), &state::INDUCE_COMPILE_FAILURE_SERIAL) + { + return Err(ParseError::DirectiveInduced); + } + } { fn new_branch_block( diff --git a/zjit/src/hir/tests.rs b/zjit/src/hir/tests.rs index b26f5a989c1374..c6538a9eda1623 100644 --- a/zjit/src/hir/tests.rs +++ b/zjit/src/hir/tests.rs @@ -4791,7 +4791,132 @@ pub mod hir_build_tests { Jump bb8(v67, v78) "); } - } + + #[test] + fn test_induce_side_exit() { + eval(" + class NonTopLexicalScope + RubyVM = 0 + def test + RubyVM::ZJIT.induce_side_exit! # lexical scope dependant -- should not recognize + ::RubyVM::ZJIT.induce_side_exit! + end + end + "); + assert_snapshot!(hir_string_proc("NonTopLexicalScope.instance_method(:test)"), @" + fn test@:5: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb3(v1) + bb2(): + EntryPoint JIT(0) + v4:BasicObject = LoadArg :self@0 + Jump bb3(v4) + bb3(v6:BasicObject): + v10:BasicObject = GetConstantPath 0x1000 + v12:BasicObject = Send v10, :induce_side_exit! # SendFallbackReason: Uncategorized(opt_send_without_block) + v16:BasicObject = GetConstantPath 0x1000 + SideExit DirectiveInduced + "); + } + + #[test] + fn test_induce_side_exit_sensitive_to_constant_state() { + eval(" + def test = ::RubyVM::ZJIT.induce_side_exit! + "); + assert!(hir_string("test").contains("SideExit DirectiveInduced")); + eval(" + class RubyVM + remove_const(:ZJIT) + end + "); + let hir_after_removal = hir_string("test"); + assert_eq!(false, hir_string("test").contains("SideExit DirectiveInduced"), "should not work when the constant lookup would fail"); + assert_snapshot!(hir_after_removal, @" + fn test@:2: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb3(v1) + bb2(): + EntryPoint JIT(0) + v4:BasicObject = LoadArg :self@0 + Jump bb3(v4) + bb3(v6:BasicObject): + v10:BasicObject = GetConstantPath 0x1000 + v12:BasicObject = Send v10, :induce_side_exit! # SendFallbackReason: Uncategorized(opt_send_without_block) + CheckInterrupts + Return v12 + "); + } + + #[test] + fn test_induce_side_exit_doesnt_work_when_method_after_undef() { + eval(" + class << RubyVM::ZJIT + undef :induce_side_exit! + end + def test = ::RubyVM::ZJIT.induce_side_exit! + "); + assert_eq!(false, hir_string("test").contains("SideExit DirectiveInduced"), "should not work after undef"); + } + + #[test] + fn test_induce_compile_failure_does_not_trigger_autoload() { + eval(" + class RubyVM + remove_const(:ZJIT) + autoload :ZJIT, 'a-file-that-does-not-exist-as-a-trap' + end + def test = ::RubyVM::ZJIT.induce_compile_failure! + "); + assert_snapshot!(hir_string("test"), @" + fn test@:6: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb3(v1) + bb2(): + EntryPoint JIT(0) + v4:BasicObject = LoadArg :self@0 + Jump bb3(v4) + bb3(v6:BasicObject): + v10:BasicObject = GetConstantPath 0x1000 + v12:BasicObject = Send v10, :induce_compile_failure! # SendFallbackReason: Uncategorized(opt_send_without_block) + CheckInterrupts + Return v12 + "); + } + + #[test] + fn test_induce_compile_failure_checks_full_const_path() { + eval("def test = ::RubyVM::ZJIT::TooDeep.induce_compile_failure!"); + assert_snapshot!(hir_string("test"), @" + fn test@:1: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb3(v1) + bb2(): + EntryPoint JIT(0) + v4:BasicObject = LoadArg :self@0 + Jump bb3(v4) + bb3(v6:BasicObject): + v10:BasicObject = GetConstantPath 0x1000 + v12:BasicObject = Send v10, :induce_compile_failure! # SendFallbackReason: Uncategorized(opt_send_without_block) + CheckInterrupts + Return v12 + "); + } + + #[test] + fn test_induce_compile_failure() { + eval("def test = ::RubyVM::ZJIT.induce_compile_failure!"); + assert_compile_fails("test", ParseError::DirectiveInduced); + } +} /// Test successor and predecessor set computations. #[cfg(test)] diff --git a/zjit/src/state.rs b/zjit/src/state.rs index df8239699ff5e2..7fa1eb862dc163 100644 --- a/zjit/src/state.rs +++ b/zjit/src/state.rs @@ -3,11 +3,14 @@ use crate::codegen::{gen_entry_trampoline, gen_exit_trampoline, gen_exit_trampoline_with_counter, gen_function_stub_hit_trampoline}; use crate::cruby::{self, rb_bug_panic_hook, rb_vm_insn_count, src_loc, EcPtr, Qnil, Qtrue, rb_vm_insn_addr2opcode, rb_profile_frames, VALUE, VM_INSTRUCTION_SIZE, size_t, rb_gc_mark, with_vm_lock, rust_str_to_id, rb_funcallv, rb_const_get, rb_cRubyVM}; use crate::cruby_methods; +use cruby::{ID, rb_callable_method_entry, get_def_method_serial, rb_gc_register_mark_object}; +use std::sync::atomic::Ordering; use crate::invariants::Invariants; use crate::asm::CodeBlock; use crate::options::{get_option, rb_zjit_prepare_options}; use crate::stats::{Counters, InsnCounters, SideExitLocations}; use crate::virtualmem::CodePtr; +use std::sync::atomic::AtomicUsize; use std::collections::HashMap; use std::ptr::null; @@ -301,6 +304,25 @@ impl ZJITState { } } +/// The `::RubyVM::ZJIT` module. +pub static ZJIT_MODULE: AtomicUsize = AtomicUsize::new(!0); +/// Serial of the canonical version of `` right after VM boot. +pub static INDUCE_SIDE_EXIT_SERIAL: AtomicUsize = AtomicUsize::new(!0); +/// Serial of the canonical version of `` right after VM boot. +pub static INDUCE_COMPILE_FAILURE_SERIAL: AtomicUsize = AtomicUsize::new(!0); + +/// Check if a method, `method_id`, currently exists on `ZJIT.singleton_class` and has the `expected_serial`. +pub fn zjit_module_method_match_serial(method_id: ID, expected_serial: &AtomicUsize) -> bool { + let zjit_module_singleton = VALUE(ZJIT_MODULE.load(Ordering::Relaxed)).class_of(); + let cme = unsafe { rb_callable_method_entry(zjit_module_singleton, method_id) }; + if cme.is_null() { + false + } else { + let serial = unsafe { get_def_method_serial((*cme).def) }; + serial == expected_serial.load(std::sync::atomic::Ordering::Relaxed) + } +} + /// Initialize IDs and annotate builtin C method entries. /// Must be called at boot before ruby_init_prelude() since the prelude /// could redefine core methods (e.g. Kernel.prepend via bundler). @@ -318,6 +340,25 @@ pub extern "C" fn rb_zjit_init_builtin_cmes() { let method_annotations = cruby_methods::init(); unsafe { ZJIT_STATE = Initialized(method_annotations); } + + // Boot time setup for compiler directives + unsafe { + let zjit_module = rb_const_get(rb_cRubyVM, rust_str_to_id("ZJIT")); + + let cme = rb_callable_method_entry(zjit_module.class_of(), ID!(induce_side_exit_bang)); + assert!(! cme.is_null(), "RubyVM::ZJIT.induce_side_exit! should exist on boot"); + let serial = get_def_method_serial((*cme).def) ; + INDUCE_SIDE_EXIT_SERIAL.store(serial, Ordering::Relaxed); + + let cme = rb_callable_method_entry(zjit_module.class_of(), ID!(induce_compile_failure_bang)); + assert!(! cme.is_null(), "RubyVM::ZJIT.induce_compile_failure! should exist on boot"); + let serial = get_def_method_serial((*cme).def) ; + INDUCE_COMPILE_FAILURE_SERIAL.store(serial, Ordering::Relaxed); + + // Root and pin the module since we'll be doing object identity comparisons. + ZJIT_MODULE.store(zjit_module.0, Ordering::Relaxed); + rb_gc_register_mark_object(zjit_module); + } } /// Initialize ZJIT at boot. This is called even if ZJIT is disabled. diff --git a/zjit/src/stats.rs b/zjit/src/stats.rs index db204f032f9340..47a434360da5ff 100644 --- a/zjit/src/stats.rs +++ b/zjit/src/stats.rs @@ -231,6 +231,7 @@ make_counters! { exit_splatkw_not_nil_or_hash, exit_splatkw_polymorphic, exit_splatkw_not_profiled, + exit_directive_induced, } // Send fallback counters that are summed as dynamic_send_count @@ -335,6 +336,7 @@ make_counters! { compile_error_parse_stack_underflow, compile_error_parse_malformed_iseq, compile_error_parse_not_allowed, + compile_error_parse_directive_induced, compile_error_validation_block_has_no_terminator, compile_error_validation_terminator_not_at_end, compile_error_validation_mismatched_block_arity, @@ -521,6 +523,7 @@ pub fn exit_counter_for_compile_error(compile_error: &CompileError) -> Counter { StackUnderflow(_) => compile_error_parse_stack_underflow, MalformedIseq(_) => compile_error_parse_malformed_iseq, NotAllowed => compile_error_parse_not_allowed, + DirectiveInduced => compile_error_parse_directive_induced, Validation(validation) => match validation { BlockHasNoTerminator(_) => compile_error_validation_block_has_no_terminator, TerminatorNotAtEnd(_, _, _) => compile_error_validation_terminator_not_at_end, @@ -596,6 +599,7 @@ pub fn side_exit_counter(reason: crate::hir::SideExitReason) -> Counter { SplatKwNotNilOrHash => exit_splatkw_not_nil_or_hash, SplatKwPolymorphic => exit_splatkw_polymorphic, SplatKwNotProfiled => exit_splatkw_not_profiled, + DirectiveInduced => exit_directive_induced, PatchPoint(Invariant::BOPRedefined { .. }) => exit_patchpoint_bop_redefined, PatchPoint(Invariant::MethodRedefined { .. }) From b0754fbd4319eaa585cf0e446aeccbb45fbcf5bf Mon Sep 17 00:00:00 2001 From: Nozomi Hijikata <121233810+nozomemein@users.noreply.github.com> Date: Fri, 13 Mar 2026 02:00:19 +0900 Subject: [PATCH 4/4] ZJIT: Deduplicate polymorphic send branches by receiver type (#16335) When lowering polymorphic `opt_send_without_block`, emit at most one branch per receiver type instead of one branch per profile bucket (`class, shape, flags`). This deduplicates redundant `HasType`/`IfTrue` chains while preserving immediate/heap splits under the same class (for example, `Fixnum`/`Bignum` and `StaticSymbol`/`DynamicSymbol`), reducing branch/codegen overhead without changing dispatch semantics. Stats diff (`railsbench`): - `code_region_bytes`: `15,171,584` -> `14,974,976` (`-1.30%`) - `side_exit_size`: `4,999,792` -> `4,955,020` (`-0.90%`) --- zjit/src/hir.rs | 8 +- zjit/src/hir/opt_tests.rs | 224 +++++++++++++++++++++++++++++++++++--- 2 files changed, 217 insertions(+), 15 deletions(-) diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 46363f9e9d7064..7fda53f3c9c371 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -7670,10 +7670,16 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result { let entry_args = state.as_args(self_param); if let Some(summary) = fun.polymorphic_summary(&profiles, recv, exit_state.insn_idx) { let join_block = insn_idx_to_block.get(&insn_idx).copied().unwrap_or_else(|| fun.new_block(insn_idx)); - // TODO(max): Only iterate over unique classes, not unique (class, shape) pairs. + // Dedup by expected type so immediate/heap variants + // under the same Ruby class can still get separate branches. + let mut seen_types = Vec::with_capacity(summary.buckets().len()); for &profiled_type in summary.buckets() { if profiled_type.is_empty() { break; } let expected = Type::from_profiled_type(profiled_type); + if seen_types.iter().any(|ty: &Type| ty.bit_equal(expected)) { + continue; + } + seen_types.push(expected); let has_type = fun.push_insn(block, Insn::HasType { val: recv, expected }); let iftrue_block = new_branch_block(&mut fun, cd, argc as usize, opcode, expected, branch_insn_idx, &exit_state, locals_count, stack_count, join_block); diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index b644936c0c4e36..5df49c5ae6ba8f 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -7406,25 +7406,17 @@ mod hir_opt_tests { bb3(v9:BasicObject, v10:BasicObject): v15:CBool = HasType v10, ObjectSubclass[class_exact:C] IfTrue v15, bb5(v9, v10, v10) - v24:CBool = HasType v10, ObjectSubclass[class_exact:C] - IfTrue v24, bb6(v9, v10, v10) - v33:BasicObject = Send v10, :foo # SendFallbackReason: SendWithoutBlock: polymorphic fallback - Jump bb4(v9, v10, v33) + v24:BasicObject = Send v10, :foo # SendFallbackReason: SendWithoutBlock: polymorphic fallback + Jump bb4(v9, v10, v24) bb5(v16:BasicObject, v17:BasicObject, v18:BasicObject): v20:ObjectSubclass[class_exact:C] = RefineType v18, ObjectSubclass[class_exact:C] PatchPoint NoSingletonClass(C@0x1008) PatchPoint MethodRedefined(C@0x1008, foo@0x1010, cme:0x1018) - v46:BasicObject = GetIvar v20, :@foo - Jump bb4(v16, v17, v46) - bb6(v25:BasicObject, v26:BasicObject, v27:BasicObject): - v29:ObjectSubclass[class_exact:C] = RefineType v27, ObjectSubclass[class_exact:C] - PatchPoint NoSingletonClass(C@0x1008) - PatchPoint MethodRedefined(C@0x1008, foo@0x1010, cme:0x1018) - v49:BasicObject = GetIvar v29, :@foo - Jump bb4(v25, v26, v49) - bb4(v35:BasicObject, v36:BasicObject, v37:BasicObject): + v37:BasicObject = GetIvar v20, :@foo + Jump bb4(v16, v17, v37) + bb4(v26:BasicObject, v27:BasicObject, v28:BasicObject): CheckInterrupts - Return v37 + Return v28 "); } @@ -13834,6 +13826,210 @@ mod hir_opt_tests { "); } + #[test] + fn specialize_polymorphic_send_fixnum_and_bignum() { + // Fixnum and Bignum both have class Integer, but they should be + // treated as different types for polymorphic dispatch because + // Fixnum is an immediate and Bignum is a heap object. + set_call_threshold(4); + eval(" + def test x + x.to_s + end + + fixnum = 1 + bignum = 10**100 + test(fixnum) + test(bignum) + test(fixnum) + test(bignum) + "); + assert_snapshot!(hir_string("test"), @" + fn test@:3: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:CPtr = LoadSP + v3:BasicObject = LoadField v2, :x@0x1000 + Jump bb3(v1, v3) + bb2(): + EntryPoint JIT(0) + v6:BasicObject = LoadArg :self@0 + v7:BasicObject = LoadArg :x@1 + Jump bb3(v6, v7) + bb3(v9:BasicObject, v10:BasicObject): + v15:CBool = HasType v10, Fixnum + IfTrue v15, bb5(v9, v10, v10) + v24:CBool = HasType v10, Bignum + IfTrue v24, bb6(v9, v10, v10) + v33:BasicObject = Send v10, :to_s # SendFallbackReason: SendWithoutBlock: polymorphic fallback + Jump bb4(v9, v10, v33) + bb5(v16:BasicObject, v17:BasicObject, v18:BasicObject): + v20:Fixnum = RefineType v18, Fixnum + PatchPoint MethodRedefined(Integer@0x1008, to_s@0x1010, cme:0x1018) + v46:StringExact = CCallVariadic v20, :Integer#to_s@0x1040 + Jump bb4(v16, v17, v46) + bb6(v25:BasicObject, v26:BasicObject, v27:BasicObject): + v29:Bignum = RefineType v27, Bignum + PatchPoint MethodRedefined(Integer@0x1008, to_s@0x1010, cme:0x1018) + v49:StringExact = CCallVariadic v29, :Integer#to_s@0x1040 + Jump bb4(v25, v26, v49) + bb4(v35:BasicObject, v36:BasicObject, v37:BasicObject): + CheckInterrupts + Return v37 + "); + } + + #[test] + fn specialize_polymorphic_send_flonum_and_heap_float() { + set_call_threshold(4); + eval(" + def test x + x.to_s + end + + flonum = 1.5 + heap_float = 1.7976931348623157e+308 + test(flonum) + test(heap_float) + test(flonum) + test(heap_float) + "); + assert_snapshot!(hir_string("test"), @" + fn test@:3: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:CPtr = LoadSP + v3:BasicObject = LoadField v2, :x@0x1000 + Jump bb3(v1, v3) + bb2(): + EntryPoint JIT(0) + v6:BasicObject = LoadArg :self@0 + v7:BasicObject = LoadArg :x@1 + Jump bb3(v6, v7) + bb3(v9:BasicObject, v10:BasicObject): + v15:CBool = HasType v10, Flonum + IfTrue v15, bb5(v9, v10, v10) + v24:CBool = HasType v10, HeapFloat + IfTrue v24, bb6(v9, v10, v10) + v33:BasicObject = Send v10, :to_s # SendFallbackReason: SendWithoutBlock: polymorphic fallback + Jump bb4(v9, v10, v33) + bb5(v16:BasicObject, v17:BasicObject, v18:BasicObject): + v20:Flonum = RefineType v18, Flonum + PatchPoint MethodRedefined(Float@0x1008, to_s@0x1010, cme:0x1018) + v46:BasicObject = CCallWithFrame v20, :Float#to_s@0x1040 + Jump bb4(v16, v17, v46) + bb6(v25:BasicObject, v26:BasicObject, v27:BasicObject): + v29:HeapFloat = RefineType v27, HeapFloat + PatchPoint MethodRedefined(Float@0x1008, to_s@0x1010, cme:0x1018) + v49:BasicObject = CCallWithFrame v29, :Float#to_s@0x1040 + Jump bb4(v25, v26, v49) + bb4(v35:BasicObject, v36:BasicObject, v37:BasicObject): + CheckInterrupts + Return v37 + "); + } + + #[test] + fn specialize_polymorphic_send_static_and_dynamic_symbol() { + set_call_threshold(4); + eval(" + def test x + x.to_s + end + + static_sym = :foo + dynamic_sym = (\"zjit_dynamic_\" + Object.new.object_id.to_s).to_sym + test static_sym + test dynamic_sym + test static_sym + test dynamic_sym + "); + assert_snapshot!(hir_string("test"), @" + fn test@:3: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:CPtr = LoadSP + v3:BasicObject = LoadField v2, :x@0x1000 + Jump bb3(v1, v3) + bb2(): + EntryPoint JIT(0) + v6:BasicObject = LoadArg :self@0 + v7:BasicObject = LoadArg :x@1 + Jump bb3(v6, v7) + bb3(v9:BasicObject, v10:BasicObject): + v15:CBool = HasType v10, StaticSymbol + IfTrue v15, bb5(v9, v10, v10) + v24:CBool = HasType v10, DynamicSymbol + IfTrue v24, bb6(v9, v10, v10) + v33:BasicObject = Send v10, :to_s # SendFallbackReason: SendWithoutBlock: polymorphic fallback + Jump bb4(v9, v10, v33) + bb5(v16:BasicObject, v17:BasicObject, v18:BasicObject): + v20:StaticSymbol = RefineType v18, StaticSymbol + PatchPoint MethodRedefined(Symbol@0x1008, to_s@0x1010, cme:0x1018) + v48:StringExact = InvokeBuiltin leaf , v20 + Jump bb4(v16, v17, v48) + bb6(v25:BasicObject, v26:BasicObject, v27:BasicObject): + v29:DynamicSymbol = RefineType v27, DynamicSymbol + PatchPoint MethodRedefined(Symbol@0x1008, to_s@0x1010, cme:0x1018) + v49:StringExact = InvokeBuiltin leaf , v29 + Jump bb4(v25, v26, v49) + bb4(v35:BasicObject, v36:BasicObject, v37:BasicObject): + CheckInterrupts + Return v37 + "); + } + + #[test] + fn specialize_polymorphic_send_iseq_duplicate_class_profiles() { + set_call_threshold(4); + eval(" + class C + def foo = 3 + end + + O1 = C.new + O1.instance_variable_set(:@foo, 1) + O2 = C.new + O2.instance_variable_set(:@bar, 2) + + def test o + o.foo + end + + test O1; test O2; test O1; test O2 + "); + assert_snapshot!(hir_string("test"), @" + fn test@:12: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:CPtr = LoadSP + v3:BasicObject = LoadField v2, :o@0x1000 + Jump bb3(v1, v3) + bb2(): + EntryPoint JIT(0) + v6:BasicObject = LoadArg :self@0 + v7:BasicObject = LoadArg :o@1 + Jump bb3(v6, v7) + bb3(v9:BasicObject, v10:BasicObject): + v15:CBool = HasType v10, ObjectSubclass[class_exact:C] + IfTrue v15, bb5(v9, v10, v10) + v24:BasicObject = Send v10, :foo # SendFallbackReason: SendWithoutBlock: polymorphic fallback + Jump bb4(v9, v10, v24) + bb5(v16:BasicObject, v17:BasicObject, v18:BasicObject): + PatchPoint NoSingletonClass(C@0x1008) + PatchPoint MethodRedefined(C@0x1008, foo@0x1010, cme:0x1018) + v38:Fixnum[3] = Const Value(3) + Jump bb4(v16, v17, v38) + bb4(v26:BasicObject, v27:BasicObject, v28:BasicObject): + CheckInterrupts + Return v28 + "); + } + #[test] fn upgrade_self_type_to_heap_after_setivar() { eval("