diff --git a/ext/rbs_extension/parser.c b/ext/rbs_extension/parser.c index 15400209e..0d019af00 100644 --- a/ext/rbs_extension/parser.c +++ b/ext/rbs_extension/parser.c @@ -966,13 +966,44 @@ static VALUE parse_simple(parserstate *state) { range rg; rg.start = state->current_token.range.start; VALUE types = rb_ary_new(); + VALUE rest_type = Qnil; + // parses type args with additional rest type arg if (state->next_token.type != pRBRACKET) { - parse_type_list(state, pRBRACKET, types); + while (true) { + rb_ary_push(types, parse_type(state)); + + if (state->next_token.type == pCOMMA) { + parser_advance(state); + if (state->next_token.type == pRBRACKET) { + break; + } else if (state->next_token.type == pSTAR) { + parser_advance(state); + rest_type = parse_type(state); + if (state->next_token.type != pRBRACKET) { + raise_syntax_error( + state, + state->next_token, + "tuple end expected" + ); + } + break; + } + } else { + if (state->next_token.type == pRBRACKET) { + break; + } + raise_syntax_error( + state, + state->next_token, + "comma delimited type list is expected" + ); + } + } } parser_advance_assert(state, pRBRACKET); rg.end = state->current_token.range.end; - return rbs_tuple(types, rbs_new_location(state->buffer, rg)); + return rbs_tuple2(types, rbs_new_location(state->buffer, rg), rest_type); } case pAREF_OPR: { return rbs_tuple(rb_ary_new(), rbs_new_location(state->buffer, state->current_token.range)); diff --git a/ext/rbs_extension/ruby_objs.c b/ext/rbs_extension/ruby_objs.c index 2a92d1c56..a7d646b07 100644 --- a/ext/rbs_extension/ruby_objs.c +++ b/ext/rbs_extension/ruby_objs.c @@ -132,6 +132,21 @@ VALUE rbs_tuple(VALUE types, VALUE location) { ); } +VALUE rbs_tuple2(VALUE types, VALUE location, VALUE rest_type) { + VALUE args = rb_hash_new(); + rb_hash_aset(args, ID2SYM(rb_intern("types")), types); + rb_hash_aset(args, ID2SYM(rb_intern("location")), location); + if (rest_type) { + rb_hash_aset(args, ID2SYM(rb_intern("rest_type")), rest_type); + } + + return CLASS_NEW_INSTANCE( + RBS_Types_Tuple, + 1, + &args + ); +} + VALUE rbs_optional(VALUE type, VALUE location) { VALUE args = rb_hash_new(); rb_hash_aset(args, ID2SYM(rb_intern("type")), type); diff --git a/ext/rbs_extension/ruby_objs.h b/ext/rbs_extension/ruby_objs.h index d9ee32262..220924f67 100644 --- a/ext/rbs_extension/ruby_objs.h +++ b/ext/rbs_extension/ruby_objs.h @@ -39,6 +39,7 @@ VALUE rbs_optional(VALUE type, VALUE location); VALUE rbs_proc(VALUE function, VALUE block, VALUE location, VALUE self_type); VALUE rbs_record(VALUE fields, VALUE location); VALUE rbs_tuple(VALUE types, VALUE location); +VALUE rbs_tuple2(VALUE types, VALUE location, VALUE rest_type); VALUE rbs_type_name(VALUE namespace, VALUE name); VALUE rbs_union(VALUE types, VALUE location); VALUE rbs_variable(VALUE name, VALUE location); diff --git a/lib/rbs/types.rb b/lib/rbs/types.rb index 9cb9a76a6..3596833b1 100644 --- a/lib/rbs/types.rb +++ b/lib/rbs/types.rb @@ -433,15 +433,17 @@ def map_type(&block) class Tuple attr_reader :types + attr_reader :rest_type attr_reader :location - def initialize(types:, location:) + def initialize(types:, location:, rest_type: nil) @types = types + @rest_type = rest_type @location = location end def ==(other) - other.is_a?(Tuple) && other.types == types + other.is_a?(Tuple) && other.types == types && other.rest_type == rest_type end alias eql? == @@ -455,15 +457,17 @@ def free_variables(set = Set.new) types.each do |type| type.free_variables set end + rest_type&.free_variables end end def to_json(state = _ = nil) - { class: :tuple, types: types, location: location }.to_json(state) + { class: :tuple, types: types, rest_type: rest_type, location: location }.to_json(state) end def sub(s) self.class.new(types: types.map {|ty| ty.sub(s) }, + rest_type: rest_type&.sub(s), location: location) end @@ -471,13 +475,14 @@ def to_s(level = 0) if types.empty? "[ ]" else - "[ #{types.join(", ")} ]" + "[ #{types.join(", ")}#{", *#{rest_type}" if rest_type} ]" end end def each_type(&block) if block types.each(&block) + block.call(rest_type) if rest_type else enum_for :each_type end @@ -486,6 +491,7 @@ def each_type(&block) def map_type_name(&block) Tuple.new( types: types.map {|type| type.map_type_name(&block) }, + rest_type: rest_type.map_type_name(&block), location: location ) end @@ -494,6 +500,7 @@ def map_type(&block) if block Tuple.new( types: types.map {|type| yield type }, + rest_type: (yield(rest_type) if rest_type), location: location ) else diff --git a/test/rbs/schema_test.rb b/test/rbs/schema_test.rb index 457c7a8b9..2d276e441 100644 --- a/test/rbs/schema_test.rb +++ b/test/rbs/schema_test.rb @@ -96,6 +96,7 @@ def test_type_schema refute_type parse_type("Foo"), "alias" assert_type parse_type("[Integer]"), "tuple" + assert_type parse_type("[Integer, *String]"), "tuple" refute_type parse_type("Foo"), "tuple" assert_type parse_type("{ id: Integer, name: String }"), "record" diff --git a/test/rbs/type_parsing_test.rb b/test/rbs/type_parsing_test.rb index 27ea589eb..813cf1b41 100644 --- a/test/rbs/type_parsing_test.rb +++ b/test/rbs/type_parsing_test.rb @@ -203,6 +203,15 @@ def test_tuple assert_equal [], type.types assert_equal "[]", type.location.source end + + Parser.parse_type("[untyped, *untyped]").yield_self do |type| + assert_instance_of Types::Tuple, type + assert_equal [ + Types::Bases::Any.new(location: nil), + ], type.types + assert_equal Types::Bases::Any.new(location: nil), type.rest_type + assert_equal "[untyped, *untyped]", type.location.source + end end def test_union_intersection