diff --git a/prism_compile.c b/prism_compile.c index c7facf6e22548f..6bc1da58d0bab0 100644 --- a/prism_compile.c +++ b/prism_compile.c @@ -140,33 +140,114 @@ pm_iseq_add_setlocal(rb_iseq_t *iseq, LINK_ANCHOR *const seq, int line, int node #define PM_COMPILE_NOT_POPPED(node) \ pm_compile_node(iseq, (node), ret, false, scope_node) +/** + * Cached line lookup that avoids repeated binary searches. Since the compiler + * walks the AST roughly in source order, consecutive lookups tend to be for + * nearby byte offsets. We cache the last result index in the scope node and + * try a short linear probe from there before falling back to binary search. + */ +static inline pm_line_column_t +pm_line_offset_list_line_column_cached(const pm_line_offset_list_t *list, uint32_t cursor, int32_t start_line, size_t *last_line) +{ + size_t hint = *last_line; + size_t size = list->size; + const uint32_t *offsets = list->offsets; + + RUBY_ASSERT(hint < size); + + /* Check if the cursor is on the same line as the hint. */ + if (offsets[hint] <= cursor) { + if (hint + 1 >= size || offsets[hint + 1] > cursor) { + *last_line = hint; + return ((pm_line_column_t) { + .line = ((int32_t) hint) + start_line, + .column = cursor - offsets[hint] + }); + } + + /* Linear scan forward (up to 8 lines before giving up). */ + size_t limit = hint + 9; + if (limit > size) limit = size; + for (size_t idx = hint + 1; idx < limit; idx++) { + if (offsets[idx] > cursor) { + *last_line = idx - 1; + return ((pm_line_column_t) { + .line = ((int32_t) (idx - 1)) + start_line, + .column = cursor - offsets[idx - 1] + }); + } + if (offsets[idx] == cursor) { + *last_line = idx; + return ((pm_line_column_t) { ((int32_t) idx) + start_line, 0 }); + } + } + } + else { + /* Linear scan backward (up to 8 lines before giving up). */ + size_t limit = hint > 8 ? hint - 8 : 0; + for (size_t idx = hint; idx > limit; idx--) { + if (offsets[idx - 1] <= cursor) { + *last_line = idx - 1; + return ((pm_line_column_t) { + .line = ((int32_t) (idx - 1)) + start_line, + .column = cursor - offsets[idx - 1] + }); + } + } + } + + /* Fall back to binary search. */ + pm_line_column_t result = pm_line_offset_list_line_column(list, cursor, start_line); + *last_line = (size_t) (result.line - start_line); + return result; +} + +/** + * The same as pm_line_offset_list_line_column_cached, but returning only the + * line number. + */ +static inline int32_t +pm_line_offset_list_line_cached(const pm_line_offset_list_t *list, uint32_t cursor, int32_t start_line, size_t *last_line) +{ + return pm_line_offset_list_line_column_cached(list, cursor, start_line, last_line).line; +} + #define PM_NODE_START_LOCATION(parser, node) \ - ((pm_node_location_t) { .line = pm_line_offset_list_line(&(parser)->line_offsets, ((const pm_node_t *) (node))->location.start, (parser)->start_line), .node_id = ((const pm_node_t *) (node))->node_id }) + ((pm_node_location_t) { .line = pm_line_offset_list_line_cached(&(parser)->line_offsets, ((const pm_node_t *) (node))->location.start, (parser)->start_line, &scope_node->last_line), .node_id = ((const pm_node_t *) (node))->node_id }) #define PM_NODE_END_LOCATION(parser, node) \ - ((pm_node_location_t) { .line = pm_line_offset_list_line(&(parser)->line_offsets, ((const pm_node_t *) (node))->location.start + ((const pm_node_t *) (node))->location.length, (parser)->start_line), .node_id = ((const pm_node_t *) (node))->node_id }) + ((pm_node_location_t) { .line = pm_line_offset_list_line_cached(&(parser)->line_offsets, ((const pm_node_t *) (node))->location.start + ((const pm_node_t *) (node))->location.length, (parser)->start_line, &scope_node->last_line), .node_id = ((const pm_node_t *) (node))->node_id }) #define PM_LOCATION_START_LOCATION(parser, location, id) \ - ((pm_node_location_t) { .line = pm_line_offset_list_line(&(parser)->line_offsets, (location)->start, (parser)->start_line), .node_id = id }) + ((pm_node_location_t) { .line = pm_line_offset_list_line_cached(&(parser)->line_offsets, (location)->start, (parser)->start_line, &scope_node->last_line), .node_id = id }) #define PM_NODE_START_LINE_COLUMN(parser, node) \ - pm_line_offset_list_line_column(&(parser)->line_offsets, ((const pm_node_t *) (node))->location.start, (parser)->start_line) + pm_line_offset_list_line_column_cached(&(parser)->line_offsets, ((const pm_node_t *) (node))->location.start, (parser)->start_line, &scope_node->last_line) #define PM_NODE_END_LINE_COLUMN(parser, node) \ - pm_line_offset_list_line_column(&(parser)->line_offsets, ((const pm_node_t *) (node))->location.start + ((const pm_node_t *) (node))->location.length, (parser)->start_line) + pm_line_offset_list_line_column_cached(&(parser)->line_offsets, ((const pm_node_t *) (node))->location.start + ((const pm_node_t *) (node))->location.length, (parser)->start_line, &scope_node->last_line) #define PM_LOCATION_START_LINE_COLUMN(parser, location) \ - pm_line_offset_list_line_column(&(parser)->line_offsets, (location)->start, (parser)->start_line) + pm_line_offset_list_line_column_cached(&(parser)->line_offsets, (location)->start, (parser)->start_line, &scope_node->last_line) static int -pm_node_line_number(const pm_parser_t *parser, const pm_node_t *node) +pm_location_line_number(const pm_parser_t *parser, const pm_location_t *location) { + return (int) pm_line_offset_list_line(&parser->line_offsets, location->start, parser->start_line); +} + +/** + * Cached variants that use the scope node's hint for fast lookups during + * compilation (where access patterns are roughly sequential). + */ +static inline int +pm_node_line_number_cached(const pm_parser_t *parser, const pm_node_t *node, pm_scope_node_t *scope_node) { - return (int) pm_line_offset_list_line(&parser->line_offsets, node->location.start, parser->start_line); + return (int) pm_line_offset_list_line_cached(&parser->line_offsets, node->location.start, parser->start_line, &scope_node->last_line); } -static int -pm_location_line_number(const pm_parser_t *parser, const pm_location_t *location) { - return (int) pm_line_offset_list_line(&parser->line_offsets, location->start, parser->start_line); +static inline int +pm_location_line_number_cached(const pm_parser_t *parser, const pm_location_t *location, pm_scope_node_t *scope_node) { + return (int) pm_line_offset_list_line_cached(&parser->line_offsets, location->start, parser->start_line, &scope_node->last_line); } /** @@ -181,24 +262,21 @@ parse_integer_value(const pm_integer_t *integer) result = UINT2NUM(integer->value); } else { - VALUE string = rb_str_new(NULL, integer->length * 8); - unsigned char *bytes = (unsigned char *) RSTRING_PTR(string); - - size_t offset = integer->length * 8; - for (size_t value_index = 0; value_index < integer->length; value_index++) { - uint32_t value = integer->values[value_index]; - - for (int index = 0; index < 8; index++) { - int byte = (value >> (4 * index)) & 0xf; - bytes[--offset] = byte < 10 ? byte + '0' : byte - 10 + 'a'; - } - } - - result = rb_funcall(string, rb_intern("to_i"), 1, UINT2NUM(16)); + // The pm_integer_t stores values as an array of uint32_t in + // least-significant-word-first order (base 2^32). We can convert + // directly to a Ruby Integer using rb_integer_unpack, avoiding the + // overhead of constructing a hex string and calling rb_funcall. + result = rb_integer_unpack( + integer->values, + integer->length, + sizeof(uint32_t), + 0, + INTEGER_PACK_LSWORD_FIRST | INTEGER_PACK_NATIVE_BYTE_ORDER + ); } if (integer->negative) { - result = rb_funcall(result, rb_intern("-@"), 0); + result = rb_int_uminus(result); } if (!SPECIAL_CONST_P(result)) { @@ -305,7 +383,7 @@ parse_string_encoded(const pm_node_t *node, const pm_string_t *string, rb_encodi } static inline VALUE -parse_static_literal_string(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *string) +parse_static_literal_string(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *string) { rb_encoding *encoding; @@ -323,7 +401,7 @@ parse_static_literal_string(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, rb_enc_str_coderange(value); if (ISEQ_COMPILE_DATA(iseq)->option->debug_frozen_string_literal || RTEST(ruby_debug)) { - int line_number = pm_node_line_number(scope_node->parser, node); + int line_number = pm_node_line_number_cached(scope_node->parser, node, scope_node); value = rb_ractor_make_shareable(rb_str_with_debug_created_info(value, rb_iseq_path(iseq), line_number)); } @@ -368,7 +446,7 @@ parse_regexp_error(rb_iseq_t *iseq, int32_t line_number, const char *fmt, ...) } static VALUE -parse_regexp_string_part(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *unescaped, rb_encoding *implicit_regexp_encoding, rb_encoding *explicit_regexp_encoding) +parse_regexp_string_part(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *unescaped, rb_encoding *implicit_regexp_encoding, rb_encoding *explicit_regexp_encoding) { // If we were passed an explicit regexp encoding, then we need to double // check that it's okay here for this fragment of the string. @@ -390,12 +468,12 @@ parse_regexp_string_part(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, con VALUE string = rb_enc_str_new((const char *) pm_string_source(unescaped), pm_string_length(unescaped), encoding); VALUE error = rb_reg_check_preprocess(string); - if (error != Qnil) parse_regexp_error(iseq, pm_node_line_number(scope_node->parser, node), "%" PRIsVALUE, rb_obj_as_string(error)); + if (error != Qnil) parse_regexp_error(iseq, pm_node_line_number_cached(scope_node->parser, node, scope_node), "%" PRIsVALUE, rb_obj_as_string(error)); return string; } static VALUE -pm_static_literal_concat(rb_iseq_t *iseq, const pm_node_list_t *nodes, const pm_scope_node_t *scope_node, rb_encoding *implicit_regexp_encoding, rb_encoding *explicit_regexp_encoding, bool top) +pm_static_literal_concat(rb_iseq_t *iseq, const pm_node_list_t *nodes, pm_scope_node_t *scope_node, rb_encoding *implicit_regexp_encoding, rb_encoding *explicit_regexp_encoding, bool top) { VALUE current = Qnil; @@ -412,7 +490,7 @@ pm_static_literal_concat(rb_iseq_t *iseq, const pm_node_list_t *nodes, const pm_ else { string = parse_string_encoded(part, &((const pm_string_node_t *) part)->unescaped, scope_node->encoding); VALUE error = rb_reg_check_preprocess(string); - if (error != Qnil) parse_regexp_error(iseq, pm_node_line_number(scope_node->parser, part), "%" PRIsVALUE, rb_obj_as_string(error)); + if (error != Qnil) parse_regexp_error(iseq, pm_node_line_number_cached(scope_node->parser, part, scope_node), "%" PRIsVALUE, rb_obj_as_string(error)); } } else { @@ -525,11 +603,11 @@ parse_regexp_encoding(const pm_scope_node_t *scope_node, const pm_node_t *node) } static VALUE -parse_regexp(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, VALUE string) +parse_regexp(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, VALUE string) { VALUE errinfo = rb_errinfo(); - int32_t line_number = pm_node_line_number(scope_node->parser, node); + int32_t line_number = pm_node_line_number_cached(scope_node->parser, node, scope_node); VALUE regexp = rb_reg_compile(string, parse_regexp_flags(node), (const char *) pm_string_source(&scope_node->parser->filepath), line_number); if (NIL_P(regexp)) { @@ -544,7 +622,7 @@ parse_regexp(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t } static inline VALUE -parse_regexp_literal(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *unescaped) +parse_regexp_literal(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, const pm_string_t *unescaped) { rb_encoding *regexp_encoding = parse_regexp_encoding(scope_node, node); if (regexp_encoding == NULL) regexp_encoding = scope_node->encoding; @@ -555,7 +633,7 @@ parse_regexp_literal(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const p } static inline VALUE -parse_regexp_concat(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, const pm_node_list_t *parts) +parse_regexp_concat(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, const pm_node_list_t *parts) { rb_encoding *explicit_regexp_encoding = parse_regexp_encoding(scope_node, node); rb_encoding *implicit_regexp_encoding = explicit_regexp_encoding != NULL ? explicit_regexp_encoding : scope_node->encoding; @@ -748,7 +826,7 @@ pm_static_literal_string(rb_iseq_t *iseq, VALUE string, int line_number) * literal values can be compiled into a literal array. */ static VALUE -pm_static_literal_value(rb_iseq_t *iseq, const pm_node_t *node, const pm_scope_node_t *scope_node) +pm_static_literal_value(rb_iseq_t *iseq, const pm_node_t *node, pm_scope_node_t *scope_node) { // Every node that comes into this function should already be marked as // static literal. If it's not, then we have a bug somewhere. @@ -804,7 +882,7 @@ pm_static_literal_value(rb_iseq_t *iseq, const pm_node_t *node, const pm_scope_n } case PM_INTERPOLATED_STRING_NODE: { VALUE string = pm_static_literal_concat(iseq, &((const pm_interpolated_string_node_t *) node)->parts, scope_node, NULL, NULL, false); - int line_number = pm_node_line_number(scope_node->parser, node); + int line_number = pm_node_line_number_cached(scope_node->parser, node, scope_node); return pm_static_literal_string(iseq, string, line_number); } case PM_INTERPOLATED_SYMBOL_NODE: { @@ -832,7 +910,7 @@ pm_static_literal_value(rb_iseq_t *iseq, const pm_node_t *node, const pm_scope_n return pm_source_file_value(cast, scope_node); } case PM_SOURCE_LINE_NODE: - return INT2FIX(pm_node_line_number(scope_node->parser, node)); + return INT2FIX(pm_node_line_number_cached(scope_node->parser, node, scope_node)); case PM_STRING_NODE: { const pm_string_node_t *cast = (const pm_string_node_t *) node; return parse_static_literal_string(iseq, scope_node, node, &cast->unescaped); @@ -851,7 +929,7 @@ pm_static_literal_value(rb_iseq_t *iseq, const pm_node_t *node, const pm_scope_n * A helper for converting a pm_location_t into a rb_code_location_t. */ static rb_code_location_t -pm_code_location(const pm_scope_node_t *scope_node, const pm_node_t *node) +pm_code_location(pm_scope_node_t *scope_node, const pm_node_t *node) { const pm_line_column_t start_location = PM_NODE_START_LINE_COLUMN(scope_node->parser, node); const pm_line_column_t end_location = PM_NODE_END_LINE_COLUMN(scope_node->parser, node); @@ -2404,7 +2482,7 @@ pm_compile_pattern_eqq_error(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const static int pm_compile_pattern_match(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, LINK_ANCHOR *const ret, LABEL *unmatched_label, bool in_single_pattern, bool use_deconstructed_cache, unsigned int base_index) { - LABEL *matched_label = NEW_LABEL(pm_node_line_number(scope_node->parser, node)); + LABEL *matched_label = NEW_LABEL(pm_node_line_number_cached(scope_node->parser, node, scope_node)); CHECK(pm_compile_pattern(iseq, scope_node, node, ret, matched_label, unmatched_label, in_single_pattern, use_deconstructed_cache, base_index)); PUSH_LABEL(ret, matched_label); return COMPILE_OK; @@ -2493,7 +2571,7 @@ pm_compile_pattern_constant(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const * responsible for compiling in those error raising instructions. */ static void -pm_compile_pattern_error_handler(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_node_t *node, LINK_ANCHOR *const ret, LABEL *done_label, bool popped) +pm_compile_pattern_error_handler(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_node_t *node, LINK_ANCHOR *const ret, LABEL *done_label, bool popped) { const pm_node_location_t location = PM_NODE_START_LOCATION(scope_node->parser, node); LABEL *key_error_label = NEW_LABEL(location.line); @@ -3243,6 +3321,7 @@ pm_scope_node_init(const pm_node_t *node, pm_scope_node_t *scope, pm_scope_node_ scope->constants = previous->constants; scope->coverage_enabled = previous->coverage_enabled; scope->script_lines = previous->script_lines; + scope->last_line = previous->last_line; } switch (PM_NODE_TYPE(node)) { @@ -3411,7 +3490,7 @@ pm_iseq_builtin_function_name(const pm_scope_node_t *scope_node, const pm_node_t // Compile Primitive.attr! :leaf, ... static int -pm_compile_builtin_attr(rb_iseq_t *iseq, const pm_scope_node_t *scope_node, const pm_arguments_node_t *arguments, const pm_node_location_t *node_location) +pm_compile_builtin_attr(rb_iseq_t *iseq, pm_scope_node_t *scope_node, const pm_arguments_node_t *arguments, const pm_node_location_t *node_location) { if (arguments == NULL) { COMPILE_ERROR(iseq, node_location->line, "attr!: no argument"); @@ -3699,7 +3778,7 @@ pm_compile_call(rb_iseq_t *iseq, const pm_call_node_t *call_node, LINK_ANCHOR *c } const pm_line_column_t start_location = PM_NODE_START_LINE_COLUMN(scope_node->parser, call_node); - const pm_line_column_t end_location = pm_line_offset_list_line_column(&scope_node->parser->line_offsets, end_cursor, scope_node->parser->start_line); + const pm_line_column_t end_location = pm_line_offset_list_line_column_cached(&scope_node->parser->line_offsets, end_cursor, scope_node->parser->start_line, &scope_node->last_line); code_location = (rb_code_location_t) { .beg_pos = { .lineno = start_location.line, .column = start_location.column }, @@ -3729,7 +3808,7 @@ pm_compile_call(rb_iseq_t *iseq, const pm_call_node_t *call_node, LINK_ANCHOR *c pm_scope_node_t next_scope_node; pm_scope_node_init(call_node->block, &next_scope_node, scope_node); - block_iseq = NEW_CHILD_ISEQ(&next_scope_node, make_name_for_block(iseq), ISEQ_TYPE_BLOCK, pm_node_line_number(scope_node->parser, call_node->block)); + block_iseq = NEW_CHILD_ISEQ(&next_scope_node, make_name_for_block(iseq), ISEQ_TYPE_BLOCK, pm_node_line_number_cached(scope_node->parser, call_node->block, scope_node)); pm_scope_node_destroy(&next_scope_node); ISEQ_COMPILE_DATA(iseq)->current_block = block_iseq; } @@ -4783,7 +4862,7 @@ pm_compile_destructured_param_locals(const pm_multi_target_node_t *node, st_tabl * as a positional parameter in a method, block, or lambda definition. */ static inline void -pm_compile_destructured_param_write(rb_iseq_t *iseq, const pm_required_parameter_node_t *node, LINK_ANCHOR *const ret, const pm_scope_node_t *scope_node) +pm_compile_destructured_param_write(rb_iseq_t *iseq, const pm_required_parameter_node_t *node, LINK_ANCHOR *const ret, pm_scope_node_t *scope_node) { const pm_node_location_t location = PM_NODE_START_LOCATION(scope_node->parser, node); pm_local_index_t index = pm_lookup_local_index(iseq, scope_node, node->name, 0); @@ -4799,7 +4878,7 @@ pm_compile_destructured_param_write(rb_iseq_t *iseq, const pm_required_parameter * for this simplified case. */ static void -pm_compile_destructured_param_writes(rb_iseq_t *iseq, const pm_multi_target_node_t *node, LINK_ANCHOR *const ret, const pm_scope_node_t *scope_node) +pm_compile_destructured_param_writes(rb_iseq_t *iseq, const pm_multi_target_node_t *node, LINK_ANCHOR *const ret, pm_scope_node_t *scope_node) { const pm_node_location_t location = PM_NODE_START_LOCATION(scope_node->parser, node); bool has_rest = (node->rest && PM_NODE_TYPE_P(node->rest, PM_SPLAT_NODE) && (((const pm_splat_node_t *) node->rest)->expression) != NULL); @@ -5435,7 +5514,7 @@ pm_compile_rescue(rb_iseq_t *iseq, const pm_begin_node_t *cast, const pm_node_lo &rescue_scope_node, rb_str_concat(rb_str_new2("rescue in "), ISEQ_BODY(iseq)->location.label), ISEQ_TYPE_RESCUE, - pm_node_line_number(parser, (const pm_node_t *) cast->rescue_clause) + pm_node_line_number_cached(parser, (const pm_node_t *) cast->rescue_clause, scope_node) ); pm_scope_node_destroy(&rescue_scope_node); @@ -5567,7 +5646,7 @@ pm_opt_str_freeze_p(const rb_iseq_t *iseq, const pm_call_node_t *node) * of the current iseq. */ static void -pm_compile_constant_read(rb_iseq_t *iseq, VALUE name, const pm_location_t *name_loc, uint32_t node_id, LINK_ANCHOR *const ret, const pm_scope_node_t *scope_node) +pm_compile_constant_read(rb_iseq_t *iseq, VALUE name, const pm_location_t *name_loc, uint32_t node_id, LINK_ANCHOR *const ret, pm_scope_node_t *scope_node) { const pm_node_location_t location = PM_LOCATION_START_LOCATION(scope_node->parser, name_loc, node_id); @@ -5667,7 +5746,7 @@ pm_compile_constant_path(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *co * Return the object that will be pushed onto the stack for the given node. */ static VALUE -pm_compile_shareable_constant_literal(rb_iseq_t *iseq, const pm_node_t *node, const pm_scope_node_t *scope_node) +pm_compile_shareable_constant_literal(rb_iseq_t *iseq, const pm_node_t *node, pm_scope_node_t *scope_node) { switch (PM_NODE_TYPE(node)) { case PM_TRUE_NODE: @@ -7507,7 +7586,7 @@ pm_compile_call_operator_write_node(rb_iseq_t *iseq, const pm_call_operator_writ * optimization entirely. */ static VALUE -pm_compile_case_node_dispatch(rb_iseq_t *iseq, VALUE dispatch, const pm_node_t *node, LABEL *label, const pm_scope_node_t *scope_node) +pm_compile_case_node_dispatch(rb_iseq_t *iseq, VALUE dispatch, const pm_node_t *node, LABEL *label, pm_scope_node_t *scope_node) { VALUE key = Qundef; switch (PM_NODE_TYPE(node)) { @@ -7590,7 +7669,7 @@ pm_compile_case_node(rb_iseq_t *iseq, const pm_case_node_t *cast, const pm_node_ const pm_when_node_t *clause = (const pm_when_node_t *) conditions->nodes[clause_index]; const pm_node_list_t *conditions = &clause->conditions; - int clause_lineno = pm_node_line_number(parser, (const pm_node_t *) clause); + int clause_lineno = pm_node_line_number_cached(parser, (const pm_node_t *) clause, scope_node); LABEL *label = NEW_LABEL(clause_lineno); PUSH_LABEL(body_seq, label); @@ -7623,7 +7702,7 @@ pm_compile_case_node(rb_iseq_t *iseq, const pm_case_node_t *cast, const pm_node_ PUSH_INSNL(cond_seq, cond_location, branchif, label); } else { - LABEL *next_label = NEW_LABEL(pm_node_line_number(parser, condition)); + LABEL *next_label = NEW_LABEL(pm_node_line_number_cached(parser, condition, scope_node)); pm_compile_branch_condition(iseq, cond_seq, condition, label, next_label, false, scope_node); PUSH_LABEL(cond_seq, next_label); } @@ -9705,7 +9784,7 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, pm_scope_node_t next_scope_node; pm_scope_node_init(node, &next_scope_node, scope_node); - int opening_lineno = pm_location_line_number(parser, &cast->opening_loc); + int opening_lineno = pm_location_line_number_cached(parser, &cast->opening_loc, scope_node); const rb_iseq_t *block = NEW_CHILD_ISEQ(&next_scope_node, make_name_for_block(iseq), ISEQ_TYPE_BLOCK, opening_lineno); pm_scope_node_destroy(&next_scope_node); @@ -10183,7 +10262,7 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, &rescue_scope_node, rb_str_concat(rb_str_new2("rescue in "), ISEQ_BODY(iseq)->location.label), ISEQ_TYPE_RESCUE, - pm_node_line_number(parser, cast->rescue_expression) + pm_node_line_number_cached(parser, cast->rescue_expression, scope_node) ); pm_scope_node_destroy(&rescue_scope_node); diff --git a/prism_compile.h b/prism_compile.h index ae1ec6bfd4f607..0b5fe685273a1e 100644 --- a/prism_compile.h +++ b/prism_compile.h @@ -61,6 +61,13 @@ typedef struct pm_scope_node { * the instructions pertaining to BEGIN{} nodes. */ struct iseq_link_anchor *pre_execution_anchor; + + /** + * Cached line hint for line offset list lookups. Since the compiler walks + * the AST roughly in source order, consecutive lookups tend to be for + * nearby byte offsets. This avoids repeated binary searches. + */ + size_t last_line; } pm_scope_node_t; void pm_scope_node_init(const pm_node_t *node, pm_scope_node_t *scope, pm_scope_node_t *previous); diff --git a/zjit.c b/zjit.c index 0c463334cde42a..68336858e958b1 100644 --- a/zjit.c +++ b/zjit.c @@ -19,6 +19,7 @@ #include "vm_insnhelper.h" #include "probes.h" #include "probes_helper.h" +#include "constant.h" #include "iseq.h" #include "ruby/debug.h" #include "internal/cont.h" diff --git a/zjit.rb b/zjit.rb index 82630612749f44..bdaf557426955c 100644 --- a/zjit.rb +++ b/zjit.rb @@ -44,6 +44,20 @@ def trace_exit_locations_enabled? Primitive.rb_zjit_trace_exit_locations_enabled_p end + # A directive for the compiler to fail to compile the call to this method. + # To show this to ZJIT, say `::RubyVM::ZJIT.induce_compile_failure!` verbatim. + # Other forms are too dynamic to detect during compilation. + # + # Actually running this method does nothing, whether ZJIT sees the call or not. + def induce_compile_failure! = nil + + # A directive for the compiler to exit out of compiled code at the call site of this method. + # To show this to ZJIT, say `::RubyVM::ZJIT.induce_side_exit!` verbatim. + # Other forms are too dynamic to detect during compilation. + # + # Actually running this method does nothing, whether ZJIT sees the call or not. + def induce_side_exit! = nil + # If --zjit-trace-exits is enabled parse the hashes from # Primitive.rb_zjit_get_exit_locations into a format readable # by Stackprof. This will allow us to find the exact location of a diff --git a/zjit/bindgen/src/main.rs b/zjit/bindgen/src/main.rs index 18436270d74ba2..7144e73737ba03 100644 --- a/zjit/bindgen/src/main.rs +++ b/zjit/bindgen/src/main.rs @@ -145,6 +145,7 @@ fn main() { .allowlist_function("rb_gc_location") .allowlist_function("rb_gc_writebarrier") .allowlist_function("rb_gc_writebarrier_remember") + .allowlist_function("rb_gc_register_mark_object") .allowlist_function("rb_zjit_writebarrier_check_immediate") // VALUE variables for Ruby class objects @@ -432,6 +433,7 @@ fn main() { .allowlist_function("rb_vm_base_ptr") .allowlist_function("rb_ec_stack_check") .allowlist_function("rb_vm_top_self") + .allowlist_function("rb_const_lookup") // We define these manually, don't import them .blocklist_type("VALUE") diff --git a/zjit/src/cruby.rs b/zjit/src/cruby.rs index 38b99edb4bd3f6..ed59dd7d9f23c2 100644 --- a/zjit/src/cruby.rs +++ b/zjit/src/cruby.rs @@ -1544,6 +1544,10 @@ pub(crate) mod ids { name: RUBY_FL_FREEZE name: RUBY_ELTS_SHARED name: VM_FRAME_FLAG_MODIFIED_BLOCK_PARAM + name: RubyVM + name: ZJIT + name: induce_side_exit_bang content: b"induce_side_exit!" + name: induce_compile_failure_bang content: b"induce_compile_failure!" } /// Get an CRuby `ID` to an interned string, e.g. a particular method name. @@ -1551,7 +1555,7 @@ pub(crate) mod ids { ($id_name:ident) => {{ let id = $crate::cruby::ids::$id_name.load(std::sync::atomic::Ordering::Relaxed); debug_assert_ne!(0, id, "ids module should be initialized"); - ID(id) + $crate::cruby::ID(id) }} } pub(crate) use ID; diff --git a/zjit/src/cruby_bindings.inc.rs b/zjit/src/cruby_bindings.inc.rs index baf39100f48d71..76fdcd93f7a1be 100644 --- a/zjit/src/cruby_bindings.inc.rs +++ b/zjit/src/cruby_bindings.inc.rs @@ -1471,6 +1471,20 @@ pub const SHAPE_ID_FL_TOO_COMPLEX: shape_id_fl_type = 67108864; pub const SHAPE_ID_FL_NON_CANONICAL_MASK: shape_id_fl_type = 50331648; pub const SHAPE_ID_FLAGS_MASK: shape_id_fl_type = 133693440; pub type shape_id_fl_type = u32; +pub const CONST_DEPRECATED: rb_const_flag_t = 256; +pub const CONST_VISIBILITY_MASK: rb_const_flag_t = 255; +pub const CONST_PUBLIC: rb_const_flag_t = 0; +pub const CONST_PRIVATE: rb_const_flag_t = 1; +pub const CONST_VISIBILITY_MAX: rb_const_flag_t = 2; +pub type rb_const_flag_t = u32; +#[repr(C)] +pub struct rb_const_entry_struct { + pub flag: rb_const_flag_t, + pub line: ::std::os::raw::c_int, + pub value: VALUE, + pub file: VALUE, +} +pub type rb_const_entry_t = rb_const_entry_struct; #[repr(C)] pub struct rb_cvar_class_tbl_entry { pub index: u32, @@ -1912,6 +1926,7 @@ unsafe extern "C" { pub fn rb_gc_location(obj: VALUE) -> VALUE; pub fn rb_gc_enable() -> VALUE; pub fn rb_gc_disable() -> VALUE; + pub fn rb_gc_register_mark_object(object: VALUE); pub fn rb_gc_writebarrier(old: VALUE, young: VALUE); pub fn rb_class_get_superclass(klass: VALUE) -> VALUE; pub fn rb_funcallv( @@ -2057,6 +2072,7 @@ unsafe extern "C" { original_shape_id: shape_id_t, id: ID, ) -> shape_id_t; + pub fn rb_const_lookup(klass: VALUE, id: ID) -> *mut rb_const_entry_t; pub fn rb_ivar_get_at_no_ractor_check(obj: VALUE, index: attr_index_t) -> VALUE; pub fn rb_gvar_get(arg1: ID) -> VALUE; pub fn rb_gvar_set(arg1: ID, arg2: VALUE) -> VALUE; diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 090bf579661eec..7fda53f3c9c371 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -7,10 +7,12 @@ #![allow(clippy::match_like_matches_macro)] use crate::{ backend::lir::C_ARG_OPNDS, - cast::IntoUsize, codegen::local_idx_to_ep_offset, cruby::*, invariants::{self, has_singleton_class_of}, payload::{get_or_create_iseq_payload, IseqPayload}, options::{debug, get_option, DumpHIR}, state::ZJITState, json::Json + cast::IntoUsize, codegen::local_idx_to_ep_offset, cruby::*, invariants::{self, has_singleton_class_of}, payload::{get_or_create_iseq_payload, IseqPayload}, options::{debug, get_option, DumpHIR}, state::ZJITState, json::Json, + state, }; use std::{ - cell::RefCell, collections::{BTreeSet, HashMap, HashSet, VecDeque}, ffi::{c_void, c_uint, c_int, CStr}, fmt::Display, mem::{align_of, size_of}, ptr, slice::Iter + cell::RefCell, collections::{BTreeSet, HashMap, HashSet, VecDeque}, ffi::{c_void, c_uint, c_int, CStr}, fmt::Display, mem::{align_of, size_of}, ptr, slice::Iter, + sync::atomic::Ordering, }; use crate::hir_type::{Type, types}; use crate::hir_effect::{Effect, abstract_heaps, effects}; @@ -529,6 +531,7 @@ pub enum SideExitReason { SplatKwNotNilOrHash, SplatKwPolymorphic, SplatKwNotProfiled, + DirectiveInduced, } #[derive(Debug, Clone, Copy)] @@ -6590,6 +6593,7 @@ pub enum ParseError { MalformedIseq(u32), // insn_idx into iseq_encoded Validation(ValidationError), NotAllowed, + DirectiveInduced, } /// Return the number of locals in the current ISEQ (includes parameters) @@ -7120,7 +7124,32 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result { } YARVINSN_opt_getconstant_path => { let ic = get_arg(pc, 0).as_ptr(); - state.stack_push(fun.push_insn(block, Insn::GetConstantPath { ic, state: exit_id })); + let get_const_path = fun.push_insn(block, Insn::GetConstantPath { ic, state: exit_id }); + state.stack_push(get_const_path); + + // Check for `::RubyVM::ZJIT` for directives + unsafe { + let mut current_segment = (*ic).segments; + let mut segments = [ID(0); 4 /* expected segment length */]; + for segment in segments.iter_mut() { + *segment = current_segment.read(); + if *segment == ID(0) { + break; + } + current_segment = current_segment.add(1); + } + if [ID!(NULL), ID!(RubyVM), ID!(ZJIT), ID(0)] == segments { + debug_assert_ne!(ID!(NULL), ID(0)); + let ruby_vm_mod = rb_const_lookup(rb_cObject, ID!(RubyVM)); + if !ruby_vm_mod.is_null() && (*ruby_vm_mod).value == rb_cRubyVM { + let zjit_module = VALUE(state::ZJIT_MODULE.load(Ordering::Relaxed)); + let lookedup_module = rb_const_lookup(rb_cRubyVM, ID!(ZJIT)); + if !lookedup_module.is_null() && (*lookedup_module).value == zjit_module { + fun.insn_types[get_const_path.0] = Type::from_value(zjit_module); + } + } + } + } } YARVINSN_branchunless | YARVINSN_branchunless_without_ints => { if opcode == YARVINSN_branchunless { @@ -7581,6 +7610,28 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result { break; // End the block } let argc = unsafe { vm_ci_argc((*cd).ci) }; + let mid = unsafe { rb_vm_ci_mid(call_info) }; + + // Check for calls to directives + if argc == 0 + && (mid == ID!(induce_side_exit_bang) || mid == ID!(induce_compile_failure_bang)) + && fun.type_of(state.stack_top()?) + .ruby_object() + .is_some_and(|obj| obj == VALUE(state::ZJIT_MODULE.load(Ordering::Relaxed))) + { + + if mid == ID!(induce_side_exit_bang) + && state::zjit_module_method_match_serial(ID!(induce_side_exit_bang), &state::INDUCE_SIDE_EXIT_SERIAL) + { + fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::DirectiveInduced }); + break; // End the block + } + if mid == ID!(induce_compile_failure_bang) + && state::zjit_module_method_match_serial(ID!(induce_compile_failure_bang), &state::INDUCE_COMPILE_FAILURE_SERIAL) + { + return Err(ParseError::DirectiveInduced); + } + } { fn new_branch_block( @@ -7619,10 +7670,16 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result { let entry_args = state.as_args(self_param); if let Some(summary) = fun.polymorphic_summary(&profiles, recv, exit_state.insn_idx) { let join_block = insn_idx_to_block.get(&insn_idx).copied().unwrap_or_else(|| fun.new_block(insn_idx)); - // TODO(max): Only iterate over unique classes, not unique (class, shape) pairs. + // Dedup by expected type so immediate/heap variants + // under the same Ruby class can still get separate branches. + let mut seen_types = Vec::with_capacity(summary.buckets().len()); for &profiled_type in summary.buckets() { if profiled_type.is_empty() { break; } let expected = Type::from_profiled_type(profiled_type); + if seen_types.iter().any(|ty: &Type| ty.bit_equal(expected)) { + continue; + } + seen_types.push(expected); let has_type = fun.push_insn(block, Insn::HasType { val: recv, expected }); let iftrue_block = new_branch_block(&mut fun, cd, argc as usize, opcode, expected, branch_insn_idx, &exit_state, locals_count, stack_count, join_block); diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index b644936c0c4e36..5df49c5ae6ba8f 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -7406,25 +7406,17 @@ mod hir_opt_tests { bb3(v9:BasicObject, v10:BasicObject): v15:CBool = HasType v10, ObjectSubclass[class_exact:C] IfTrue v15, bb5(v9, v10, v10) - v24:CBool = HasType v10, ObjectSubclass[class_exact:C] - IfTrue v24, bb6(v9, v10, v10) - v33:BasicObject = Send v10, :foo # SendFallbackReason: SendWithoutBlock: polymorphic fallback - Jump bb4(v9, v10, v33) + v24:BasicObject = Send v10, :foo # SendFallbackReason: SendWithoutBlock: polymorphic fallback + Jump bb4(v9, v10, v24) bb5(v16:BasicObject, v17:BasicObject, v18:BasicObject): v20:ObjectSubclass[class_exact:C] = RefineType v18, ObjectSubclass[class_exact:C] PatchPoint NoSingletonClass(C@0x1008) PatchPoint MethodRedefined(C@0x1008, foo@0x1010, cme:0x1018) - v46:BasicObject = GetIvar v20, :@foo - Jump bb4(v16, v17, v46) - bb6(v25:BasicObject, v26:BasicObject, v27:BasicObject): - v29:ObjectSubclass[class_exact:C] = RefineType v27, ObjectSubclass[class_exact:C] - PatchPoint NoSingletonClass(C@0x1008) - PatchPoint MethodRedefined(C@0x1008, foo@0x1010, cme:0x1018) - v49:BasicObject = GetIvar v29, :@foo - Jump bb4(v25, v26, v49) - bb4(v35:BasicObject, v36:BasicObject, v37:BasicObject): + v37:BasicObject = GetIvar v20, :@foo + Jump bb4(v16, v17, v37) + bb4(v26:BasicObject, v27:BasicObject, v28:BasicObject): CheckInterrupts - Return v37 + Return v28 "); } @@ -13834,6 +13826,210 @@ mod hir_opt_tests { "); } + #[test] + fn specialize_polymorphic_send_fixnum_and_bignum() { + // Fixnum and Bignum both have class Integer, but they should be + // treated as different types for polymorphic dispatch because + // Fixnum is an immediate and Bignum is a heap object. + set_call_threshold(4); + eval(" + def test x + x.to_s + end + + fixnum = 1 + bignum = 10**100 + test(fixnum) + test(bignum) + test(fixnum) + test(bignum) + "); + assert_snapshot!(hir_string("test"), @" + fn test@:3: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:CPtr = LoadSP + v3:BasicObject = LoadField v2, :x@0x1000 + Jump bb3(v1, v3) + bb2(): + EntryPoint JIT(0) + v6:BasicObject = LoadArg :self@0 + v7:BasicObject = LoadArg :x@1 + Jump bb3(v6, v7) + bb3(v9:BasicObject, v10:BasicObject): + v15:CBool = HasType v10, Fixnum + IfTrue v15, bb5(v9, v10, v10) + v24:CBool = HasType v10, Bignum + IfTrue v24, bb6(v9, v10, v10) + v33:BasicObject = Send v10, :to_s # SendFallbackReason: SendWithoutBlock: polymorphic fallback + Jump bb4(v9, v10, v33) + bb5(v16:BasicObject, v17:BasicObject, v18:BasicObject): + v20:Fixnum = RefineType v18, Fixnum + PatchPoint MethodRedefined(Integer@0x1008, to_s@0x1010, cme:0x1018) + v46:StringExact = CCallVariadic v20, :Integer#to_s@0x1040 + Jump bb4(v16, v17, v46) + bb6(v25:BasicObject, v26:BasicObject, v27:BasicObject): + v29:Bignum = RefineType v27, Bignum + PatchPoint MethodRedefined(Integer@0x1008, to_s@0x1010, cme:0x1018) + v49:StringExact = CCallVariadic v29, :Integer#to_s@0x1040 + Jump bb4(v25, v26, v49) + bb4(v35:BasicObject, v36:BasicObject, v37:BasicObject): + CheckInterrupts + Return v37 + "); + } + + #[test] + fn specialize_polymorphic_send_flonum_and_heap_float() { + set_call_threshold(4); + eval(" + def test x + x.to_s + end + + flonum = 1.5 + heap_float = 1.7976931348623157e+308 + test(flonum) + test(heap_float) + test(flonum) + test(heap_float) + "); + assert_snapshot!(hir_string("test"), @" + fn test@:3: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:CPtr = LoadSP + v3:BasicObject = LoadField v2, :x@0x1000 + Jump bb3(v1, v3) + bb2(): + EntryPoint JIT(0) + v6:BasicObject = LoadArg :self@0 + v7:BasicObject = LoadArg :x@1 + Jump bb3(v6, v7) + bb3(v9:BasicObject, v10:BasicObject): + v15:CBool = HasType v10, Flonum + IfTrue v15, bb5(v9, v10, v10) + v24:CBool = HasType v10, HeapFloat + IfTrue v24, bb6(v9, v10, v10) + v33:BasicObject = Send v10, :to_s # SendFallbackReason: SendWithoutBlock: polymorphic fallback + Jump bb4(v9, v10, v33) + bb5(v16:BasicObject, v17:BasicObject, v18:BasicObject): + v20:Flonum = RefineType v18, Flonum + PatchPoint MethodRedefined(Float@0x1008, to_s@0x1010, cme:0x1018) + v46:BasicObject = CCallWithFrame v20, :Float#to_s@0x1040 + Jump bb4(v16, v17, v46) + bb6(v25:BasicObject, v26:BasicObject, v27:BasicObject): + v29:HeapFloat = RefineType v27, HeapFloat + PatchPoint MethodRedefined(Float@0x1008, to_s@0x1010, cme:0x1018) + v49:BasicObject = CCallWithFrame v29, :Float#to_s@0x1040 + Jump bb4(v25, v26, v49) + bb4(v35:BasicObject, v36:BasicObject, v37:BasicObject): + CheckInterrupts + Return v37 + "); + } + + #[test] + fn specialize_polymorphic_send_static_and_dynamic_symbol() { + set_call_threshold(4); + eval(" + def test x + x.to_s + end + + static_sym = :foo + dynamic_sym = (\"zjit_dynamic_\" + Object.new.object_id.to_s).to_sym + test static_sym + test dynamic_sym + test static_sym + test dynamic_sym + "); + assert_snapshot!(hir_string("test"), @" + fn test@:3: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:CPtr = LoadSP + v3:BasicObject = LoadField v2, :x@0x1000 + Jump bb3(v1, v3) + bb2(): + EntryPoint JIT(0) + v6:BasicObject = LoadArg :self@0 + v7:BasicObject = LoadArg :x@1 + Jump bb3(v6, v7) + bb3(v9:BasicObject, v10:BasicObject): + v15:CBool = HasType v10, StaticSymbol + IfTrue v15, bb5(v9, v10, v10) + v24:CBool = HasType v10, DynamicSymbol + IfTrue v24, bb6(v9, v10, v10) + v33:BasicObject = Send v10, :to_s # SendFallbackReason: SendWithoutBlock: polymorphic fallback + Jump bb4(v9, v10, v33) + bb5(v16:BasicObject, v17:BasicObject, v18:BasicObject): + v20:StaticSymbol = RefineType v18, StaticSymbol + PatchPoint MethodRedefined(Symbol@0x1008, to_s@0x1010, cme:0x1018) + v48:StringExact = InvokeBuiltin leaf , v20 + Jump bb4(v16, v17, v48) + bb6(v25:BasicObject, v26:BasicObject, v27:BasicObject): + v29:DynamicSymbol = RefineType v27, DynamicSymbol + PatchPoint MethodRedefined(Symbol@0x1008, to_s@0x1010, cme:0x1018) + v49:StringExact = InvokeBuiltin leaf , v29 + Jump bb4(v25, v26, v49) + bb4(v35:BasicObject, v36:BasicObject, v37:BasicObject): + CheckInterrupts + Return v37 + "); + } + + #[test] + fn specialize_polymorphic_send_iseq_duplicate_class_profiles() { + set_call_threshold(4); + eval(" + class C + def foo = 3 + end + + O1 = C.new + O1.instance_variable_set(:@foo, 1) + O2 = C.new + O2.instance_variable_set(:@bar, 2) + + def test o + o.foo + end + + test O1; test O2; test O1; test O2 + "); + assert_snapshot!(hir_string("test"), @" + fn test@:12: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:CPtr = LoadSP + v3:BasicObject = LoadField v2, :o@0x1000 + Jump bb3(v1, v3) + bb2(): + EntryPoint JIT(0) + v6:BasicObject = LoadArg :self@0 + v7:BasicObject = LoadArg :o@1 + Jump bb3(v6, v7) + bb3(v9:BasicObject, v10:BasicObject): + v15:CBool = HasType v10, ObjectSubclass[class_exact:C] + IfTrue v15, bb5(v9, v10, v10) + v24:BasicObject = Send v10, :foo # SendFallbackReason: SendWithoutBlock: polymorphic fallback + Jump bb4(v9, v10, v24) + bb5(v16:BasicObject, v17:BasicObject, v18:BasicObject): + PatchPoint NoSingletonClass(C@0x1008) + PatchPoint MethodRedefined(C@0x1008, foo@0x1010, cme:0x1018) + v38:Fixnum[3] = Const Value(3) + Jump bb4(v16, v17, v38) + bb4(v26:BasicObject, v27:BasicObject, v28:BasicObject): + CheckInterrupts + Return v28 + "); + } + #[test] fn upgrade_self_type_to_heap_after_setivar() { eval(" diff --git a/zjit/src/hir/tests.rs b/zjit/src/hir/tests.rs index b26f5a989c1374..c6538a9eda1623 100644 --- a/zjit/src/hir/tests.rs +++ b/zjit/src/hir/tests.rs @@ -4791,7 +4791,132 @@ pub mod hir_build_tests { Jump bb8(v67, v78) "); } - } + + #[test] + fn test_induce_side_exit() { + eval(" + class NonTopLexicalScope + RubyVM = 0 + def test + RubyVM::ZJIT.induce_side_exit! # lexical scope dependant -- should not recognize + ::RubyVM::ZJIT.induce_side_exit! + end + end + "); + assert_snapshot!(hir_string_proc("NonTopLexicalScope.instance_method(:test)"), @" + fn test@:5: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb3(v1) + bb2(): + EntryPoint JIT(0) + v4:BasicObject = LoadArg :self@0 + Jump bb3(v4) + bb3(v6:BasicObject): + v10:BasicObject = GetConstantPath 0x1000 + v12:BasicObject = Send v10, :induce_side_exit! # SendFallbackReason: Uncategorized(opt_send_without_block) + v16:BasicObject = GetConstantPath 0x1000 + SideExit DirectiveInduced + "); + } + + #[test] + fn test_induce_side_exit_sensitive_to_constant_state() { + eval(" + def test = ::RubyVM::ZJIT.induce_side_exit! + "); + assert!(hir_string("test").contains("SideExit DirectiveInduced")); + eval(" + class RubyVM + remove_const(:ZJIT) + end + "); + let hir_after_removal = hir_string("test"); + assert_eq!(false, hir_string("test").contains("SideExit DirectiveInduced"), "should not work when the constant lookup would fail"); + assert_snapshot!(hir_after_removal, @" + fn test@:2: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb3(v1) + bb2(): + EntryPoint JIT(0) + v4:BasicObject = LoadArg :self@0 + Jump bb3(v4) + bb3(v6:BasicObject): + v10:BasicObject = GetConstantPath 0x1000 + v12:BasicObject = Send v10, :induce_side_exit! # SendFallbackReason: Uncategorized(opt_send_without_block) + CheckInterrupts + Return v12 + "); + } + + #[test] + fn test_induce_side_exit_doesnt_work_when_method_after_undef() { + eval(" + class << RubyVM::ZJIT + undef :induce_side_exit! + end + def test = ::RubyVM::ZJIT.induce_side_exit! + "); + assert_eq!(false, hir_string("test").contains("SideExit DirectiveInduced"), "should not work after undef"); + } + + #[test] + fn test_induce_compile_failure_does_not_trigger_autoload() { + eval(" + class RubyVM + remove_const(:ZJIT) + autoload :ZJIT, 'a-file-that-does-not-exist-as-a-trap' + end + def test = ::RubyVM::ZJIT.induce_compile_failure! + "); + assert_snapshot!(hir_string("test"), @" + fn test@:6: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb3(v1) + bb2(): + EntryPoint JIT(0) + v4:BasicObject = LoadArg :self@0 + Jump bb3(v4) + bb3(v6:BasicObject): + v10:BasicObject = GetConstantPath 0x1000 + v12:BasicObject = Send v10, :induce_compile_failure! # SendFallbackReason: Uncategorized(opt_send_without_block) + CheckInterrupts + Return v12 + "); + } + + #[test] + fn test_induce_compile_failure_checks_full_const_path() { + eval("def test = ::RubyVM::ZJIT::TooDeep.induce_compile_failure!"); + assert_snapshot!(hir_string("test"), @" + fn test@:1: + bb1(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb3(v1) + bb2(): + EntryPoint JIT(0) + v4:BasicObject = LoadArg :self@0 + Jump bb3(v4) + bb3(v6:BasicObject): + v10:BasicObject = GetConstantPath 0x1000 + v12:BasicObject = Send v10, :induce_compile_failure! # SendFallbackReason: Uncategorized(opt_send_without_block) + CheckInterrupts + Return v12 + "); + } + + #[test] + fn test_induce_compile_failure() { + eval("def test = ::RubyVM::ZJIT.induce_compile_failure!"); + assert_compile_fails("test", ParseError::DirectiveInduced); + } +} /// Test successor and predecessor set computations. #[cfg(test)] diff --git a/zjit/src/state.rs b/zjit/src/state.rs index df8239699ff5e2..7fa1eb862dc163 100644 --- a/zjit/src/state.rs +++ b/zjit/src/state.rs @@ -3,11 +3,14 @@ use crate::codegen::{gen_entry_trampoline, gen_exit_trampoline, gen_exit_trampoline_with_counter, gen_function_stub_hit_trampoline}; use crate::cruby::{self, rb_bug_panic_hook, rb_vm_insn_count, src_loc, EcPtr, Qnil, Qtrue, rb_vm_insn_addr2opcode, rb_profile_frames, VALUE, VM_INSTRUCTION_SIZE, size_t, rb_gc_mark, with_vm_lock, rust_str_to_id, rb_funcallv, rb_const_get, rb_cRubyVM}; use crate::cruby_methods; +use cruby::{ID, rb_callable_method_entry, get_def_method_serial, rb_gc_register_mark_object}; +use std::sync::atomic::Ordering; use crate::invariants::Invariants; use crate::asm::CodeBlock; use crate::options::{get_option, rb_zjit_prepare_options}; use crate::stats::{Counters, InsnCounters, SideExitLocations}; use crate::virtualmem::CodePtr; +use std::sync::atomic::AtomicUsize; use std::collections::HashMap; use std::ptr::null; @@ -301,6 +304,25 @@ impl ZJITState { } } +/// The `::RubyVM::ZJIT` module. +pub static ZJIT_MODULE: AtomicUsize = AtomicUsize::new(!0); +/// Serial of the canonical version of `` right after VM boot. +pub static INDUCE_SIDE_EXIT_SERIAL: AtomicUsize = AtomicUsize::new(!0); +/// Serial of the canonical version of `` right after VM boot. +pub static INDUCE_COMPILE_FAILURE_SERIAL: AtomicUsize = AtomicUsize::new(!0); + +/// Check if a method, `method_id`, currently exists on `ZJIT.singleton_class` and has the `expected_serial`. +pub fn zjit_module_method_match_serial(method_id: ID, expected_serial: &AtomicUsize) -> bool { + let zjit_module_singleton = VALUE(ZJIT_MODULE.load(Ordering::Relaxed)).class_of(); + let cme = unsafe { rb_callable_method_entry(zjit_module_singleton, method_id) }; + if cme.is_null() { + false + } else { + let serial = unsafe { get_def_method_serial((*cme).def) }; + serial == expected_serial.load(std::sync::atomic::Ordering::Relaxed) + } +} + /// Initialize IDs and annotate builtin C method entries. /// Must be called at boot before ruby_init_prelude() since the prelude /// could redefine core methods (e.g. Kernel.prepend via bundler). @@ -318,6 +340,25 @@ pub extern "C" fn rb_zjit_init_builtin_cmes() { let method_annotations = cruby_methods::init(); unsafe { ZJIT_STATE = Initialized(method_annotations); } + + // Boot time setup for compiler directives + unsafe { + let zjit_module = rb_const_get(rb_cRubyVM, rust_str_to_id("ZJIT")); + + let cme = rb_callable_method_entry(zjit_module.class_of(), ID!(induce_side_exit_bang)); + assert!(! cme.is_null(), "RubyVM::ZJIT.induce_side_exit! should exist on boot"); + let serial = get_def_method_serial((*cme).def) ; + INDUCE_SIDE_EXIT_SERIAL.store(serial, Ordering::Relaxed); + + let cme = rb_callable_method_entry(zjit_module.class_of(), ID!(induce_compile_failure_bang)); + assert!(! cme.is_null(), "RubyVM::ZJIT.induce_compile_failure! should exist on boot"); + let serial = get_def_method_serial((*cme).def) ; + INDUCE_COMPILE_FAILURE_SERIAL.store(serial, Ordering::Relaxed); + + // Root and pin the module since we'll be doing object identity comparisons. + ZJIT_MODULE.store(zjit_module.0, Ordering::Relaxed); + rb_gc_register_mark_object(zjit_module); + } } /// Initialize ZJIT at boot. This is called even if ZJIT is disabled. diff --git a/zjit/src/stats.rs b/zjit/src/stats.rs index db204f032f9340..47a434360da5ff 100644 --- a/zjit/src/stats.rs +++ b/zjit/src/stats.rs @@ -231,6 +231,7 @@ make_counters! { exit_splatkw_not_nil_or_hash, exit_splatkw_polymorphic, exit_splatkw_not_profiled, + exit_directive_induced, } // Send fallback counters that are summed as dynamic_send_count @@ -335,6 +336,7 @@ make_counters! { compile_error_parse_stack_underflow, compile_error_parse_malformed_iseq, compile_error_parse_not_allowed, + compile_error_parse_directive_induced, compile_error_validation_block_has_no_terminator, compile_error_validation_terminator_not_at_end, compile_error_validation_mismatched_block_arity, @@ -521,6 +523,7 @@ pub fn exit_counter_for_compile_error(compile_error: &CompileError) -> Counter { StackUnderflow(_) => compile_error_parse_stack_underflow, MalformedIseq(_) => compile_error_parse_malformed_iseq, NotAllowed => compile_error_parse_not_allowed, + DirectiveInduced => compile_error_parse_directive_induced, Validation(validation) => match validation { BlockHasNoTerminator(_) => compile_error_validation_block_has_no_terminator, TerminatorNotAtEnd(_, _, _) => compile_error_validation_terminator_not_at_end, @@ -596,6 +599,7 @@ pub fn side_exit_counter(reason: crate::hir::SideExitReason) -> Counter { SplatKwNotNilOrHash => exit_splatkw_not_nil_or_hash, SplatKwPolymorphic => exit_splatkw_polymorphic, SplatKwNotProfiled => exit_splatkw_not_profiled, + DirectiveInduced => exit_directive_induced, PatchPoint(Invariant::BOPRedefined { .. }) => exit_patchpoint_bop_redefined, PatchPoint(Invariant::MethodRedefined { .. })