Skip to content

Commit 49db85c

Browse files
committed
fix: enhance parser to support ternary, splat args, and statement expressions
- Add ternary operator (? :) parsing in ExpressionParser - Support double splat (**opts) and single splat (*args) in method calls - Support keyword arguments (name: value) in method calls - Allow case/if/unless/begin as assignment right-hand side values - Improve generic type compatibility (Array[untyped] with Array[T]) Fixes type inference errors in keyword_args samples.
1 parent 9cc5177 commit 49db85c

File tree

3 files changed

+176
-54
lines changed

3 files changed

+176
-54
lines changed

lib/t_ruby/compiler.rb

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,19 @@ def types_compatible?(inferred, declared)
319319
# Subtype relationships
320320
return true if subtype_of?(inferred, declared)
321321

322+
# Handle generic types (e.g., Array[untyped] is compatible with Array[String])
323+
if inferred.include?("[") && declared.include?("[")
324+
inferred_base = inferred.split("[").first
325+
declared_base = declared.split("[").first
326+
if inferred_base == declared_base
327+
# Extract type arguments
328+
inferred_args = inferred[/\[(.+)\]/, 1]
329+
declared_args = declared[/\[(.+)\]/, 1]
330+
# untyped type argument is compatible with any type argument
331+
return true if inferred_args == "untyped" || declared_args == "untyped"
332+
end
333+
end
334+
322335
# Handle union types in declared
323336
if declared.include?("|")
324337
declared_types = declared.split("|").map(&:strip)

lib/t_ruby/parser_combinator/token/expression_parser.rb

Lines changed: 122 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,33 @@ def parse_precedence(tokens, position, min_precedence)
8989
)
9090
end
9191

92+
# 삼항 연산자: condition ? then_branch : else_branch
93+
if pos < tokens.length && tokens[pos].type == :question
94+
pos += 1 # consume '?'
95+
96+
then_result = parse_expression(tokens, pos)
97+
return then_result if then_result.failure?
98+
99+
pos = then_result.position
100+
101+
unless tokens[pos]&.type == :colon
102+
return TokenParseResult.failure("Expected ':' in ternary operator", tokens, pos)
103+
end
104+
105+
pos += 1 # consume ':'
106+
107+
else_result = parse_expression(tokens, pos)
108+
return else_result if else_result.failure?
109+
110+
left = IR::Conditional.new(
111+
kind: :ternary,
112+
condition: left,
113+
then_branch: then_result.value,
114+
else_branch: else_result.value
115+
)
116+
pos = else_result.position
117+
end
118+
92119
TokenParseResult.success(left, tokens, pos)
93120
end
94121

@@ -101,19 +128,21 @@ def parse_unary(tokens, position)
101128
when :bang
102129
result = parse_unary(tokens, position + 1)
103130
return result if result.failure?
131+
104132
node = IR::UnaryOp.new(operator: :!, operand: result.value)
105133
TokenParseResult.success(node, tokens, result.position)
106134
when :minus
107135
result = parse_unary(tokens, position + 1)
108136
return result if result.failure?
137+
109138
# For negative number literals, we could fold them
110-
if result.value.is_a?(IR::Literal) && result.value.literal_type == :integer
111-
node = IR::Literal.new(value: -result.value.value, literal_type: :integer)
112-
elsif result.value.is_a?(IR::Literal) && result.value.literal_type == :float
113-
node = IR::Literal.new(value: -result.value.value, literal_type: :float)
114-
else
115-
node = IR::UnaryOp.new(operator: :-, operand: result.value)
116-
end
139+
node = if result.value.is_a?(IR::Literal) && result.value.literal_type == :integer
140+
IR::Literal.new(value: -result.value.value, literal_type: :integer)
141+
elsif result.value.is_a?(IR::Literal) && result.value.literal_type == :float
142+
IR::Literal.new(value: -result.value.value, literal_type: :float)
143+
else
144+
IR::UnaryOp.new(operator: :-, operand: result.value)
145+
end
117146
TokenParseResult.success(node, tokens, result.position)
118147
else
119148
parse_postfix(tokens, position)
@@ -149,6 +178,7 @@ def parse_postfix(tokens, position)
149178
if pos < tokens.length && tokens[pos].type == :lparen
150179
args_result = parse_arguments(tokens, pos)
151180
return args_result if args_result.failure?
181+
152182
args = args_result.value
153183
pos = args_result.position
154184
end
@@ -166,6 +196,7 @@ def parse_postfix(tokens, position)
166196

167197
pos = index_result.position
168198
return TokenParseResult.failure("Expected ']'", tokens, pos) unless tokens[pos]&.type == :rbracket
199+
169200
pos += 1
170201

171202
left = IR::MethodCall.new(
@@ -175,18 +206,17 @@ def parse_postfix(tokens, position)
175206
)
176207
when :lparen
177208
# Function call without explicit receiver (left is identifier -> method call)
178-
if left.is_a?(IR::VariableRef) && left.scope == :local
179-
args_result = parse_arguments(tokens, pos)
180-
return args_result if args_result.failure?
209+
break unless left.is_a?(IR::VariableRef) && left.scope == :local
210+
211+
args_result = parse_arguments(tokens, pos)
212+
return args_result if args_result.failure?
213+
214+
left = IR::MethodCall.new(
215+
method_name: left.name,
216+
arguments: args_result.value
217+
)
218+
pos = args_result.position
181219

182-
left = IR::MethodCall.new(
183-
method_name: left.name,
184-
arguments: args_result.value
185-
)
186-
pos = args_result.position
187-
else
188-
break
189-
end
190220
else
191221
break
192222
end
@@ -225,11 +255,11 @@ def parse_primary(tokens, position)
225255
node = IR::Literal.new(value: value, literal_type: :symbol)
226256
TokenParseResult.success(node, tokens, position + 1)
227257

228-
when :true
258+
when true
229259
node = IR::Literal.new(value: true, literal_type: :boolean)
230260
TokenParseResult.success(node, tokens, position + 1)
231261

232-
when :false
262+
when false
233263
node = IR::Literal.new(value: false, literal_type: :boolean)
234264
TokenParseResult.success(node, tokens, position + 1)
235265

@@ -264,6 +294,7 @@ def parse_primary(tokens, position)
264294

265295
pos = result.position
266296
return TokenParseResult.failure("Expected ')'", tokens, pos) unless tokens[pos]&.type == :rparen
297+
267298
TokenParseResult.success(result.value, tokens, pos + 1)
268299

269300
when :lbracket
@@ -281,6 +312,7 @@ def parse_primary(tokens, position)
281312

282313
def parse_arguments(tokens, position)
283314
return TokenParseResult.failure("Expected '('", tokens, position) unless tokens[position]&.type == :lparen
315+
284316
position += 1
285317

286318
args = []
@@ -291,26 +323,77 @@ def parse_arguments(tokens, position)
291323
end
292324

293325
# Parse first argument
294-
result = parse_expression(tokens, position)
326+
result = parse_argument(tokens, position)
295327
return result if result.failure?
328+
296329
args << result.value
297330
position = result.position
298331

299332
# Parse remaining arguments
300333
while tokens[position]&.type == :comma
301334
position += 1
302-
result = parse_expression(tokens, position)
335+
result = parse_argument(tokens, position)
303336
return result if result.failure?
337+
304338
args << result.value
305339
position = result.position
306340
end
307341

308342
return TokenParseResult.failure("Expected ')'", tokens, position) unless tokens[position]&.type == :rparen
343+
309344
TokenParseResult.success(args, tokens, position + 1)
310345
end
311346

347+
# Parse a single argument (handles splat, double splat, and keyword arguments)
348+
def parse_argument(tokens, position)
349+
# Double splat argument: **expr
350+
if tokens[position]&.type == :star_star
351+
position += 1
352+
expr_result = parse_expression(tokens, position)
353+
return expr_result if expr_result.failure?
354+
355+
# Wrap in a splat node (we'll use MethodCall with special name for now)
356+
node = IR::MethodCall.new(
357+
method_name: "**",
358+
arguments: [expr_result.value]
359+
)
360+
return TokenParseResult.success(node, tokens, expr_result.position)
361+
end
362+
363+
# Single splat argument: *expr
364+
if tokens[position]&.type == :star
365+
position += 1
366+
expr_result = parse_expression(tokens, position)
367+
return expr_result if expr_result.failure?
368+
369+
node = IR::MethodCall.new(
370+
method_name: "*",
371+
arguments: [expr_result.value]
372+
)
373+
return TokenParseResult.success(node, tokens, expr_result.position)
374+
end
375+
376+
# Keyword argument: name: value
377+
if tokens[position]&.type == :identifier && tokens[position + 1]&.type == :colon
378+
key_name = tokens[position].value
379+
position += 2 # skip identifier and colon
380+
381+
value_result = parse_expression(tokens, position)
382+
return value_result if value_result.failure?
383+
384+
# Create a hash pair for keyword argument
385+
key = IR::Literal.new(value: key_name.to_sym, literal_type: :symbol)
386+
node = IR::HashPair.new(key: key, value: value_result.value)
387+
return TokenParseResult.success(node, tokens, value_result.position)
388+
end
389+
390+
# Regular expression argument
391+
parse_expression(tokens, position)
392+
end
393+
312394
def parse_array_literal(tokens, position)
313395
return TokenParseResult.failure("Expected '['", tokens, position) unless tokens[position]&.type == :lbracket
396+
314397
position += 1
315398

316399
elements = []
@@ -324,6 +407,7 @@ def parse_array_literal(tokens, position)
324407
# Parse first element
325408
result = parse_expression(tokens, position)
326409
return result if result.failure?
410+
327411
elements << result.value
328412
position = result.position
329413

@@ -332,17 +416,20 @@ def parse_array_literal(tokens, position)
332416
position += 1
333417
result = parse_expression(tokens, position)
334418
return result if result.failure?
419+
335420
elements << result.value
336421
position = result.position
337422
end
338423

339424
return TokenParseResult.failure("Expected ']'", tokens, position) unless tokens[position]&.type == :rbracket
425+
340426
node = IR::ArrayLiteral.new(elements: elements)
341427
TokenParseResult.success(node, tokens, position + 1)
342428
end
343429

344430
def parse_hash_literal(tokens, position)
345431
return TokenParseResult.failure("Expected '{'", tokens, position) unless tokens[position]&.type == :lbrace
432+
346433
position += 1
347434

348435
pairs = []
@@ -356,6 +443,7 @@ def parse_hash_literal(tokens, position)
356443
# Parse first pair
357444
pair_result = parse_hash_pair(tokens, position)
358445
return pair_result if pair_result.failure?
446+
359447
pairs << pair_result.value
360448
position = pair_result.position
361449

@@ -364,11 +452,13 @@ def parse_hash_literal(tokens, position)
364452
position += 1
365453
pair_result = parse_hash_pair(tokens, position)
366454
return pair_result if pair_result.failure?
455+
367456
pairs << pair_result.value
368457
position = pair_result.position
369458
end
370459

371460
return TokenParseResult.failure("Expected '}'", tokens, position) unless tokens[position]&.type == :rbrace
461+
372462
node = IR::HashLiteral.new(pairs: pairs)
373463
TokenParseResult.success(node, tokens, position + 1)
374464
end
@@ -382,15 +472,15 @@ def parse_hash_pair(tokens, position)
382472
# Parse key expression
383473
key_result = parse_expression(tokens, position)
384474
return key_result if key_result.failure?
475+
385476
key = key_result.value
386477
position = key_result.position
387478

388479
# Expect => or :
389-
if tokens[position]&.type == :colon
390-
position += 1
391-
else
392-
return TokenParseResult.failure("Expected ':' or '=>' in hash pair", tokens, position)
393-
end
480+
return TokenParseResult.failure("Expected ':' or '=>' in hash pair", tokens, position) unless tokens[position]&.type == :colon
481+
482+
position += 1
483+
394484
end
395485

396486
# Parse value expression
@@ -419,15 +509,15 @@ def parse_interpolated_string(tokens, position)
419509
position += 1
420510
expr_result = parse_expression(tokens, position)
421511
return expr_result if expr_result.failure?
512+
422513
parts << expr_result.value
423514
position = expr_result.position
424515

425516
# Expect interpolation_end (})
426-
if tokens[position]&.type == :interpolation_end
427-
position += 1
428-
else
429-
return TokenParseResult.failure("Expected '}'", tokens, position)
430-
end
517+
return TokenParseResult.failure("Expected '}'", tokens, position) unless tokens[position]&.type == :interpolation_end
518+
519+
position += 1
520+
431521
when :string_end
432522
position += 1
433523
break

0 commit comments

Comments
 (0)