From 02a0e617e0474d78875a41b3484bef0765585160 Mon Sep 17 00:00:00 2001 From: Wind Date: Wed, 11 Mar 2026 16:26:35 +0800 Subject: [PATCH 01/12] split out AstNode --- src/ast_nodes.rs | 240 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 240 insertions(+) create mode 100644 src/ast_nodes.rs diff --git a/src/ast_nodes.rs b/src/ast_nodes.rs new file mode 100644 index 0000000..a72361b --- /dev/null +++ b/src/ast_nodes.rs @@ -0,0 +1,240 @@ +use nu_protocol::{ast::Expression, engine::Variable}; + +use crate::parser::PipelineId; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct NameNodeId(pub usize); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct NameNode; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct StringNodeId(pub usize); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct StringNode; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct VariableNodeId(pub usize); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct VariableNode; + +#[derive(Debug, Clone)] +pub struct Block { + pub nodes: Vec, +} + +impl Block { + pub fn new(nodes: Vec) -> Block { + Block { nodes } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct BlockId(pub usize); + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ExpressionNode { + Int, + Float, + String(StringNodeId), + Name(NameNodeId), + Variable(VariableNodeId), + + // Booleans + True, + False, + + // Empty values + Null, + + VarRef, + + Closure { + params: Option, + block: BlockId, + }, + + Call { + parts: Vec, + }, + NamedValue { + name: NodeId, + value: NodeId, + }, + BinaryOp { + lhs: NodeId, + op: NodeId, + rhs: NodeId, + }, + Range { + lhs: NodeId, + rhs: NodeId, + }, + List(Vec), + Table { + header: NodeId, + rows: Vec, + }, + Record { + pairs: Vec<(NodeId, NodeId)>, + }, + MemberAccess { + target: NodeId, + field: NodeId, + }, + Block(BlockId), + // Pipeline is also an expression, and it contains a list of expressions. + Pipeline(PipelineId), + If { + condition: NodeId, + then_block: BlockId, + else_block: Option, + }, + Try { + try_block: BlockId, + catch_block: Option, + finally_block: Option, + }, + Match { + target: NodeId, + match_arms: Vec<(NodeId, NodeId)>, + }, + // Pipeline is also an expression, and it contains a list of expressions. + Pipeline(PipelineId), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ExpressionNodeId(pub usize); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum StatementNode { + // Definitions + Def { + name: NodeIndexer, + type_params: Option, + params: NodeId, + in_out_types: Option, + block: BlockId, + env: bool, + wrapped: bool, + }, + Extern { + name: NodeId, + params: NodeId, + }, + Alias { + new_name: NodeIndexer, + old_name: NodeIndexer, + }, + Let { + variable_name: VariableNodeId, + ty: Option, + initializer: NodeId, + is_mutable: bool, + }, + + While { + condition: NodeId, + block: BlockId, + }, + For { + variable: VariableNodeId, + range: NodeId, + block: BlockId, + }, + Loop { + block: BlockId, + }, + Return(Option), + Break, + Continue, + + Expr(ExpressionNodeId), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct StatementNodeId(pub usize); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct NodeId(pub usize); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum NodeIndexer { + Name(NameNodeId), + String(StringNodeId), + Variable(VariableNodeId), + Expression(ExpressionNodeId), + Statement(StatementNodeId), + Block(BlockId), + General(NodeId), +} + +// TODO: All nodes with Vec<...> should be moved to their own ID (like BlockId) to allow Copy trait +#[derive(Debug, Clone, PartialEq)] +pub enum AstNode { + Type { + name: NameNodeId, + args: Option, + optional: bool, + }, + TypeArgs(Vec), + RecordType { + /// Contains [AstNode::Params] + fields: NodeId, + optional: bool, + }, + VarDecl, + + // Operators + Pow, + Multiply, + Divide, + FloorDiv, + Modulo, + Plus, + Minus, + Equal, + NotEqual, + LessThan, + GreaterThan, + LessThanOrEqual, + GreaterThanOrEqual, + RegexMatch, + NotRegexMatch, + In, + Append, + And, + Xor, + Or, + + // Assignments + Assignment, + AddAssignment, + SubtractAssignment, + MultiplyAssignment, + DivideAssignment, + AppendAssignment, + + Params(Vec), + Param { + name: NodeId, + ty: Option, + }, + InOutTypes(Vec), + /// Input/output type pair for a command + InOutType(NodeId, NodeId), + + /// Long flag ('--' + one or more letters) + FlagLong, + /// Short flag ('-' + single letter) + FlagShort, + /// Group of short flags ('-' + more than 1 letters) + FlagShortGroup, + + // ??? should statement belongs to AstNode? + Statement(StatementNodeId), + + Garbage, +} From 1e0a8dfd80fbed3ce6aae68496c53a157f9cd958 Mon Sep 17 00:00:00 2001 From: Wind Date: Wed, 11 Mar 2026 16:27:06 +0800 Subject: [PATCH 02/12] let's define them in compiler --- src/ast_nodes.rs | 144 +++++++++++++++++++++++++++++++++- src/compiler.rs | 197 ++++++++++++++++++++++++++++++++++++----------- src/lib.rs | 1 + 3 files changed, 297 insertions(+), 45 deletions(-) diff --git a/src/ast_nodes.rs b/src/ast_nodes.rs index a72361b..9ec0791 100644 --- a/src/ast_nodes.rs +++ b/src/ast_nodes.rs @@ -1,3 +1,4 @@ +use super::compiler::{Compiler, Span}; use nu_protocol::{ast::Expression, engine::Variable}; use crate::parser::PipelineId; @@ -101,8 +102,6 @@ pub enum ExpressionNode { target: NodeId, match_arms: Vec<(NodeId, NodeId)>, }, - // Pipeline is also an expression, and it contains a list of expressions. - Pipeline(PipelineId), } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -238,3 +237,144 @@ pub enum AstNode { Garbage, } + +pub trait Tmp { + type Output; + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output; + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output; +} + +pub trait Tmp1 { + type Output; + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output; +} + +impl Tmp for NameNodeId { + type Output = NameNode; + + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { + compiler.name_nodes.get_node(self.0) + } + + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { + compiler.name_nodes.get_node_mut(self.0) + } +} + +impl Tmp1 for NameNode { + type Output = NameNodeId; + + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { + compiler.name_nodes.push(span, self); + + let result = NameNodeId(compiler.name_nodes.len() - 1); + let indexer = NodeIndexer::Name(result); + compiler.indexer.push(indexer); + + result + } +} + +impl Tmp for StringNodeId { + type Output = StringNode; + + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { + compiler.string_nodes.get_node(self.0) + } + + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { + compiler.string_nodes.get_node_mut(self.0) + } +} + +impl Tmp1 for StringNode { + type Output = StringNodeId; + + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { + compiler.string_nodes.push(span, self); + + let result = StringNodeId(compiler.string_nodes.len() - 1); + let indexer = NodeIndexer::String(result); + compiler.indexer.push(indexer); + + result + } +} + +impl Tmp for VariableNodeId { + type Output = VariableNode; + + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { + compiler.variable_nodes.get_node(self.0) + } + + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { + compiler.variable_nodes.get_node_mut(self.0) + } +} + +impl Tmp1 for VariableNode { + type Output = VariableNodeId; + + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { + compiler.variable_nodes.push(span, self); + + let result = VariableNodeId(compiler.variable_nodes.len() - 1); + let indexer = NodeIndexer::Variable(result); + compiler.indexer.push(indexer); + + result + } +} + +impl Tmp for BlockId { + type Output = Block; + + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { + compiler.blocks.get_node(self.0) + } + + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { + compiler.blocks.get_node_mut(self.0) + } +} + +impl Tmp1 for Block { + type Output = BlockId; + + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { + compiler.blocks.push(span, self); + + let result = BlockId(compiler.blocks.len() - 1); + let indexer = NodeIndexer::Block(result); + compiler.indexer.push(indexer); + + result + } +} + +impl Tmp for StatementNodeId { + type Output = StatementNode; + + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { + compiler.statement_nodes.get_node(self.0) + } + + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { + compiler.statement_nodes.get_node_mut(self.0) + } +} + +impl Tmp1 for StatementNode { + type Output = StatementNodeId; + + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { + compiler.statement_nodes.push(span, self); + + let result = StatementNodeId(compiler.statement_nodes.len() - 1); + let indexer = NodeIndexer::Statement(result); + compiler.indexer.push(indexer); + + result + } +} diff --git a/src/compiler.rs b/src/compiler.rs index 11136e0..4332ac2 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -1,5 +1,10 @@ +use crate::ast_nodes::{ + AstNode, Block, BlockId, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, NodeId, + NodeIndexer, StatementNode, StatementNodeId, StringNode, StringNodeId, Tmp, Tmp1, VariableNode, + VariableNodeId, +}; use crate::errors::SourceError; -use crate::parser::{AstNode, Block, NodeId, Pipeline}; +use crate::parser::Pipeline; use crate::protocol::Command; use crate::resolver::{ DeclId, Frame, NameBindings, ScopeId, TypeDecl, TypeDeclId, VarId, Variable, @@ -8,8 +13,12 @@ use crate::typechecker::{TypeId, Types}; use std::collections::HashMap; pub struct RollbackPoint { - idx_span_start: usize, idx_nodes: usize, + idx_name_nodes: usize, + idx_string_nodes: usize, + idx_variable_nodes: usize, + idx_expression_nodes: usize, + idx_statment_nodes: usize, idx_errors: usize, idx_blocks: usize, token_pos: usize, @@ -39,15 +48,65 @@ impl Spanned { } } +#[derive(Clone, Debug)] +struct NodeSpans { + nodes: Vec, // indexed by relative nodeId + spans: Vec, +} + +impl NodeSpans { + pub fn new() -> Self { + Self { + nodes: vec![], + spans: vec![], + } + } + pub fn get_span(&self, i: usize) -> Span { + self.spans[i] + } + + pub fn get_node(&self, i: usize) -> &T { + &self.nodes[i] + } + + pub fn get_node_mut(&mut self, i: usize) -> &mut T { + &mut self.nodes[i] + } + + pub fn push(&mut self, span: Span, node: T) { + self.spans.push(span); + self.nodes.push(node); + } + + pub fn len(&self) -> usize { + self.nodes.len() + } + + pub fn truncate(&mut self, len: usize) { + self.nodes.truncate(len); + self.spans.truncate(len); + } + + pub fn is_empty(&self) -> bool { + self.nodes.is_empty() + } +} + #[derive(Clone)] pub struct Compiler { - // Core information, indexed by NodeId: - pub spans: Vec, - pub ast_nodes: Vec, + // different types of nodes. + pub name_nodes: NodeSpans, + pub string_nodes: NodeSpans, + pub variable_nodes: NodeSpans, + pub ast_nodes: NodeSpans, + pub expression_nodes: NodeSpans, + pub statement_nodes: NodeSpans, + pub indexer: Vec, + pub blocks: NodeSpans, // Blocks, indexed by BlockId + pub pipelines: Vec, // Pipelines, indexed by PipelineId + pub node_types: Vec, // node_lifetimes: Vec, - pub blocks: Vec, // Blocks, indexed by BlockId - pub pipelines: Vec, // Pipelines, indexed by PipelineId pub source: Vec, pub file_offsets: Vec<(String, usize, usize)>, // fname, start, end @@ -91,10 +150,15 @@ impl Default for Compiler { impl Compiler { pub fn new() -> Self { Self { - spans: vec![], - ast_nodes: vec![], + string_nodes: NodeSpans::new(), + variable_nodes: NodeSpans::new(), + ast_nodes: NodeSpans::new(), + name_nodes: NodeSpans::new(), + expression_nodes: NodeSpans::new(), + statement_nodes: NodeSpans::new(), node_types: vec![], - blocks: vec![], + indexer: vec![], + blocks: NodeSpans::new(), pipelines: vec![], source: vec![], file_offsets: vec![], @@ -128,20 +192,58 @@ impl Compiler { // TODO: This should say PARSER, not COMPILER let mut result = "==== COMPILER ====\n".to_string(); - for (idx, ast_node) in self.ast_nodes.iter().enumerate() { + for (idx, indexer) in self.indexer.iter().enumerate() { + let (node_dbg_string, span) = match indexer { + NodeIndexer::String(i) => ( + format!("{:?}", self.string_nodes.get_node(i.0)), + self.string_nodes.get_span(i.0), + ), + NodeIndexer::Name(i) => ( + format!("{:?}", self.name_nodes.get_node(i.0)), + self.name_nodes.get_span(i.0), + ), + NodeIndexer::Variable(i) => ( + format!("{:?}", self.variable_nodes.get_node(i.0)), + self.variable_nodes.get_span(i.0), + ), + NodeIndexer::Expression(i) => ( + format!("{:?}", self.expression_nodes.get_node(i.0)), + self.expression_nodes.get_span(i.0), + ), + NodeIndexer::Statement(i) => ( + format!("{:?}", self.statement_nodes.get_node(i.0)), + self.statement_nodes.get_span(i.0), + ), + NodeIndexer::General(i) => ( + format!("{:?}", self.ast_nodes.get_node(i.0)), + self.ast_nodes.get_span(i.0), + ), + NodeIndexer::Block(i) => ( + format!("{:?}", self.blocks.get_node(i.0)), + self.blocks.get_span(i.0), + ), + }; result.push_str(&format!( - "{}: {:?} ({} to {})", - idx, ast_node, self.spans[idx].start, self.spans[idx].end + "{}: {} ({} to {})", + idx, node_dbg_string, span.start, span.end )); if matches!( - ast_node, - AstNode::Name | AstNode::Variable | AstNode::Int | AstNode::Float | AstNode::String + indexer, + NodeIndexer::Name(_) | NodeIndexer::Variable(_) | NodeIndexer::String(_) ) { result.push_str(&format!( " \"{}\"", - String::from_utf8_lossy(self.get_span_contents(NodeId(idx))) + String::from_utf8_lossy(self.get_span_contents(*indexer)) )); + } else if let NodeIndexer::Expression(i) = indexer { + let node = self.expression_nodes.get_node(i.0); + if matches!(node, ExpressionNode::Int | ExpressionNode::Float) { + result.push_str(&format!( + " \"{}\"", + String::from_utf8_lossy(self.get_span_contents(*indexer)) + )); + } } result.push('\n'); @@ -151,8 +253,8 @@ impl Compiler { result.push_str("==== COMPILER ERRORS ====\n"); for error in &self.errors { result.push_str(&format!( - "{:?} (NodeId {}): {}\n", - error.severity, error.node_id.0, error.message + "{:?} (NodeId {:?}): {}\n", + error.severity, error.node_id, error.message )); } } @@ -164,12 +266,12 @@ impl Compiler { self.scope.extend(name_bindings.scope); self.scope_stack.extend(name_bindings.scope_stack); self.variables.extend(name_bindings.variables); - self.var_resolution.extend(name_bindings.var_resolution); + // self.var_resolution.extend(name_bindings.var_resolution); self.type_decls.extend(name_bindings.type_decls); - self.type_resolution.extend(name_bindings.type_resolution); + // self.type_resolution.extend(name_bindings.type_resolution); self.decls.extend(name_bindings.decls); - self.decl_nodes.extend(name_bindings.decl_nodes); - self.decl_resolution.extend(name_bindings.decl_resolution); + // self.decl_nodes.extend(name_bindings.decl_nodes); + // self.decl_resolution.extend(name_bindings.decl_resolution); self.errors.extend(name_bindings.errors); } @@ -191,24 +293,26 @@ impl Compiler { self.source.len() } - pub fn get_node(&self, node_id: NodeId) -> &AstNode { - &self.ast_nodes[node_id.0] + pub fn get_node(&self, node_id: T) -> &T::Output { + node_id.get_node(self) } - pub fn get_node_mut(&mut self, node_id: NodeId) -> &mut AstNode { - &mut self.ast_nodes[node_id.0] + pub fn get_node_mut(&mut self, node_id: T) -> &mut T::Output { + node_id.get_node_mut(self) } - pub fn push_node(&mut self, ast_node: AstNode) -> NodeId { - self.ast_nodes.push(ast_node); - - NodeId(self.ast_nodes.len() - 1) + pub fn push_node(&mut self, span: Span, ast_node: T) -> T::Output { + ast_node.push_node(span, self) } pub fn get_rollback_point(&self, token_pos: usize) -> RollbackPoint { RollbackPoint { - idx_span_start: self.spans.len(), idx_nodes: self.ast_nodes.len(), + idx_name_nodes: self.name_nodes.len(), + idx_string_nodes: self.string_nodes.len(), + idx_variable_nodes: self.variable_nodes.len(), + idx_expression_nodes: self.expression_nodes.len(), + idx_statment_nodes: self.statement_nodes.len(), idx_errors: self.errors.len(), idx_blocks: self.blocks.len(), token_pos, @@ -218,23 +322,30 @@ impl Compiler { pub fn apply_compiler_rollback(&mut self, rbp: RollbackPoint) -> usize { self.blocks.truncate(rbp.idx_blocks); self.ast_nodes.truncate(rbp.idx_nodes); + self.name_nodes.truncate(rbp.idx_name_nodes); + self.string_nodes.truncate(rbp.idx_string_nodes); + self.variable_nodes.truncate(rbp.idx_variable_nodes); self.errors.truncate(rbp.idx_errors); - self.spans.truncate(rbp.idx_span_start); rbp.token_pos } /// Get span of node - pub fn get_span(&self, node_id: NodeId) -> Span { - *self - .spans - .get(node_id.0) - .expect("internal error: missing span of node") + pub fn get_span(&self, node_indexer: NodeIndexer) -> Span { + match node_indexer { + NodeIndexer::String(i) => self.string_nodes.get_span(i.0), + NodeIndexer::Name(i) => self.name_nodes.get_span(i.0), + NodeIndexer::Variable(i) => self.variable_nodes.get_span(i.0), + NodeIndexer::General(i) => self.ast_nodes.get_span(i.0), + NodeIndexer::Expression(i) => self.expression_nodes.get_span(i.0), + NodeIndexer::Block(i) => self.blocks.get_span(i.0), + NodeIndexer::Statement(i) => self.statement_nodes.get_span(i.0), + } } /// Get the source contents of a span of a node - pub fn get_span_contents(&self, node_id: NodeId) -> &[u8] { - let span = self.get_span(node_id); + pub fn get_span_contents(&self, node_indexer: NodeIndexer) -> &[u8] { + let span = self.get_span(node_indexer); self.source .get(span.start..span.end) .expect("internal error: missing source of span") @@ -248,14 +359,14 @@ impl Compiler { } /// Get the source contents of a node - pub fn node_as_str(&self, node_id: NodeId) -> &str { - std::str::from_utf8(self.get_span_contents(node_id)) + pub fn node_as_str(&self, node_indexer: NodeIndexer) -> &str { + std::str::from_utf8(self.get_span_contents(node_indexer)) .expect("internal error: expected utf8 string") } /// Get the source contents of a node as i64 - pub fn node_as_i64(&self, node_id: NodeId) -> i64 { - self.node_as_str(node_id) + pub fn node_as_i64(&self, node_indexer: NodeIndexer) -> i64 { + self.node_as_str(node_indexer) .parse::() .expect("internal error: expected i64") } diff --git a/src/lib.rs b/src/lib.rs index 6d74421..5c9dd37 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod ast_nodes; pub mod compiler; pub mod errors; pub mod ir_generator; From c763962cfb2aae79174f12185f85cb70aff7b9b6 Mon Sep 17 00:00:00 2001 From: Wind Date: Wed, 11 Mar 2026 17:01:13 +0800 Subject: [PATCH 03/12] introduce a new span_end method, also add Pipeline --- src/ast_nodes.rs | 114 +++++++++++++++++++++++- src/compiler.rs | 9 +- src/parser.rs | 222 ++--------------------------------------------- 3 files changed, 123 insertions(+), 222 deletions(-) diff --git a/src/ast_nodes.rs b/src/ast_nodes.rs index 9ec0791..6bcfb4e 100644 --- a/src/ast_nodes.rs +++ b/src/ast_nodes.rs @@ -1,8 +1,6 @@ use super::compiler::{Compiler, Span}; use nu_protocol::{ast::Expression, engine::Variable}; -use crate::parser::PipelineId; - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct NameNodeId(pub usize); @@ -35,6 +33,34 @@ impl Block { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct BlockId(pub usize); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct PipelineId(pub usize); + +// Pipeline just contains a list of expressions +// +// It's not allowed if there is only one element in pipeline, in that +// case, it's just an expression. +// +// Making such restriction can reduce indirect access on expression, which +// can improve performance in parse time. +#[derive(Debug, Clone, PartialEq)] +pub struct Pipeline { + pub nodes: Vec, +} + +impl Pipeline { + pub fn new(nodes: Vec) -> Self { + debug_assert!( + nodes.len() > 1, + "a pipeline must contain at least 2 nodes, or else it's actually an expression" + ); + Self { nodes } + } + + pub fn get_expressions(&self) -> &Vec { + &self.nodes + } +} #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum ExpressionNode { Int, @@ -167,6 +193,7 @@ pub enum NodeIndexer { Expression(ExpressionNodeId), Statement(StatementNodeId), Block(BlockId), + Pipeline(PipelineId), General(NodeId), } @@ -242,6 +269,7 @@ pub trait Tmp { type Output; fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output; fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output; + fn get_span_end(&self, compiler: &Compiler) -> Span; } pub trait Tmp1 { @@ -259,6 +287,10 @@ impl Tmp for NameNodeId { fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { compiler.name_nodes.get_node_mut(self.0) } + + fn get_span_end(&self, compiler: &Compiler) -> Span { + compiler.name_nodes.get_span(self.0) + } } impl Tmp1 for NameNode { @@ -285,6 +317,10 @@ impl Tmp for StringNodeId { fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { compiler.string_nodes.get_node_mut(self.0) } + + fn get_span_end(&self, compiler: &Compiler) -> Span { + compiler.string_nodes.get_span(self.0) + } } impl Tmp1 for StringNode { @@ -311,6 +347,10 @@ impl Tmp for VariableNodeId { fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { compiler.variable_nodes.get_node_mut(self.0) } + + fn get_span_end(&self, compiler: &Compiler) -> Span { + compiler.variable_nodes.get_span(self.0) + } } impl Tmp1 for VariableNode { @@ -337,6 +377,10 @@ impl Tmp for BlockId { fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { compiler.blocks.get_node_mut(self.0) } + + fn get_span_end(&self, compiler: &Compiler) -> Span { + compiler.blocks.get_span(self.0) + } } impl Tmp1 for Block { @@ -363,6 +407,10 @@ impl Tmp for StatementNodeId { fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { compiler.statement_nodes.get_node_mut(self.0) } + + fn get_span_end(&self, compiler: &Compiler) -> Span { + compiler.statement_nodes.get_span(self.0) + } } impl Tmp1 for StatementNode { @@ -378,3 +426,65 @@ impl Tmp1 for StatementNode { result } } + +impl Tmp for PipelineId { + type Output = Pipeline; + + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { + compiler.pipelines.get_node(self.0) + } + + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { + compiler.pipelines.get_node_mut(self.0) + } + + fn get_span_end(&self, compiler: &Compiler) -> Span { + compiler.pipelines.get_span(self.0) + } +} + +impl Tmp1 for Pipeline { + type Output = PipelineId; + + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { + compiler.pipelines.push(span, self); + + let result = PipelineId(compiler.pipelines.len() - 1); + let indexer = NodeIndexer::Pipeline(result); + compiler.indexer.push(indexer); + + result + } +} + +impl Tmp for ExpressionNodeId { + type Output = ExpressionNode; + + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { + compiler.expression_nodes.get_node(self.0) + } + + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { + compiler.expression_nodes.get_node_mut(self.0) + } + + fn get_span_end(&self, compiler: &Compiler) -> Span { + compiler.expression_nodes.get_span(self.0) + } +} + +impl Tmp for NodeId { + type Output = AstNode; + + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { + compiler.ast_nodes.get_node(self.0) + } + + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { + compiler.ast_nodes.get_node_mut(self.0) + } + + fn get_span_end(&self, compiler: &Compiler) -> Span { + compiler.ast_nodes.get_span(self.0) + } +} diff --git a/src/compiler.rs b/src/compiler.rs index 4332ac2..ef93bab 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -1,10 +1,9 @@ use crate::ast_nodes::{ AstNode, Block, BlockId, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, NodeId, - NodeIndexer, StatementNode, StatementNodeId, StringNode, StringNodeId, Tmp, Tmp1, VariableNode, - VariableNodeId, + NodeIndexer, Pipeline, StatementNode, StatementNodeId, StringNode, StringNodeId, Tmp, Tmp1, + VariableNode, VariableNodeId, }; use crate::errors::SourceError; -use crate::parser::Pipeline; use crate::protocol::Command; use crate::resolver::{ DeclId, Frame, NameBindings, ScopeId, TypeDecl, TypeDeclId, VarId, Variable, @@ -101,9 +100,9 @@ pub struct Compiler { pub ast_nodes: NodeSpans, pub expression_nodes: NodeSpans, pub statement_nodes: NodeSpans, + pub blocks: NodeSpans, // Blocks, indexed by BlockId + pub pipelines: NodeSpans, // Pipelines, indexed by PipelineId pub indexer: Vec, - pub blocks: NodeSpans, // Blocks, indexed by BlockId - pub pipelines: Vec, // Pipelines, indexed by PipelineId pub node_types: Vec, // node_lifetimes: Vec, diff --git a/src/parser.rs b/src/parser.rs index 0bece02..fa6719c 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,3 +1,8 @@ +use crate::ast_nodes::{ + AstNode, Block, BlockId, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, NodeId, + NodeIndexer, Pipeline, PipelineId, StatementNode, StatementNodeId, StringNode, StringNodeId, + Tmp, Tmp1, VariableNode, VariableNodeId, +}; use crate::compiler::{Compiler, RollbackPoint, Span}; use crate::errors::{Severity, SourceError}; use crate::lexer::{Token, Tokens}; @@ -9,52 +14,6 @@ pub struct Parser { tokens: Tokens, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct NodeId(pub usize); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct BlockId(pub usize); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct PipelineId(pub usize); - -#[derive(Debug, Clone)] -pub struct Block { - pub nodes: Vec, -} - -impl Block { - pub fn new(nodes: Vec) -> Block { - Block { nodes } - } -} - -// Pipeline just contains a list of expressions -// -// It's not allowed if there is only one element in pipeline, in that -// case, it's just an expression. -// -// Making such restriction can reduce indirect access on expression, which -// can improve performance in parse time. -#[derive(Debug, Clone, PartialEq)] -pub struct Pipeline { - pub nodes: Vec, -} - -impl Pipeline { - pub fn new(nodes: Vec) -> Self { - debug_assert!( - nodes.len() > 1, - "a pipeline must contain at least 2 nodes, or else it's actually an expression" - ); - Self { nodes } - } - - pub fn get_expressions(&self) -> &Vec { - &self.nodes - } -} - #[derive(Debug, Clone, PartialEq)] pub enum BlockContext { /// This block is a whole block of code not wrapped in curlies (e.g., a file) @@ -96,173 +55,6 @@ impl AssignmentOrExpression { } } -// TODO: All nodes with Vec<...> should be moved to their own ID (like BlockId) to allow Copy trait -#[derive(Debug, PartialEq, Clone)] -pub enum AstNode { - Int, - Float, - String, - Name, - Type { - name: NodeId, - args: Option, - optional: bool, - }, - TypeArgs(Vec), - RecordType { - /// Contains [AstNode::Params] - fields: NodeId, - optional: bool, - }, - Variable, - - // Booleans - True, - False, - - // Empty values - Null, - - // Operators - Pow, - Multiply, - Divide, - FloorDiv, - Modulo, - Plus, - Minus, - Equal, - NotEqual, - LessThan, - GreaterThan, - LessThanOrEqual, - GreaterThanOrEqual, - RegexMatch, - NotRegexMatch, - In, - Append, - And, - Xor, - Or, - - // Assignments - Assignment, - AddAssignment, - SubtractAssignment, - MultiplyAssignment, - DivideAssignment, - AppendAssignment, - - // Statements - Let { - variable_name: NodeId, - ty: Option, - initializer: NodeId, - is_mutable: bool, - }, - While { - condition: NodeId, - block: NodeId, - }, - For { - variable: NodeId, - range: NodeId, - block: NodeId, - }, - Loop { - block: NodeId, - }, - Return(Option), - Break, - Continue, - - // Definitions - Def { - name: NodeId, - type_params: Option, - params: NodeId, - in_out_types: Option, - block: NodeId, - env: bool, - wrapped: bool, - }, - Extern { - name: NodeId, - params: NodeId, - }, - Params(Vec), - Param { - name: NodeId, - ty: Option, - }, - InOutTypes(Vec), - /// Input/output type pair for a command - InOutType(NodeId, NodeId), - Closure { - params: Option, - block: NodeId, - }, - Alias { - new_name: NodeId, - old_name: NodeId, - }, - - /// Long flag ('--' + one or more letters) - FlagLong, - /// Short flag ('-' + single letter) - FlagShort, - /// Group of short flags ('-' + more than 1 letters) - FlagShortGroup, - - // Expressions - Call { - parts: Vec, - }, - NamedValue { - name: NodeId, - value: NodeId, - }, - BinaryOp { - lhs: NodeId, - op: NodeId, - rhs: NodeId, - }, - Range { - lhs: NodeId, - rhs: NodeId, - }, - List(Vec), - Table { - header: NodeId, - rows: Vec, - }, - Record { - pairs: Vec<(NodeId, NodeId)>, - }, - MemberAccess { - target: NodeId, - field: NodeId, - }, - Block(BlockId), - Pipeline(PipelineId), - If { - condition: NodeId, - then_block: NodeId, - else_block: Option, - }, - Try { - try_block: NodeId, - catch_block: Option, - finally_block: Option, - }, - Match { - target: NodeId, - match_arms: Vec<(NodeId, NodeId)>, - }, - Statement(NodeId), - Garbage, -} - pub const ASSIGNMENT_PRECEDENCE: usize = 10; impl AstNode { @@ -304,8 +96,8 @@ impl Parser { self.tokens.peek_span().start } - fn get_span_end(&self, node_id: NodeId) -> usize { - self.compiler.spans[node_id.0].end + fn get_span_end(&self, node_id: T) -> usize { + node_id.get_span_end(&self.compiler).end } pub fn parse(mut self) -> Compiler { From 3a7d3c29fb4a37f77d234b09913469dd49b0ff2a Mon Sep 17 00:00:00 2001 From: Wind Date: Wed, 11 Mar 2026 17:32:21 +0800 Subject: [PATCH 04/12] parser change --- src/ast_nodes.rs | 85 ++++--- src/compiler.rs | 2 + src/errors.rs | 4 +- src/parser.rs | 593 +++++++++++++++++++++++++++++++---------------- 4 files changed, 454 insertions(+), 230 deletions(-) diff --git a/src/ast_nodes.rs b/src/ast_nodes.rs index 6bcfb4e..a534817 100644 --- a/src/ast_nodes.rs +++ b/src/ast_nodes.rs @@ -21,11 +21,11 @@ pub struct VariableNode; #[derive(Debug, Clone)] pub struct Block { - pub nodes: Vec, + pub nodes: Vec, } impl Block { - pub fn new(nodes: Vec) -> Block { + pub fn new(nodes: Vec) -> Block { Block { nodes } } } @@ -84,7 +84,8 @@ pub enum ExpressionNode { }, Call { - parts: Vec, + head: Vec, + parts: Vec, }, NamedValue { name: NodeId, @@ -99,13 +100,13 @@ pub enum ExpressionNode { lhs: NodeId, rhs: NodeId, }, - List(Vec), + List(Vec), Table { - header: NodeId, - rows: Vec, + header: ExpressionNodeId, + rows: Vec, }, Record { - pairs: Vec<(NodeId, NodeId)>, + pairs: Vec<(ExpressionNodeId, ExpressionNodeId)>, }, MemberAccess { target: NodeId, @@ -115,9 +116,9 @@ pub enum ExpressionNode { // Pipeline is also an expression, and it contains a list of expressions. Pipeline(PipelineId), If { - condition: NodeId, + condition: ExpressionNodeId, then_block: BlockId, - else_block: Option, + else_block: Option, // it can be a block, or another if expression (else if) }, Try { try_block: BlockId, @@ -125,8 +126,8 @@ pub enum ExpressionNode { finally_block: Option, }, Match { - target: NodeId, - match_arms: Vec<(NodeId, NodeId)>, + target: ExpressionNodeId, + match_arms: Vec<(ExpressionNodeId, ExpressionNodeId)>, }, } @@ -137,7 +138,7 @@ pub struct ExpressionNodeId(pub usize); pub enum StatementNode { // Definitions Def { - name: NodeIndexer, + name: NodeIndexer, // can be string or name type_params: Option, params: NodeId, in_out_types: Option, @@ -146,7 +147,7 @@ pub enum StatementNode { wrapped: bool, }, Extern { - name: NodeId, + name: NodeIndexer, // can be string or name params: NodeId, }, Alias { @@ -156,17 +157,17 @@ pub enum StatementNode { Let { variable_name: VariableNodeId, ty: Option, - initializer: NodeId, + initializer: ExpressionNodeId, is_mutable: bool, }, While { - condition: NodeId, + condition: ExpressionNodeId, block: BlockId, }, For { variable: VariableNodeId, - range: NodeId, + range: ExpressionNodeId, block: BlockId, }, Loop { @@ -243,7 +244,7 @@ pub enum AstNode { DivideAssignment, AppendAssignment, - Params(Vec), + Params(Vec), Param { name: NodeId, ty: Option, @@ -269,7 +270,14 @@ pub trait Tmp { type Output; fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output; fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output; - fn get_span_end(&self, compiler: &Compiler) -> Span; + fn get_span(&self, compiler: &Compiler) -> Span; + fn get_span_contents<'a>(&self, compiler: &'a Compiler) -> &'a [u8] { + let span = self.get_span(compiler); + compiler + .source + .get(span.start..span.end) + .expect("internal error: missing source of span") + } } pub trait Tmp1 { @@ -288,7 +296,7 @@ impl Tmp for NameNodeId { compiler.name_nodes.get_node_mut(self.0) } - fn get_span_end(&self, compiler: &Compiler) -> Span { + fn get_span(&self, compiler: &Compiler) -> Span { compiler.name_nodes.get_span(self.0) } } @@ -318,7 +326,7 @@ impl Tmp for StringNodeId { compiler.string_nodes.get_node_mut(self.0) } - fn get_span_end(&self, compiler: &Compiler) -> Span { + fn get_span(&self, compiler: &Compiler) -> Span { compiler.string_nodes.get_span(self.0) } } @@ -348,7 +356,7 @@ impl Tmp for VariableNodeId { compiler.variable_nodes.get_node_mut(self.0) } - fn get_span_end(&self, compiler: &Compiler) -> Span { + fn get_span(&self, compiler: &Compiler) -> Span { compiler.variable_nodes.get_span(self.0) } } @@ -378,7 +386,7 @@ impl Tmp for BlockId { compiler.blocks.get_node_mut(self.0) } - fn get_span_end(&self, compiler: &Compiler) -> Span { + fn get_span(&self, compiler: &Compiler) -> Span { compiler.blocks.get_span(self.0) } } @@ -408,7 +416,7 @@ impl Tmp for StatementNodeId { compiler.statement_nodes.get_node_mut(self.0) } - fn get_span_end(&self, compiler: &Compiler) -> Span { + fn get_span(&self, compiler: &Compiler) -> Span { compiler.statement_nodes.get_span(self.0) } } @@ -438,7 +446,7 @@ impl Tmp for PipelineId { compiler.pipelines.get_node_mut(self.0) } - fn get_span_end(&self, compiler: &Compiler) -> Span { + fn get_span(&self, compiler: &Compiler) -> Span { compiler.pipelines.get_span(self.0) } } @@ -456,6 +464,19 @@ impl Tmp1 for Pipeline { result } } +impl Tmp1 for ExpressionNode { + type Output = ExpressionNodeId; + + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { + compiler.expression_nodes.push(span, self); + + let result = ExpressionNodeId(compiler.expression_nodes.len() - 1); + let indexer = NodeIndexer::Expression(result); + compiler.indexer.push(indexer); + + result + } +} impl Tmp for ExpressionNodeId { type Output = ExpressionNode; @@ -468,11 +489,23 @@ impl Tmp for ExpressionNodeId { compiler.expression_nodes.get_node_mut(self.0) } - fn get_span_end(&self, compiler: &Compiler) -> Span { + fn get_span(&self, compiler: &Compiler) -> Span { compiler.expression_nodes.get_span(self.0) } } +impl Tmp1 for AstNode { + type Output = NodeId; + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { + compiler.ast_nodes.push(span, self); + + let result = NodeId(compiler.ast_nodes.len() - 1); + let indexer = NodeIndexer::General(result); + compiler.indexer.push(indexer); + + result + } +} impl Tmp for NodeId { type Output = AstNode; @@ -484,7 +517,7 @@ impl Tmp for NodeId { compiler.ast_nodes.get_node_mut(self.0) } - fn get_span_end(&self, compiler: &Compiler) -> Span { + fn get_span(&self, compiler: &Compiler) -> Span { compiler.ast_nodes.get_span(self.0) } } diff --git a/src/compiler.rs b/src/compiler.rs index ef93bab..4686025 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -330,6 +330,7 @@ impl Compiler { } /// Get span of node + /// TODO: no need this. pub fn get_span(&self, node_indexer: NodeIndexer) -> Span { match node_indexer { NodeIndexer::String(i) => self.string_nodes.get_span(i.0), @@ -343,6 +344,7 @@ impl Compiler { } /// Get the source contents of a span of a node + /// TODO: no need this. pub fn get_span_contents(&self, node_indexer: NodeIndexer) -> &[u8] { let span = self.get_span(node_indexer); self.source diff --git a/src/errors.rs b/src/errors.rs index 6f0d562..b395c59 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,4 +1,4 @@ -use crate::parser::NodeId; +use crate::{ast_nodes::NodeIndexer, parser::NodeId}; #[derive(Debug, Clone, Copy)] pub enum Severity { @@ -9,6 +9,6 @@ pub enum Severity { #[derive(Debug, Clone)] pub struct SourceError { pub message: String, - pub node_id: NodeId, + pub node_id: NodeIndexer, pub severity: Severity, } diff --git a/src/parser.rs b/src/parser.rs index fa6719c..7aa467d 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -97,7 +97,7 @@ impl Parser { } fn get_span_end(&self, node_id: T) -> usize { - node_id.get_span_end(&self.compiler).end + node_id.get_span(&self.compiler).end } pub fn parse(mut self) -> Compiler { @@ -107,9 +107,9 @@ impl Parser { self.compiler } - pub fn expression(&mut self) -> NodeId { + pub fn expression(&mut self) -> ExpressionNodeId { let _span = span!(); - self.math_expression(false).get_node_id() + self.math_expression(false) } fn pipeline(&mut self, first_element: NodeId, span_start: usize) -> NodeId { @@ -146,7 +146,7 @@ impl Parser { self.pipeline(first_id, span_start) } - pub fn pipeline_or_expression(&mut self) -> NodeId { + pub fn pipeline_or_expression(&mut self) -> ExpressionNodeId { let _span = span!(); let span_start = self.position(); let first_id = self.expression(); @@ -157,7 +157,7 @@ impl Parser { self.pipeline(first_id, span_start) } - fn math_expression(&mut self, allow_assignment: bool) -> AssignmentOrExpression { + fn math_expression(&mut self, allow_assignment: bool) -> ExpressionNode { let _span = span!(); let mut expr_stack = Vec::<(NodeId, NodeId)>::new(); @@ -271,7 +271,7 @@ impl Parser { AssignmentOrExpression::Expression(leftmost) } - pub fn simple_expression(&mut self, bareword_context: BarewordContext) -> NodeId { + pub fn simple_expression(&mut self, bareword_context: BarewordContext) -> ExpressionNodeId { let _span = span!(); // skip comments and newlines @@ -376,28 +376,36 @@ impl Parser { } } - pub fn advance_node(&mut self, node: AstNode, span: Span) -> NodeId { + pub fn advance_node(&mut self, node: T, span: Span) -> T::Output { self.tokens.advance(); - self.create_node(node, span.start, span.end) + node.push_node(span, &mut self.compiler) } - pub fn variable(&mut self) -> NodeId { + pub fn variable(&mut self) -> Option { if self.is_dollar() { let span_start = self.position(); self.tokens.advance(); if let (Token::Bareword, name_span) = self.tokens.peek() { self.tokens.advance(); - self.create_node(AstNode::Variable, span_start, name_span.end) + Some(VariableNode.push_node( + Span { + start: span_start, + end: name_span.end, + }, + &mut self.compiler, + )) } else { - self.error("variable name must be a bareword") + self.error("variable name must be a bareword"); + None } } else { - self.error("expected variable starting with '$'") + self.error("expected variable starting with '$'"); + None } } - pub fn variable_decl(&mut self) -> NodeId { + pub fn variable_decl(&mut self) -> Option { let _span = span!(); let span_start = self.position(); @@ -408,15 +416,29 @@ impl Parser { if let (Token::Bareword, name_span) = self.tokens.peek() { self.tokens.advance(); - self.create_node(AstNode::Variable, span_start, name_span.end) + Some(VariableNode.push_node( + Span { + start: span_start, + end: name_span.end, + }, + &mut self.compiler, + )) } else { - self.error("variable assignment name must be a bareword") + self.error("variable assignment name must be a bareword"); + None } } - pub fn call(&mut self) -> NodeId { + pub fn advance_unchecked(&mut self, node: T) -> T::Output { + let span = self.tokens.peek_span(); + self.tokens.advance(); + node.push_node(span, &mut self.compiler) + } + + pub fn call(&mut self) -> ExpressionNodeId { let _span = span!(); - let mut parts = vec![self.call_name()]; + let mut head = vec![self.call_name()]; + let mut parts = vec![]; let mut is_head = true; let span_start = self.position(); @@ -426,7 +448,7 @@ impl Parser { } if self.is_name() && is_head { - parts.push(self.name()); + head.push(self.advance_unchecked(NameNode)); continue; } @@ -439,10 +461,16 @@ impl Parser { let span_end = self.position(); - self.create_node(AstNode::Call { parts }, span_start, span_end) + ExpressionNode::Call { head, parts }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ) } - pub fn list_or_table(&mut self) -> NodeId { + pub fn list_or_table(&mut self) -> ExpressionNodeId { let _span = span!(); let span_start = self.position(); let mut is_table = false; @@ -462,15 +490,18 @@ impl Parser { } else if self.is_semicolon() { if items.len() != 1 { self.error("semicolon to create table should immediately follow headers"); - } else if !matches!(self.compiler.get_node(items[0]), AstNode::List(_)) { - self.error_on_node("tables require a list for their headers", items[0]) + } else if !matches!(self.compiler.get_node(items[0]), ExpressionNode::List(_)) { + self.error_on_node( + "tables require a list for their headers", + NodeIndexer::Expression(items[0]), + ) } self.tokens.advance(); is_table = true; } else if self.is_simple_expression() { items.push(self.simple_expression(BarewordContext::String)); } else { - items.push(self.error("expected list item")); + self.error("expected list item"); if self.is_eof() { // prevent forever looping if there is no token to put the error on break; @@ -480,20 +511,29 @@ impl Parser { if is_table { let header = items.remove(0); - self.create_node( - AstNode::Table { - header, - rows: items, + ExpressionNode::Table { + header, + rows: items, + } + .push_node( + Span { + start: span_start, + end: span_end, }, - span_start, - span_end, + &mut self.compiler, ) } else { - self.create_node(AstNode::List(items), span_start, span_end) + ExpressionNode::List(items).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ) } } - pub fn record_or_closure(&mut self) -> NodeId { + pub fn record_or_closure(&mut self) -> ExpressionNodeId { let _span = span!(); let span_start = self.position(); let mut span_end = self.position(); // TODO: make sure we only initialize it expectedly @@ -513,7 +553,13 @@ impl Parser { self.rcurly(); span_end = self.position(); - return self.create_node(AstNode::Closure { params, block }, span_start, span_end); + return ExpressionNode::Closure { params, block }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ); } let rollback_point = self.get_rollback_point(); @@ -552,56 +598,71 @@ impl Parser { span_end = self.position(); - self.create_node( - AstNode::Closure { - params: None, - block, + ExpressionNode::Closure { + params: None, + block, + } + .push_node( + Span { + start: span_start, + end: span_end, }, - span_start, - span_end, + &mut self.compiler, ) } else { - self.create_node(AstNode::Record { pairs: items }, span_start, span_end) + ExpressionNode::Record { pairs: items }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ) } } - pub fn operator(&mut self) -> NodeId { + pub fn operator(&mut self) -> Option { let (token, span) = self.tokens.peek(); match token { - Token::Plus => self.advance_node(AstNode::Plus, span), - Token::PlusPlus => self.advance_node(AstNode::Append, span), - Token::Dash => self.advance_node(AstNode::Minus, span), - Token::Asterisk => self.advance_node(AstNode::Multiply, span), - Token::ForwardSlash => self.advance_node(AstNode::Divide, span), - Token::ForwardSlashForwardSlash => self.advance_node(AstNode::FloorDiv, span), - Token::LessThan => self.advance_node(AstNode::LessThan, span), - Token::LessThanEqual => self.advance_node(AstNode::LessThanOrEqual, span), - Token::GreaterThan => self.advance_node(AstNode::GreaterThan, span), - Token::GreaterThanEqual => self.advance_node(AstNode::GreaterThanOrEqual, span), - Token::EqualsEquals => self.advance_node(AstNode::Equal, span), - Token::ExclamationEquals => self.advance_node(AstNode::NotEqual, span), - Token::EqualsTilde => self.advance_node(AstNode::RegexMatch, span), - Token::ExclamationTilde => self.advance_node(AstNode::NotRegexMatch, span), - Token::AsteriskAsterisk => self.advance_node(AstNode::Pow, span), - Token::Equals => self.advance_node(AstNode::Assignment, span), - Token::PlusEquals => self.advance_node(AstNode::AddAssignment, span), - Token::DashEquals => self.advance_node(AstNode::SubtractAssignment, span), - Token::AsteriskEquals => self.advance_node(AstNode::MultiplyAssignment, span), - Token::ForwardSlashEquals => self.advance_node(AstNode::DivideAssignment, span), - Token::PlusPlusEquals => self.advance_node(AstNode::AppendAssignment, span), + Token::Plus => Some(self.advance_node(AstNode::Plus, span)), + Token::PlusPlus => Some(self.advance_node(AstNode::Append, span)), + Token::Dash => Some(self.advance_node(AstNode::Minus, span)), + Token::Asterisk => Some(self.advance_node(AstNode::Multiply, span)), + Token::ForwardSlash => Some(self.advance_node(AstNode::Divide, span)), + Token::ForwardSlashForwardSlash => Some(self.advance_node(AstNode::FloorDiv, span)), + Token::LessThan => Some(self.advance_node(AstNode::LessThan, span)), + Token::LessThanEqual => Some(self.advance_node(AstNode::LessThanOrEqual, span)), + Token::GreaterThan => Some(self.advance_node(AstNode::GreaterThan, span)), + Token::GreaterThanEqual => Some(self.advance_node(AstNode::GreaterThanOrEqual, span)), + Token::EqualsEquals => Some(self.advance_node(AstNode::Equal, span)), + Token::ExclamationEquals => Some(self.advance_node(AstNode::NotEqual, span)), + Token::EqualsTilde => Some(self.advance_node(AstNode::RegexMatch, span)), + Token::ExclamationTilde => Some(self.advance_node(AstNode::NotRegexMatch, span)), + Token::AsteriskAsterisk => Some(self.advance_node(AstNode::Pow, span)), + Token::Equals => Some(self.advance_node(AstNode::Assignment, span)), + Token::PlusEquals => Some(self.advance_node(AstNode::AddAssignment, span)), + Token::DashEquals => Some(self.advance_node(AstNode::SubtractAssignment, span)), + Token::AsteriskEquals => Some(self.advance_node(AstNode::MultiplyAssignment, span)), + Token::ForwardSlashEquals => Some(self.advance_node(AstNode::DivideAssignment, span)), + Token::PlusPlusEquals => Some(self.advance_node(AstNode::AppendAssignment, span)), Token::Bareword => match self.compiler.get_span_contents_manual(span.start, span.end) { - b"mod" => self.advance_node(AstNode::Modulo, span), - b"in" => self.advance_node(AstNode::In, span), - b"and" => self.advance_node(AstNode::And, span), - b"xor" => self.advance_node(AstNode::Xor, span), - b"or" => self.advance_node(AstNode::Or, span), - op => self.error(format!( - "Unknown operator: '{}'", - String::from_utf8_lossy(op) - )), + b"mod" => Some(self.advance_node(AstNode::Modulo, span)), + b"in" => Some(self.advance_node(AstNode::In, span)), + b"and" => Some(self.advance_node(AstNode::And, span)), + b"xor" => Some(self.advance_node(AstNode::Xor, span)), + b"or" => Some(self.advance_node(AstNode::Or, span)), + op => { + self.error(format!( + "Unknown operator: '{}'", + String::from_utf8_lossy(op) + )); + None + } }, - _ => self.error("expected: operator"), + _ => { + self.error("expected: operator"); + None + } } } @@ -616,22 +677,28 @@ impl Parser { ) } - pub fn string(&mut self) -> NodeId { + pub fn string(&mut self) -> Option { match self.tokens.peek() { - (Token::DoubleQuotedString, span) => self.advance_node(AstNode::String, span), - (Token::SingleQuotedString, span) => self.advance_node(AstNode::String, span), - _ => self.error("expected: string"), + (Token::DoubleQuotedString, span) => Some(self.advance_node(StringNode, span)), + (Token::SingleQuotedString, span) => Some(self.advance_node(StringNode, span)), + _ => { + self.error("expected: string"); + None + } } } - pub fn name(&mut self) -> NodeId { + pub fn name(&mut self) -> Option { match self.tokens.peek() { - (Token::Bareword, span) => self.advance_node(AstNode::Name, span), - _ => self.error("expected: name"), + (Token::Bareword, span) => Some(self.advance_node(NameNode, span)), + _ => { + self.error("expected: name"); + None + } } } - pub fn call_name(&mut self) -> NodeId { + pub fn call_name(&mut self) -> NameNodeId { let (mut token, mut span) = self.tokens.peek(); loop { @@ -651,14 +718,14 @@ impl Parser { span.end = next_span.end; } - self.create_node(AstNode::Name, span.start, span.end) + NameNode.push_node(span, &mut self.compiler) } pub fn has_tokens(&mut self) -> bool { self.tokens.peek_token() != Token::Eof } - pub fn match_expression(&mut self) -> NodeId { + pub fn match_expression(&mut self) -> Option { let _span = span!(); let span_start = self.position(); let span_end; @@ -669,7 +736,8 @@ impl Parser { let mut match_arms = vec![]; if !self.is_lcurly() { - return self.error("expected left curly brace '{'"); + self.error("expected left curly brace '{'"); + return None; } self.lcurly(); @@ -683,7 +751,8 @@ impl Parser { let pattern = self.simple_expression(BarewordContext::String); if !self.is_thick_arrow() { - return self.error("expected thick arrow (=>) between match cases"); + self.error("expected thick arrow (=>) between match cases"); + return None; } self.tokens.advance(); @@ -697,14 +766,21 @@ impl Parser { } else if self.is_newline() { self.tokens.advance(); } else { - return self.error("expected match arm in match"); + self.error("expected match arm in match"); + return None; } } - self.create_node(AstNode::Match { target, match_arms }, span_start, span_end) + Some(ExpressionNode::Match { target, match_arms }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } - pub fn if_expression(&mut self) -> NodeId { + pub fn if_expression(&mut self) -> Option { let _span = span!(); let span_start = self.position(); let span_end; @@ -722,31 +798,41 @@ impl Parser { self.skip_newlines(); let block = if self.is_keyword(b"if") { - self.if_expression() + let exp = self.if_expression()?; + span_end = self.get_span_end(exp); + NodeIndexer::Expression(self.if_expression()?) } else if self.is_keyword(b"match") { - self.match_expression() + let match_exp = self.match_expression()?; + span_end = self.get_span_end(match_exp); + NodeIndexer::Expression(match_exp) } else { - self.block(BlockContext::Curlies) + let exp = self.block(BlockContext::Curlies); + span_end = self.get_span_end(exp); + NodeIndexer::Block(exp) }; - span_end = self.get_span_end(block); Some(block) } else { span_end = self.get_span_end(then_block); None }; - self.create_node( - AstNode::If { + Some( + ExpressionNode::If { condition, then_block, else_block, - }, - span_start, - span_end, + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), ) } - pub fn try_expression(&mut self) -> NodeId { + pub fn try_expression(&mut self) -> ExpressionNodeId { let _span = span!(); let span_start = self.position(); @@ -781,14 +867,17 @@ impl Parser { None }; - self.create_node( - AstNode::Try { - try_block, - catch_block, - finally_block, + ExpressionNode::Try { + try_block, + catch_block, + finally_block, + } + .push_node( + Span { + start: span_start, + end: span_end, }, - span_start, - span_end, + &mut self.compiler, ) } @@ -887,13 +976,21 @@ impl Parser { continue; } - param_list.push(self.name()); + if let Some(name) = self.name() { + param_list.push(name) + } } let span_end = self.position() + 1; self.greater_than(); - self.create_node(AstNode::Params(param_list), span_start, span_end) + AstNode::Params(param_list).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ) } pub fn type_args(&mut self) -> NodeId { @@ -927,11 +1024,11 @@ impl Parser { self.create_node(AstNode::TypeArgs(arg_list), span_start, span_end) } - pub fn typename(&mut self) -> NodeId { + pub fn typename(&mut self) -> Option { let _span = span!(); if let (Token::Bareword, span) = self.tokens.peek() { - let name = self.name(); - let name_text = self.compiler.get_span_contents(name); + let name = self.name()?; + let name_text = name.get_span_contents(&self.compiler); if name_text == b"record" { let fields = self.signature_params(ParamsContext::Angles); @@ -943,11 +1040,13 @@ impl Parser { false }; let span_end = self.position(); - return self.create_node( - AstNode::RecordType { fields, optional }, - span.start, - span_end, - ); + return Some(AstNode::RecordType { fields, optional }.push_node( + Span { + start: span.start, + end: span_end, + }, + &mut self.compiler, + )); } let mut args = None; @@ -963,33 +1062,46 @@ impl Parser { } else { false }; - self.create_node( + + Some( AstNode::Type { name, args, optional, - }, - span.start, - span.end, // FIXME: this uses the end of the name as its end + } + .push_node( + Span { + start: span.start, + end: span.end, + }, + &mut self.compiler, + ), ) } else { - self.error("expect name") + self.error("expect name"); + None } } - pub fn in_out_type(&mut self) -> NodeId { + pub fn in_out_type(&mut self) -> Option { let _span = span!(); let span_start = self.position(); - let in_ty = self.typename(); + let in_ty = self.typename()?; self.thin_arrow(); - let out_ty = self.typename(); + let out_ty = self.typename()?; let span_end = self.position(); - self.create_node(AstNode::InOutType(in_ty, out_ty), span_start, span_end) + Some(AstNode::InOutType(in_ty, out_ty).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } - pub fn in_out_types(&mut self) -> NodeId { + pub fn in_out_types(&mut self) -> Option { let _span = span!(); self.colon(); @@ -1009,21 +1121,33 @@ impl Parser { continue; } - output.push(self.in_out_type()); + output.push(self.in_out_type()?); } self.rsquare(); let span_end = self.position(); - self.create_node(AstNode::InOutTypes(output), span_start, span_end) + Some(AstNode::InOutTypes(output).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } else { - let ty = self.in_out_type(); - let span = self.compiler.get_span(ty); - self.create_node(AstNode::InOutTypes(vec![ty]), span.start, span.end) + let ty = self.in_out_type()?; + let span = ty.get_span(&self.compiler); + Some(AstNode::InOutType(ty, ty).push_node( + Span { + start: span.start, + end: span.end, + }, + &mut self.compiler, + )) } } - pub fn def_statement(&mut self) -> NodeId { + pub fn def_statement(&mut self) -> Option { let _span = span!(); let span_start = self.position(); @@ -1040,29 +1164,38 @@ impl Parser { let flag_name = self.compiler.get_span_contents_manual(span.start, span.end); if flag_name == b"env" { if has_env_flag { - return self.error("duplicated --env flag"); + self.error("duplicated --env flag"); + return None; } has_env_flag = true; } else if flag_name == b"wrapped" { if has_wrapped_flag { - return self.error("duplicated --wrapped flag"); + self.error("duplicated --wrapped flag"); + return None; } has_wrapped_flag = true } else { - return self.error("expect --env or --wrapped"); + self.error("expect --env or --wrapped"); + return None; } self.tokens.advance(); } - _ => return self.error("incomplete flag name"), + _ => { + self.error("incomplete flag name"); + return None; + } } } let name = match self.tokens.peek() { - (Token::Bareword, span) => self.advance_node(AstNode::Name, span), + (Token::Bareword, span) => NodeIndexer::Name(self.advance_node(NameNode, span)), (Token::DoubleQuotedString | Token::SingleQuotedString, span) => { - self.advance_node(AstNode::String, span) + NodeIndexer::String(self.advance_node(StringNode, span)) + } + _ => { + self.error("expected def name"); + return None; } - _ => return self.error("expected def name"), }; let type_params = if self.is_less_than() { @@ -1073,7 +1206,7 @@ impl Parser { let params = self.signature_params(ParamsContext::Squares); let in_out_types = if self.is_colon() { - Some(self.in_out_types()) + Some(self.in_out_types()?) } else { None }; @@ -1081,8 +1214,8 @@ impl Parser { let span_end = self.get_span_end(block); - self.create_node( - AstNode::Def { + Some( + StatementNode::Def { name, type_params, params, @@ -1090,47 +1223,61 @@ impl Parser { block, env: has_env_flag, wrapped: has_wrapped_flag, - }, - span_start, - span_end, + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), ) } - pub fn extern_statement(&mut self) -> NodeId { + pub fn extern_statement(&mut self) -> Option { let _span = span!(); let span_start = self.position(); self.keyword(b"extern"); let name = match self.tokens.peek() { - (Token::Bareword, span) => self.advance_node(AstNode::Name, span), + (Token::Bareword, span) => NodeIndexer::Name(self.advance_node(NameNode, span)), (Token::DoubleQuotedString | Token::SingleQuotedString, span) => { - self.advance_node(AstNode::String, span) + NodeIndexer::String(self.advance_node(StringNode, span)) + } + _ => { + self.error("expected def name"); + return None; } - _ => return self.error("expected def name"), }; let params = self.signature_params(ParamsContext::Squares); let span_end = self.position(); - self.create_node(AstNode::Extern { name, params }, span_start, span_end) + Some(StatementNode::Extern { name, params }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } // TODO: Deduplicate code between let/mut/const assignments - pub fn let_statement(&mut self) -> NodeId { + pub fn let_statement(&mut self) -> Option { let _span = span!(); let is_mutable = false; let span_start = self.position(); self.keyword(b"let"); - let variable_name = self.variable_decl(); + let variable_name = self.variable_decl()?; let ty = if self.is_colon() { // We have a type self.colon(); - Some(self.typename()) + Some(self.typename()?) } else { None }; @@ -1141,33 +1288,38 @@ impl Parser { let span_end = self.get_span_end(initializer); - self.create_node( - AstNode::Let { + Some( + StatementNode::Let { variable_name, ty, initializer, is_mutable, - }, - span_start, - span_end, + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), ) } // TODO: Deduplicate code between let/mut/const assignments - pub fn mut_statement(&mut self) -> NodeId { + pub fn mut_statement(&mut self) -> Option { let _span = span!(); let is_mutable = true; let span_start = self.position(); self.keyword(b"mut"); - let variable_name = self.variable_decl(); + let variable_name = self.variable_decl()?; let ty = if self.is_colon() { // We have a type self.colon(); - Some(self.typename()) + Some(self.typename()?) } else { None }; @@ -1178,15 +1330,20 @@ impl Parser { let span_end = self.get_span_end(initializer); - self.create_node( - AstNode::Let { + Some( + StatementNode::Let { variable_name, ty, initializer, is_mutable, - }, - span_start, - span_end, + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), ) } @@ -1202,7 +1359,7 @@ impl Parser { } } - pub fn block(&mut self, context: BlockContext) -> NodeId { + pub fn block(&mut self, context: BlockContext) -> BlockId { let _span = span!(); let span_start = self.position(); @@ -1272,7 +1429,7 @@ impl Parser { ) } - pub fn while_statement(&mut self) -> NodeId { + pub fn while_statement(&mut self) -> StatementNodeId { let _span = span!(); let span_start = self.position(); self.keyword(b"while"); @@ -1287,43 +1444,60 @@ impl Parser { let block = self.block(BlockContext::Curlies); let span_end = self.get_span_end(block); - self.create_node(AstNode::While { condition, block }, span_start, span_end) + StatementNode::While { condition, block }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ) } - pub fn for_statement(&mut self) -> NodeId { + pub fn for_statement(&mut self) -> Option { let _span = span!(); let span_start = self.position(); self.keyword(b"for"); - let variable = self.variable_decl(); + let variable = self.variable_decl()?; self.keyword(b"in"); let range = self.simple_expression(BarewordContext::String); let block = self.block(BlockContext::Curlies); let span_end = self.get_span_end(block); - self.create_node( - AstNode::For { + Some( + StatementNode::For { variable, range, block, - }, - span_start, - span_end, + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), ) } - pub fn loop_statement(&mut self) -> NodeId { + pub fn loop_statement(&mut self) -> StatementNodeId { let _span = span!(); let span_start = self.position(); self.keyword(b"loop"); let block = self.block(BlockContext::Curlies); let span_end = self.get_span_end(block); - self.create_node(AstNode::Loop { block }, span_start, span_end) + StatementNode::Loop { block }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ) } - pub fn return_statement(&mut self) -> NodeId { + pub fn return_statement(&mut self) -> StatementNodeId { let _span = span!(); let span_start = self.position(); let span_end; @@ -1339,44 +1513,69 @@ impl Parser { None }; - self.create_node(AstNode::Return(ret_val), span_start, span_end) + StatementNode::Return(ret_val).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ) } - pub fn continue_statement(&mut self) -> NodeId { + pub fn continue_statement(&mut self) -> StatementNodeId { let _span = span!(); let span_start = self.position(); self.keyword(b"continue"); let span_end = span_start + b"continue".len(); - self.create_node(AstNode::Continue, span_start, span_end) + StatementNode::Continue.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ) } - pub fn break_statement(&mut self) -> NodeId { + pub fn break_statement(&mut self) -> StatementNodeId { let _span = span!(); let span_start = self.position(); self.keyword(b"break"); let span_end = span_start + b"break".len(); - self.create_node(AstNode::Break, span_start, span_end) + StatementNode::Break.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ) } - pub fn alias_statement(&mut self) -> NodeId { + pub fn alias_statement(&mut self) -> Option { let _span = span!(); let span_start = self.position(); self.keyword(b"alias"); let new_name = if self.is_string() { - self.string() + NodeIndexer::String(self.string()?) } else { - self.name() + NodeIndexer::Name(self.name()?) }; self.equals(); - let old_name = if self.is_string() { - self.string() + let (old_name, span_end) = if self.is_string() { + let s = self.string()?; + (NodeIndexer::String(s), self.get_span_end(s)) } else { - self.name() + let s = self.name()?; + (NodeIndexer::Name(s), self.get_span_end(s)) }; - let span_end = self.get_span_end(old_name); - self.create_node(AstNode::Alias { new_name, old_name }, span_start, span_end) + Some(StatementNode::Alias { new_name, old_name }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } pub fn is_operator(&mut self) -> bool { @@ -1558,7 +1757,7 @@ impl Parser { || self.is_name() } - pub fn error_on_node(&mut self, message: impl Into, node_id: NodeId) { + pub fn error_on_node(&mut self, message: impl Into, node_id: NodeIndexer) { self.compiler.errors.push(SourceError { message: message.into(), node_id, @@ -1566,29 +1765,19 @@ impl Parser { }); } - pub fn error(&mut self, message: impl Into) -> NodeId { + pub fn error(&mut self, message: impl Into) { let (token, span) = self.tokens.peek(); if token != Token::Eof { self.tokens.advance(); } - let node_id = self.create_node(AstNode::Garbage, span.start, span.end); + let node_id = NodeIndexer::General(AstNode::Garbage.push_node(span, &mut self.compiler)); self.compiler.errors.push(SourceError { message: message.into(), node_id, severity: Severity::Error, }); - - node_id - } - - pub fn create_node(&mut self, ast_node: AstNode, span_start: usize, span_end: usize) -> NodeId { - self.compiler.spans.push(Span { - start: span_start, - end: span_end, - }); - self.compiler.push_node(ast_node) } pub fn lparen(&mut self) { From 90e37eefcca0c1f22587c0dbbcb18ead974c773c Mon Sep 17 00:00:00 2001 From: Wind Date: Thu, 12 Mar 2026 18:17:49 +0800 Subject: [PATCH 05/12] finish parser change --- src/ast_nodes.rs | 27 ++- src/parser.rs | 510 +++++++++++++++++++++++++++-------------------- 2 files changed, 315 insertions(+), 222 deletions(-) diff --git a/src/ast_nodes.rs b/src/ast_nodes.rs index a534817..b10110c 100644 --- a/src/ast_nodes.rs +++ b/src/ast_nodes.rs @@ -19,13 +19,19 @@ pub struct VariableNodeId(pub usize); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct VariableNode; +#[derive(Debug, Clone)] +pub enum StatementOrExpression { + Statement(StatementNodeId), + Expression(ExpressionNodeId), +} + #[derive(Debug, Clone)] pub struct Block { - pub nodes: Vec, + pub nodes: Vec, } impl Block { - pub fn new(nodes: Vec) -> Block { + pub fn new(nodes: Vec) -> Block { Block { nodes } } } @@ -92,13 +98,13 @@ pub enum ExpressionNode { value: NodeId, }, BinaryOp { - lhs: NodeId, + lhs: ExpressionNodeId, op: NodeId, - rhs: NodeId, + rhs: ExpressionNodeId, }, Range { - lhs: NodeId, - rhs: NodeId, + lhs: ExpressionNodeId, + rhs: ExpressionNodeId, }, List(Vec), Table { @@ -109,8 +115,8 @@ pub enum ExpressionNode { pairs: Vec<(ExpressionNodeId, ExpressionNodeId)>, }, MemberAccess { - target: NodeId, - field: NodeId, + target: ExpressionNodeId, + field: ExpressionNodeId, }, Block(BlockId), // Pipeline is also an expression, and it contains a list of expressions. @@ -244,9 +250,10 @@ pub enum AstNode { DivideAssignment, AppendAssignment, - Params(Vec), + TypeParams(Vec), + Params(Vec), Param { - name: NodeId, + name: NameNodeId, ty: Option, }, InOutTypes(Vec), diff --git a/src/parser.rs b/src/parser.rs index 7aa467d..b92cc25 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,7 +1,7 @@ use crate::ast_nodes::{ AstNode, Block, BlockId, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, NodeId, - NodeIndexer, Pipeline, PipelineId, StatementNode, StatementNodeId, StringNode, StringNodeId, - Tmp, Tmp1, VariableNode, VariableNodeId, + NodeIndexer, Pipeline, PipelineId, StatementNode, StatementNodeId, StatementOrExpression, + StringNode, StringNodeId, Tmp, Tmp1, VariableNode, VariableNodeId, }; use crate::compiler::{Compiler, RollbackPoint, Span}; use crate::errors::{Severity, SourceError}; @@ -102,17 +102,21 @@ impl Parser { pub fn parse(mut self) -> Compiler { let _span = span!(); - self.block(BlockContext::Bare); + let _ = self.block(BlockContext::Bare); self.compiler } - pub fn expression(&mut self) -> ExpressionNodeId { + pub fn expression(&mut self) -> Option { let _span = span!(); self.math_expression(false) } - fn pipeline(&mut self, first_element: NodeId, span_start: usize) -> NodeId { + fn pipeline( + &mut self, + first_element: ExpressionNodeId, + span_start: usize, + ) -> Option { let mut expressions = vec![first_element]; while self.is_pipe() { self.pipe(); @@ -120,46 +124,67 @@ impl Parser { if self.is_newline() { self.tokens.advance() } - expressions.push(self.expression()); + expressions.push(self.expression()?); } - self.compiler.pipelines.push(Pipeline::new(expressions)); let span_end = self.position(); - self.create_node( - AstNode::Pipeline(PipelineId(self.compiler.pipelines.len() - 1)), - span_start, - span_end, - ) + let pipeline_id = Pipeline::new(expressions).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ); + // pipeline itself is an expression, so we push an expression node for it. + // It may make more overhead but it simpifies this `pipeline` interface. + Some(ExpressionNode::Pipeline(pipeline_id).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } - pub fn pipeline_or_expression_or_assignment(&mut self) -> NodeId { + + pub fn pipeline_or_expression_or_assignment(&mut self) -> Option { // get the first expression let _span = span!(); let span_start = self.position(); - let first = self.math_expression(true); - let first_id = first.get_node_id(); - if let AssignmentOrExpression::Assignment(_) = &first { - return first_id; + let first = self.math_expression(true)?; + if let ExpressionNode::BinaryOp { op, .. } = self.compiler.get_node(first) { + if matches!( + self.compiler.get_node(*op), + AstNode::Assignment + | AstNode::AddAssignment + | AstNode::SubtractAssignment + | AstNode::MultiplyAssignment + | AstNode::DivideAssignment + | AstNode::AppendAssignment + ) { + return Some(first); + } } // pipeline with one element is an expression actually if !self.is_pipe() { - return first_id; + return Some(first); } - self.pipeline(first_id, span_start) + self.pipeline(first, span_start) } - pub fn pipeline_or_expression(&mut self) -> ExpressionNodeId { + // Can be a pipeline or expression + pub fn pipeline_or_expression(&mut self) -> Option { let _span = span!(); let span_start = self.position(); - let first_id = self.expression(); + let first_id = self.expression()?; // pipeline with one element is an expression actually. if !self.is_pipe() { - return first_id; + return Some(first_id); } self.pipeline(first_id, span_start) } - fn math_expression(&mut self, allow_assignment: bool) -> ExpressionNode { + fn math_expression(&mut self, allow_assignment: bool) -> Option { let _span = span!(); - let mut expr_stack = Vec::<(NodeId, NodeId)>::new(); + let mut expr_stack = Vec::<(NodeId, ExpressionNodeId)>::new(); let mut last_prec = 1000000; @@ -167,63 +192,69 @@ impl Parser { // Check for special forms if self.is_keyword(b"if") { - return AssignmentOrExpression::Expression(self.if_expression()); + return self.if_expression(); } else if self.is_keyword(b"match") { - return AssignmentOrExpression::Expression(self.match_expression()); + return self.match_expression(); } else if self.is_keyword(b"try") { - return AssignmentOrExpression::Expression(self.try_expression()); + return self.try_expression(); } // TODO // } else if self.is_keyword(b"where") { // } // Otherwise assume a math expression - let mut leftmost = self.simple_expression(BarewordContext::Call); + let mut leftmost = self.simple_expression(BarewordContext::Call)?; if self.is_equals() { if !allow_assignment { self.error("assignment found in expression"); } - let op = self.operator(); + let op = self.operator()?; - let rhs = self.pipeline_or_expression(); + let rhs = self.pipeline_or_expression()?; let span_end = self.get_span_end(rhs); - return AssignmentOrExpression::Assignment(self.create_node( - AstNode::BinaryOp { + return Some( + ExpressionNode::BinaryOp { lhs: leftmost, op, rhs, - }, - span_start, - span_end, - )); + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), + ); } while self.has_tokens() { if self.is_operator() { let missing_space_before_op = !self.is_horizontal_space(); - let op = self.operator(); + let op = self.operator()?; let missing_space_after_op = !self.is_horizontal_space(); if missing_space_before_op { - self.error_on_node("missing space before operator", op); + self.error_on_node("missing space before operator", NodeIndexer::General(op)); } if missing_space_after_op { - self.error_on_node("missing space after operator", op); + self.error_on_node("missing space after operator", NodeIndexer::General(op)); } let op_prec = self.operator_precedence(op); if op_prec == ASSIGNMENT_PRECEDENCE && !allow_assignment { - self.error_on_node("assignment found in expression", op); + self.error_on_node("assignment found in expression", NodeIndexer::General(op)); } let rhs = if self.is_simple_expression() { - self.simple_expression(BarewordContext::Call) + self.simple_expression(BarewordContext::Call)? } else { - self.error("incomplete math expression") + self.error("incomplete math expression"); + return None; }; while op_prec <= last_prec { @@ -241,10 +272,12 @@ impl Parser { let lhs = expr_stack.last_mut().map_or(&mut leftmost, |l| &mut l.1); let (span_start, span_end) = self.spanning(*lhs, rhs); - *lhs = self.create_node( - AstNode::BinaryOp { lhs: *lhs, op, rhs }, - span_start, - span_end, + *lhs = ExpressionNode::BinaryOp { lhs: *lhs, op, rhs }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, ); } @@ -261,17 +294,22 @@ impl Parser { let (span_start, span_end) = self.spanning(*lhs, rhs); - *lhs = self.create_node( - AstNode::BinaryOp { lhs: *lhs, op, rhs }, - span_start, - span_end, + *lhs = ExpressionNode::BinaryOp { lhs: *lhs, op, rhs }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, ); } - AssignmentOrExpression::Expression(leftmost) + Some(leftmost) } - pub fn simple_expression(&mut self, bareword_context: BarewordContext) -> ExpressionNodeId { + pub fn simple_expression( + &mut self, + bareword_context: BarewordContext, + ) -> Option { let _span = span!(); // skip comments and newlines @@ -288,38 +326,51 @@ impl Parser { Token::LParen => { self.tokens.advance(); if self.tokens.peek_token() == Token::RParen { - self.error("use null instead of ()") + self.error("use null instead of ()"); + return None; } else { - let output = self.expression(); + let output = self.expression()?; self.rparen(); output } } Token::LSquare => self.list_or_table(), - Token::Int => self.advance_node(AstNode::Int, span), - Token::Float => self.advance_node(AstNode::Float, span), - Token::DoubleQuotedString => self.advance_node(AstNode::String, span), - Token::SingleQuotedString => self.advance_node(AstNode::String, span), - Token::Dollar => self.variable(), + Token::Int => self.advance_node(ExpressionNode::Int, span), + Token::Float => self.advance_node(ExpressionNode::Float, span), + Token::DoubleQuotedString => { + let string_node_id = self.advance_node(StringNode, span); + self.advance_node(ExpressionNode::String(string_node_id), span) + } + Token::SingleQuotedString => { + let string_node_id = self.advance_node(StringNode, span); + self.advance_node(ExpressionNode::String(string_node_id), span) + } + Token::Dollar => { + let var_id = self.variable()?; + self.advance_node(ExpressionNode::Variable(var_id), span) + } Token::Bareword => match self.compiler.get_span_contents_manual(span.start, span.end) { - b"true" => self.advance_node(AstNode::True, span), - b"false" => self.advance_node(AstNode::False, span), - b"null" => self.advance_node(AstNode::Null, span), + b"true" => self.advance_node(ExpressionNode::True, span), + b"false" => self.advance_node(ExpressionNode::False, span), + b"null" => self.advance_node(ExpressionNode::Null, span), _ => match bareword_context { BarewordContext::String => { - let node_id = self.name(); - self.compiler.ast_nodes[node_id.0] = AstNode::String; - node_id + // it's a string, so just make a string. + let string_node_id = self.advance_node(StringNode, span); + self.advance_node(ExpressionNode::String(string_node_id), span) } BarewordContext::Call => self.call(), }, }, - _ => self.error("incomplete expression"), + _ => { + self.error("incomplete expression"); + return None; + } }; loop { if self.is_horizontal_space() { - return expr; + return Some(expr); } else if self.is_dotdot() { // Range self.tokens.advance(); @@ -329,13 +380,18 @@ impl Parser { // // TODO: tweak the garbage location. self.error("incomplete range"); - return expr; + return Some(expr); } else { - let rhs = self.simple_expression(BarewordContext::String); + let rhs = self.simple_expression(BarewordContext::String)?; let span_end = self.get_span_end(rhs); - expr = - self.create_node(AstNode::Range { lhs: expr, rhs }, span_start, span_end); + ExpressionNode::Range { lhs: expr, rhs }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ); } } else if self.is_dot() { // Member access @@ -343,35 +399,35 @@ impl Parser { if self.is_horizontal_space() { self.error("missing path name"); - return expr; + return Some(expr); } - let name = self.name(); + let name = self.name()?; let field_or_call = if self.is_lparen() { - self.variable() + let var_id = self.variable()?; + self.advance_node( + ExpressionNode::Variable(var_id), + name.get_span(&self.compiler), + ) } else { - name + self.advance_node(ExpressionNode::Name(name), name.get_span(&self.compiler)) }; let span_end = self.get_span_end(field_or_call); - match self.compiler.get_node_mut(field_or_call) { - AstNode::Variable | AstNode::Name => { - expr = self.create_node( - AstNode::MemberAccess { - target: expr, - field: field_or_call, - }, - span_start, - span_end, - ); - } - _ => { - self.error("expected field"); - } + expr = ExpressionNode::MemberAccess { + target: expr, + field: field_or_call, } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ); } else { - return expr; + return Some(expr); } } } @@ -435,7 +491,7 @@ impl Parser { node.push_node(span, &mut self.compiler) } - pub fn call(&mut self) -> ExpressionNodeId { + pub fn call(&mut self) -> Option { let _span = span!(); let mut head = vec![self.call_name()]; let mut parts = vec![]; @@ -455,22 +511,22 @@ impl Parser { // TODO: Add flags is_head = false; - let arg_id = self.simple_expression(BarewordContext::String); + let arg_id = self.simple_expression(BarewordContext::String)?; parts.push(arg_id); } let span_end = self.position(); - ExpressionNode::Call { head, parts }.push_node( + Some(ExpressionNode::Call { head, parts }.push_node( Span { start: span_start, end: span_end, }, &mut self.compiler, - ) + )) } - pub fn list_or_table(&mut self) -> ExpressionNodeId { + pub fn list_or_table(&mut self) -> Option { let _span = span!(); let span_start = self.position(); let mut is_table = false; @@ -499,7 +555,7 @@ impl Parser { self.tokens.advance(); is_table = true; } else if self.is_simple_expression() { - items.push(self.simple_expression(BarewordContext::String)); + items.push(self.simple_expression(BarewordContext::String)?); } else { self.error("expected list item"); if self.is_eof() { @@ -511,29 +567,31 @@ impl Parser { if is_table { let header = items.remove(0); - ExpressionNode::Table { - header, - rows: items, - } - .push_node( - Span { - start: span_start, - end: span_end, - }, - &mut self.compiler, + Some( + ExpressionNode::Table { + header, + rows: items, + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), ) } else { - ExpressionNode::List(items).push_node( + Some(ExpressionNode::List(items).push_node( Span { start: span_start, end: span_end, }, &mut self.compiler, - ) + )) } } - pub fn record_or_closure(&mut self) -> ExpressionNodeId { + pub fn record_or_closure(&mut self) -> Option { let _span = span!(); let span_start = self.position(); let mut span_end = self.position(); // TODO: make sure we only initialize it expectedly @@ -548,18 +606,18 @@ impl Parser { // Explicit closure case if self.is_pipe() { - let params = Some(self.signature_params(ParamsContext::Pipes)); - let block = self.block(BlockContext::Closure); + let params = self.signature_params(ParamsContext::Pipes); + let block = self.block(BlockContext::Closure)?; self.rcurly(); span_end = self.position(); - return ExpressionNode::Closure { params, block }.push_node( + return Some(ExpressionNode::Closure { params, block }.push_node( Span { start: span_start, end: span_end, }, &mut self.compiler, - ); + )); } let rollback_point = self.get_rollback_point(); @@ -570,7 +628,7 @@ impl Parser { span_end = self.position(); break; } - let key = self.simple_expression(BarewordContext::String); + let key = self.simple_expression(BarewordContext::String)?; self.skip_newlines(); if first_pass && !self.is_colon() { is_closure = true; @@ -578,7 +636,7 @@ impl Parser { } self.colon(); self.skip_newlines(); - let val = self.simple_expression(BarewordContext::String); + let val = self.simple_expression(BarewordContext::String)?; items.push((key, val)); first_pass = false; @@ -593,30 +651,32 @@ impl Parser { if is_closure { self.apply_rollback(rollback_point); - let block = self.block(BlockContext::Closure); + let block = self.block(BlockContext::Closure)?; self.rcurly(); span_end = self.position(); - ExpressionNode::Closure { - params: None, - block, - } - .push_node( - Span { - start: span_start, - end: span_end, - }, - &mut self.compiler, + Some( + ExpressionNode::Closure { + params: None, + block, + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), ) } else { - ExpressionNode::Record { pairs: items }.push_node( + Some(ExpressionNode::Record { pairs: items }.push_node( Span { start: span_start, end: span_end, }, &mut self.compiler, - ) + )) } } @@ -670,10 +730,10 @@ impl Parser { self.compiler.get_node(operator).precedence() } - pub fn spanning(&mut self, from: NodeId, to: NodeId) -> (usize, usize) { + pub fn spanning(&mut self, from: T, to: T) -> (usize, usize) { ( - self.compiler.spans[from.0].start, - self.compiler.spans[to.0].end, + from.get_span(&self.compiler).start, + to.get_span(&self.compiler).end, ) } @@ -731,7 +791,7 @@ impl Parser { let span_end; self.keyword(b"match"); - let target = self.simple_expression(BarewordContext::String); + let target = self.simple_expression(BarewordContext::String)?; let mut match_arms = vec![]; @@ -748,7 +808,7 @@ impl Parser { self.rcurly(); break; } else if self.is_simple_expression() { - let pattern = self.simple_expression(BarewordContext::String); + let pattern = self.simple_expression(BarewordContext::String)?; if !self.is_thick_arrow() { self.error("expected thick arrow (=>) between match cases"); @@ -756,7 +816,7 @@ impl Parser { } self.tokens.advance(); - let pattern_result = self.simple_expression(BarewordContext::String); + let pattern_result = self.simple_expression(BarewordContext::String)?; if self.is_comma() { self.tokens.advance(); @@ -787,10 +847,10 @@ impl Parser { self.keyword(b"if"); - let condition = self.expression(); + let condition = self.expression()?; self.skip_newlines(); - let then_block = self.block(BlockContext::Curlies); + let then_block = self.block(BlockContext::Curlies)?; self.skip_newlines(); let else_block = if self.is_keyword(b"else") { @@ -806,7 +866,7 @@ impl Parser { span_end = self.get_span_end(match_exp); NodeIndexer::Expression(match_exp) } else { - let exp = self.block(BlockContext::Curlies); + let exp = self.block(BlockContext::Curlies)?; span_end = self.get_span_end(exp); NodeIndexer::Block(exp) }; @@ -832,13 +892,13 @@ impl Parser { ) } - pub fn try_expression(&mut self) -> ExpressionNodeId { + pub fn try_expression(&mut self) -> Option { let _span = span!(); let span_start = self.position(); self.keyword(b"try"); - let try_block = self.block(BlockContext::Curlies); + let try_block = self.block(BlockContext::Curlies)?; let mut span_end = self.get_span_end(try_block); self.skip_newlines(); @@ -847,7 +907,7 @@ impl Parser { self.tokens.advance(); self.skip_newlines(); - let block = self.block(BlockContext::Curlies); + let block = self.block(BlockContext::Curlies)?; span_end = self.get_span_end(block); Some(block) @@ -860,30 +920,32 @@ impl Parser { self.tokens.advance(); self.skip_newlines(); - let block = self.block(BlockContext::Curlies); + let block = self.block(BlockContext::Curlies)?; span_end = self.get_span_end(block); Some(block) } else { None }; - ExpressionNode::Try { - try_block, - catch_block, - finally_block, - } - .push_node( - Span { - start: span_start, - end: span_end, - }, - &mut self.compiler, + Some( + ExpressionNode::Try { + try_block, + catch_block, + finally_block, + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), ) } // directly ripped from `type_params` just changed delimiters // FIXME: simplify if appropriate - pub fn signature_params(&mut self, params_context: ParamsContext) -> NodeId { + pub fn signature_params(&mut self, params_context: ParamsContext) -> Option { let _span = span!(); let span_start = self.position(); let span_end; @@ -920,26 +982,31 @@ impl Parser { continue; } - let name = self.name(); + let name = self.name()?; let ty = if self.is_colon() { // We have a type self.colon(); - Some(self.typename()) + Some(self.typename()?) } else { None }; - let name_span = self.compiler.spans[name.0]; + let name_span = name.get_span(&self.compiler); let param_span_end = if let Some(ty_id) = ty { - self.compiler.spans[ty_id.0].end + ty_id.get_span(&self.compiler).end } else { name_span.end }; - let param = - self.create_node(AstNode::Param { name, ty }, name_span.start, param_span_end); + let param = AstNode::Param { name, ty }.push_node( + Span { + start: name_span.start, + end: param_span_end, + }, + &mut self.compiler, + ); // output.push(self.name()); output.push(param); @@ -956,7 +1023,13 @@ impl Parser { output }; - self.create_node(AstNode::Params(param_list), span_start, span_end) + Some(AstNode::Params(param_list).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } pub fn type_params(&mut self) -> NodeId { @@ -984,7 +1057,7 @@ impl Parser { let span_end = self.position() + 1; self.greater_than(); - AstNode::Params(param_list).push_node( + AstNode::TypeParams(param_list).push_node( Span { start: span_start, end: span_end, @@ -993,7 +1066,7 @@ impl Parser { ) } - pub fn type_args(&mut self) -> NodeId { + pub fn type_args(&mut self) -> Option { let _span = span!(); let span_start = self.position(); let span_end; @@ -1012,7 +1085,7 @@ impl Parser { continue; } - output.push(self.typename()); + output.push(self.typename()?); } span_end = self.position() + 1; @@ -1021,7 +1094,13 @@ impl Parser { output }; - self.create_node(AstNode::TypeArgs(arg_list), span_start, span_end) + Some(AstNode::TypeArgs(arg_list).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } pub fn typename(&mut self) -> Option { @@ -1031,7 +1110,7 @@ impl Parser { let name_text = name.get_span_contents(&self.compiler); if name_text == b"record" { - let fields = self.signature_params(ParamsContext::Angles); + let fields = self.signature_params(ParamsContext::Angles)?; let optional = if self.is_question_mark() { // We have an optional type self.tokens.advance(); @@ -1052,7 +1131,7 @@ impl Parser { let mut args = None; if self.is_less_than() { // We have generics - args = Some(self.type_args()); + args = Some(self.type_args()?); } let optional = if self.is_question_mark() { @@ -1204,13 +1283,13 @@ impl Parser { None }; - let params = self.signature_params(ParamsContext::Squares); + let params = self.signature_params(ParamsContext::Squares)?; let in_out_types = if self.is_colon() { Some(self.in_out_types()?) } else { None }; - let block = self.block(BlockContext::Curlies); + let block = self.block(BlockContext::Curlies)?; let span_end = self.get_span_end(block); @@ -1251,7 +1330,7 @@ impl Parser { } }; - let params = self.signature_params(ParamsContext::Squares); + let params = self.signature_params(ParamsContext::Squares)?; let span_end = self.position(); Some(StatementNode::Extern { name, params }.push_node( @@ -1284,7 +1363,7 @@ impl Parser { self.equals(); - let initializer = self.pipeline_or_expression(); + let initializer = self.pipeline_or_expression()?; let span_end = self.get_span_end(initializer); @@ -1326,7 +1405,7 @@ impl Parser { self.equals(); - let initializer = self.pipeline_or_expression(); + let initializer = self.pipeline_or_expression()?; let span_end = self.get_span_end(initializer); @@ -1359,7 +1438,7 @@ impl Parser { } } - pub fn block(&mut self, context: BlockContext) -> BlockId { + pub fn block(&mut self, context: BlockContext) -> Option { let _span = span!(); let span_start = self.position(); @@ -1368,6 +1447,9 @@ impl Parser { self.lcurly(); } + // NOTE: Here we make early returns on statements, so there is only one error pop up. + // We can also consider parsing as much as possible and collect all errors, but that might + // cause a lot of cascading errors, so let's see how this goes first. while self.has_tokens() { if self.is_rcurly() && context == BlockContext::Curlies { self.rcurly(); @@ -1379,57 +1461,61 @@ impl Parser { self.tokens.advance(); continue; } else if self.is_keyword(b"def") { - code_body.push(self.def_statement()); + code_body.push(StatementOrExpression::Statement(self.def_statement()?)); } else if self.is_keyword(b"let") { - code_body.push(self.let_statement()); + code_body.push(StatementOrExpression::Statement(self.let_statement()?)); } else if self.is_keyword(b"mut") { - code_body.push(self.mut_statement()); + code_body.push(StatementOrExpression::Statement(self.mut_statement()?)); } else if self.is_keyword(b"while") { - code_body.push(self.while_statement()); + code_body.push(StatementOrExpression::Statement(self.while_statement()?)); } else if self.is_keyword(b"for") { - code_body.push(self.for_statement()); + code_body.push(StatementOrExpression::Statement(self.for_statement()?)); } else if self.is_keyword(b"loop") { - code_body.push(self.loop_statement()); + code_body.push(StatementOrExpression::Statement(self.loop_statement()?)); } else if self.is_keyword(b"return") { - code_body.push(self.return_statement()); + code_body.push(StatementOrExpression::Statement(self.return_statement()?)); } else if self.is_keyword(b"continue") { - code_body.push(self.continue_statement()); + code_body.push(StatementOrExpression::Statement(self.continue_statement())); } else if self.is_keyword(b"break") { - code_body.push(self.break_statement()); + code_body.push(StatementOrExpression::Statement(self.break_statement())); } else if self.is_keyword(b"alias") { - code_body.push(self.alias_statement()); + code_body.push(StatementOrExpression::Statement(self.alias_statement()?)); } else if self.is_keyword(b"extern") { - code_body.push(self.extern_statement()); + code_body.push(StatementOrExpression::Statement(self.extern_statement()?)); } else { let exp_span_start = self.position(); - let pipeline = self.pipeline_or_expression_or_assignment(); + let pipeline = self.pipeline_or_expression_or_assignment()?; let exp_span_end = self.get_span_end(pipeline); if self.is_semicolon() { // This is a statement, not an expression self.tokens.advance(); - code_body.push(self.create_node( - AstNode::Statement(pipeline), - exp_span_start, - exp_span_end, - )) + code_body.push(StatementOrExpression::Statement( + StatementNode::Expr(pipeline).push_node( + Span { + start: exp_span_start, + end: exp_span_end, + }, + &mut self.compiler, + ), + )); } else { - code_body.push(pipeline); + code_body.push(StatementOrExpression::Expression(pipeline)); } } } - self.compiler.blocks.push(Block::new(code_body)); let span_end = self.position(); - - self.create_node( - AstNode::Block(BlockId(self.compiler.blocks.len() - 1)), - span_start, - span_end, - ) + Some(Block { nodes: code_body }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } - pub fn while_statement(&mut self) -> StatementNodeId { + pub fn while_statement(&mut self) -> Option { let _span = span!(); let span_start = self.position(); self.keyword(b"while"); @@ -1440,17 +1526,17 @@ impl Parser { self.tokens.advance(); } - let condition = self.expression(); - let block = self.block(BlockContext::Curlies); + let condition = self.expression()?; + let block = self.block(BlockContext::Curlies)?; let span_end = self.get_span_end(block); - StatementNode::While { condition, block }.push_node( + Some(StatementNode::While { condition, block }.push_node( Span { start: span_start, end: span_end, }, &mut self.compiler, - ) + )) } pub fn for_statement(&mut self) -> Option { @@ -1461,8 +1547,8 @@ impl Parser { let variable = self.variable_decl()?; self.keyword(b"in"); - let range = self.simple_expression(BarewordContext::String); - let block = self.block(BlockContext::Curlies); + let range = self.simple_expression(BarewordContext::String)?; + let block = self.block(BlockContext::Curlies)?; let span_end = self.get_span_end(block); Some( @@ -1481,23 +1567,23 @@ impl Parser { ) } - pub fn loop_statement(&mut self) -> StatementNodeId { + pub fn loop_statement(&mut self) -> Option { let _span = span!(); let span_start = self.position(); self.keyword(b"loop"); - let block = self.block(BlockContext::Curlies); + let block = self.block(BlockContext::Curlies)?; let span_end = self.get_span_end(block); - StatementNode::Loop { block }.push_node( + Some(StatementNode::Loop { block }.push_node( Span { start: span_start, end: span_end, }, &mut self.compiler, - ) + )) } - pub fn return_statement(&mut self) -> StatementNodeId { + pub fn return_statement(&mut self) -> Option { let _span = span!(); let span_start = self.position(); let span_end; @@ -1505,7 +1591,7 @@ impl Parser { self.keyword(b"return"); let ret_val = if self.is_expression() { - let expr = self.expression(); + let expr = self.expression()?; span_end = self.get_span_end(expr); Some(expr) } else { @@ -1513,13 +1599,13 @@ impl Parser { None }; - StatementNode::Return(ret_val).push_node( + Some(StatementNode::Return(ret_val).push_node( Span { start: span_start, end: span_end, }, &mut self.compiler, - ) + )) } pub fn continue_statement(&mut self) -> StatementNodeId { From ceb8a658f1a32ffd37f0cc5fdec5a4b4c237ba03 Mon Sep 17 00:00:00 2001 From: WindSoilder Date: Fri, 13 Mar 2026 22:26:14 +0800 Subject: [PATCH 06/12] make some minor fix --- src/compiler.rs | 9 +++++++-- src/parser.rs | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/compiler.rs b/src/compiler.rs index 4686025..15f36fb 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -48,7 +48,7 @@ impl Spanned { } #[derive(Clone, Debug)] -struct NodeSpans { +pub struct NodeSpans { nodes: Vec, // indexed by relative nodeId spans: Vec, } @@ -155,10 +155,10 @@ impl Compiler { name_nodes: NodeSpans::new(), expression_nodes: NodeSpans::new(), statement_nodes: NodeSpans::new(), + pipelines: NodeSpans::new(), node_types: vec![], indexer: vec![], blocks: NodeSpans::new(), - pipelines: vec![], source: vec![], file_offsets: vec![], @@ -221,6 +221,10 @@ impl Compiler { format!("{:?}", self.blocks.get_node(i.0)), self.blocks.get_span(i.0), ), + NodeIndexer::Pipeline(i) => ( + format!("{:?}", self.pipelines.get_node(i.0)), + self.pipelines.get_span(i.0), + ), }; result.push_str(&format!( "{}: {} ({} to {})", @@ -340,6 +344,7 @@ impl Compiler { NodeIndexer::Expression(i) => self.expression_nodes.get_span(i.0), NodeIndexer::Block(i) => self.blocks.get_span(i.0), NodeIndexer::Statement(i) => self.statement_nodes.get_span(i.0), + NodeIndexer::Pipeline(i) => self.pipelines.get_span(i.0), } } diff --git a/src/parser.rs b/src/parser.rs index b92cc25..11f175c 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -359,7 +359,7 @@ impl Parser { let string_node_id = self.advance_node(StringNode, span); self.advance_node(ExpressionNode::String(string_node_id), span) } - BarewordContext::Call => self.call(), + BarewordContext::Call => self.call()?, }, }, _ => { From 92c0d22c36ae1b423a4cf395a381aafbe2582a2f Mon Sep 17 00:00:00 2001 From: WindSoilder Date: Fri, 13 Mar 2026 22:26:31 +0800 Subject: [PATCH 07/12] introduce a new NameOrString --- src/ast_nodes.rs | 53 +++++++++++++++++++++++++++++++++++++++++++----- src/parser.rs | 26 ++++++++++++------------ 2 files changed, 61 insertions(+), 18 deletions(-) diff --git a/src/ast_nodes.rs b/src/ast_nodes.rs index b10110c..3b12780 100644 --- a/src/ast_nodes.rs +++ b/src/ast_nodes.rs @@ -19,12 +19,22 @@ pub struct VariableNodeId(pub usize); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct VariableNode; -#[derive(Debug, Clone)] +// A helper enum for block compoments. Compiler doesn't save +// this as an individual id. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum StatementOrExpression { Statement(StatementNodeId), Expression(ExpressionNodeId), } +// A helper enum for block compoments. Compiler doesn't save +// this as an individual id. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum NameOrString { + Name(NameNodeId), + String(StringNodeId), +} + #[derive(Debug, Clone)] pub struct Block { pub nodes: Vec, @@ -144,7 +154,7 @@ pub struct ExpressionNodeId(pub usize); pub enum StatementNode { // Definitions Def { - name: NodeIndexer, // can be string or name + name: NameOrString, type_params: Option, params: NodeId, in_out_types: Option, @@ -153,12 +163,12 @@ pub enum StatementNode { wrapped: bool, }, Extern { - name: NodeIndexer, // can be string or name + name: NameOrString, params: NodeId, }, Alias { - new_name: NodeIndexer, - old_name: NodeIndexer, + new_name: NameOrString, + old_name: NameOrString, }, Let { variable_name: VariableNodeId, @@ -285,6 +295,7 @@ pub trait Tmp { .get(span.start..span.end) .expect("internal error: missing source of span") } + fn into_indexer(self) -> NodeIndexer; } pub trait Tmp1 { @@ -306,6 +317,10 @@ impl Tmp for NameNodeId { fn get_span(&self, compiler: &Compiler) -> Span { compiler.name_nodes.get_span(self.0) } + + fn into_indexer(self) -> NodeIndexer { + NodeIndexer::Name(self) + } } impl Tmp1 for NameNode { @@ -336,6 +351,10 @@ impl Tmp for StringNodeId { fn get_span(&self, compiler: &Compiler) -> Span { compiler.string_nodes.get_span(self.0) } + + fn into_indexer(self) -> NodeIndexer { + NodeIndexer::String(self) + } } impl Tmp1 for StringNode { @@ -366,6 +385,10 @@ impl Tmp for VariableNodeId { fn get_span(&self, compiler: &Compiler) -> Span { compiler.variable_nodes.get_span(self.0) } + + fn into_indexer(self) -> NodeIndexer { + NodeIndexer::Variable(self) + } } impl Tmp1 for VariableNode { @@ -396,6 +419,10 @@ impl Tmp for BlockId { fn get_span(&self, compiler: &Compiler) -> Span { compiler.blocks.get_span(self.0) } + + fn into_indexer(self) -> NodeIndexer { + NodeIndexer::Block(self) + } } impl Tmp1 for Block { @@ -426,6 +453,10 @@ impl Tmp for StatementNodeId { fn get_span(&self, compiler: &Compiler) -> Span { compiler.statement_nodes.get_span(self.0) } + + fn into_indexer(self) -> NodeIndexer { + NodeIndexer::Statement(self) + } } impl Tmp1 for StatementNode { @@ -456,6 +487,10 @@ impl Tmp for PipelineId { fn get_span(&self, compiler: &Compiler) -> Span { compiler.pipelines.get_span(self.0) } + + fn into_indexer(self) -> NodeIndexer { + NodeIndexer::Pipeline(self) + } } impl Tmp1 for Pipeline { @@ -499,6 +534,10 @@ impl Tmp for ExpressionNodeId { fn get_span(&self, compiler: &Compiler) -> Span { compiler.expression_nodes.get_span(self.0) } + + fn into_indexer(self) -> NodeIndexer { + NodeIndexer::Expression(self) + } } impl Tmp1 for AstNode { type Output = NodeId; @@ -527,4 +566,8 @@ impl Tmp for NodeId { fn get_span(&self, compiler: &Compiler) -> Span { compiler.ast_nodes.get_span(self.0) } + + fn into_indexer(self) -> NodeIndexer { + NodeIndexer::General(self) + } } diff --git a/src/parser.rs b/src/parser.rs index 11f175c..4498a42 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,7 +1,7 @@ use crate::ast_nodes::{ - AstNode, Block, BlockId, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, NodeId, - NodeIndexer, Pipeline, PipelineId, StatementNode, StatementNodeId, StatementOrExpression, - StringNode, StringNodeId, Tmp, Tmp1, VariableNode, VariableNodeId, + AstNode, Block, BlockId, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, NameOrString, + NodeId, NodeIndexer, Pipeline, PipelineId, StatementNode, StatementNodeId, + StatementOrExpression, StringNode, StringNodeId, Tmp, Tmp1, VariableNode, VariableNodeId, }; use crate::compiler::{Compiler, RollbackPoint, Span}; use crate::errors::{Severity, SourceError}; @@ -322,7 +322,7 @@ impl Parser { let (token, span) = self.tokens.peek(); let mut expr = match token { - Token::LCurly => self.record_or_closure(), + Token::LCurly => self.record_or_closure()?, Token::LParen => { self.tokens.advance(); if self.tokens.peek_token() == Token::RParen { @@ -334,7 +334,7 @@ impl Parser { output } } - Token::LSquare => self.list_or_table(), + Token::LSquare => self.list_or_table()?, Token::Int => self.advance_node(ExpressionNode::Int, span), Token::Float => self.advance_node(ExpressionNode::Float, span), Token::DoubleQuotedString => { @@ -1267,9 +1267,9 @@ impl Parser { } let name = match self.tokens.peek() { - (Token::Bareword, span) => NodeIndexer::Name(self.advance_node(NameNode, span)), + (Token::Bareword, span) => NameOrString::Name(self.advance_node(NameNode, span)), (Token::DoubleQuotedString | Token::SingleQuotedString, span) => { - NodeIndexer::String(self.advance_node(StringNode, span)) + NameOrString::String(self.advance_node(StringNode, span)) } _ => { self.error("expected def name"); @@ -1320,9 +1320,9 @@ impl Parser { self.keyword(b"extern"); let name = match self.tokens.peek() { - (Token::Bareword, span) => NodeIndexer::Name(self.advance_node(NameNode, span)), + (Token::Bareword, span) => NameOrString::Name(self.advance_node(NameNode, span)), (Token::DoubleQuotedString | Token::SingleQuotedString, span) => { - NodeIndexer::String(self.advance_node(StringNode, span)) + NameOrString::String(self.advance_node(StringNode, span)) } _ => { self.error("expected def name"); @@ -1643,17 +1643,17 @@ impl Parser { let span_start = self.position(); self.keyword(b"alias"); let new_name = if self.is_string() { - NodeIndexer::String(self.string()?) + NameOrString::String(self.string()?) } else { - NodeIndexer::Name(self.name()?) + NameOrString::Name(self.name()?) }; self.equals(); let (old_name, span_end) = if self.is_string() { let s = self.string()?; - (NodeIndexer::String(s), self.get_span_end(s)) + (NameOrString::String(s), self.get_span_end(s)) } else { let s = self.name()?; - (NodeIndexer::Name(s), self.get_span_end(s)) + (NameOrString::Name(s), self.get_span_end(s)) }; Some(StatementNode::Alias { new_name, old_name }.push_node( Span { From 896ec16107230fb944481a070ad252a179d93fe1 Mon Sep 17 00:00:00 2001 From: WindSoilder Date: Mon, 16 Mar 2026 15:20:37 +0800 Subject: [PATCH 08/12] some typechecker change --- src/compiler.rs | 2 +- src/typechecker.rs | 352 ++++++++++++++++++++++++++++++++------------- 2 files changed, 257 insertions(+), 97 deletions(-) diff --git a/src/compiler.rs b/src/compiler.rs index 15f36fb..8a37941 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -117,7 +117,7 @@ pub struct Compiler { /// Variables, indexed by VarId pub variables: Vec, /// Mapping of variable's name node -> Variable - pub var_resolution: HashMap, + pub var_resolution: HashMap, /// Type declarations, indexed by TypeDeclId pub type_decls: Vec, /// Mapping of type decl's name node -> TypeDecl diff --git a/src/typechecker.rs b/src/typechecker.rs index 9364d81..f2a3402 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -1,9 +1,11 @@ //! See typechecking.md in the contributing/ folder for more information on //! how the typechecker works +use crate::ast_nodes::{ + AstNode, Block, BlockId, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, NodeId, NodeIndexer, Pipeline, PipelineId, StatementNode, StatementNodeId, StatementOrExpression, StringNode, StringNodeId, Tmp, Tmp1, VariableNode, VariableNodeId +}; use crate::compiler::Compiler; use crate::errors::{Severity, SourceError}; -use crate::parser::{AstNode, NodeId}; use crate::resolver::{TypeDecl, TypeDeclId}; use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; @@ -77,7 +79,14 @@ pub enum Type { pub struct Types { pub types: Vec, + pub name_node_types: Vec, + pub string_node_types: Vec, + pub variable_node_types: Vec, pub node_types: Vec, + pub expression_node_types: Vec, + pub statement_node_types: Vec, + pub block_node_types: Vec, + pub pipeline_node_types: Vec, pub errors: Vec, } @@ -113,10 +122,17 @@ pub struct Typechecker<'a> { types: Vec, /// Types of nodes. Each type in this vector matches a node in compiler.ast_nodes at the same position. + pub name_node_types: Vec, + pub string_node_types: Vec, + pub variable_node_types: Vec, pub node_types: Vec, + pub expression_node_types: Vec, + pub statement_node_types: Vec, + pub block_node_types: Vec, + pub pipeline_node_types: Vec, /// Record fields used for `RecordType`. Each value in this vector matches with the index in RecordTypeId. /// The individual field lists are stored sorted by field name. - pub record_types: Vec>, + pub record_types: Vec>, /// Types used for `OneOf`. Each value in this vector matches with the index in OneOfId. pub oneof_types: Vec>, /// Types used for `AllOf`. Each value in this vector matches with the index in AllOfId. @@ -155,7 +171,14 @@ impl<'a> Typechecker<'a> { Type::Top, Type::Bottom, ], + name_node_types: vec![UNKNOWN_TYPE; compiler.name_nodes.len()], + string_node_types: vec![UNKNOWN_TYPE; compiler.string_nodes.len()], + variable_node_types: vec![UNKNOWN_TYPE; compiler.variable_nodes.len()], node_types: vec![UNKNOWN_TYPE; compiler.ast_nodes.len()], + expression_node_types: vec![UNKNOWN_TYPE; compiler.expression_nodes.len()], + statement_node_types: vec![UNKNOWN_TYPE; compiler.statement_nodes.len()], + block_node_types: vec![UNKNOWN_TYPE; compiler.blocks.len()], + pipeline_node_types: vec![UNKNOWN_TYPE; compiler.pipelines.len()], record_types: Vec::new(), oneof_types: Vec::new(), allof_types: Vec::new(), @@ -175,7 +198,14 @@ impl<'a> Typechecker<'a> { pub fn to_types(self) -> Types { Types { types: self.types, + name_node_types: self.name_node_types, + string_node_types: self.string_node_types, + variable_node_types: self.variable_node_types, node_types: self.node_types, + expression_node_types: self.expression_node_types, + statement_node_types: self.statement_node_types, + block_node_types: self.block_node_types, + pipeline_node_types: self.pipeline_node_types, errors: self.errors, } } @@ -213,11 +243,14 @@ impl<'a> Typechecker<'a> { /// Typecheck AST nodes, starting from the last node pub fn typecheck(&mut self) { - if !self.compiler.ast_nodes.is_empty() { - let last = self.compiler.ast_nodes.len() - 1; - let last_node_id = NodeId(last); - self.typecheck_node(last_node_id); - + if !self.compiler.indexer.is_empty() { + let length = self.compiler.indexer.len(); + let last_indexer = self.compiler.indexer[length - 1]; + match last_indexer { + NodeIndexer::General(node_id) => self.typecheck_node(node_id), + NodeIndexer::Block(block_id) => self.typecheck_block(block_id, TOP_TYPE), + _ => return; + } for i in 0..self.type_vars.len() { let var = &self.type_vars[i]; let bound = var.lower_bound; @@ -232,6 +265,25 @@ impl<'a> Typechecker<'a> { } } } + // if !self.compiler.ast_nodes.is_empty() { + // let last = self.compiler.ast_nodes.len() - 1; + // let last_node_id = NodeId(last); + // self.typecheck_node(last_node_id); + // + // for i in 0..self.type_vars.len() { + // let var = &self.type_vars[i]; + // let bound = var.lower_bound; + // let cleaned = self.eliminate_type_vars(bound, TypeVarId(0), true); + // self.types[bound.0] = self.types[cleaned.0]; + // } + // + // for i in 0..self.types.len() { + // if let Type::Var(var_id) = &self.types[i] { + // let bound = self.type_vars[var_id.0].lower_bound; + // self.types[i] = self.types[bound.0]; + // } + // } + // } } /// Get type ID of a node @@ -240,13 +292,14 @@ impl<'a> Typechecker<'a> { } /// Get type of node - pub fn type_of(&self, node_id: NodeId) -> Type { - let type_id = self.type_id_of(node_id); + pub fn type_of(&self, node_id: T) -> Type { + let type_id = node_id.type_id_of(self); self.types[type_id.0] } fn typecheck_node(&mut self, node_id: NodeId) { - match self.compiler.ast_nodes[node_id.0] { + let node = node_id.get_node(self.compiler); + match node { AstNode::Params(ref params) => { for param in params { self.typecheck_node(*param); @@ -256,12 +309,12 @@ impl<'a> Typechecker<'a> { } AstNode::Param { name, ty } => { if let Some(ty) = ty { - let ty_id = self.typecheck_type(ty); + let ty_id = self.typecheck_type(*ty); let var_id = self .compiler .var_resolution - .get(&name) + .get(name) .expect("missing resolved variable"); self.variable_types[var_id.0] = ty_id; self.set_node_type_id(node_id, ty_id); @@ -276,33 +329,29 @@ impl<'a> Typechecker<'a> { // Type argument lists are not supposed to be evaluated self.set_node_type_id(node_id, FORBIDDEN_TYPE); } - AstNode::Block(_) => { - self.typecheck_block(node_id, TOP_TYPE); - } + // NOTE: what about AstNode::TypeParams? _ => self.error( format!( "unsupported/unexpected ast node '{:?}' in typechecker", - self.compiler.ast_nodes[node_id.0] + node ), - node_id, + NodeIndexer::General(node_id) ), } } - fn typecheck_block(&mut self, node_id: NodeId, expected: TypeId) -> TypeId { - let AstNode::Block(block_id) = self.compiler.ast_nodes[node_id.0] else { - panic!( - "Expected block to typecheck, got '{:?}'", - self.compiler.ast_nodes[node_id.0] - ); - }; - let block = &self.compiler.blocks[block_id.0]; + fn typecheck_block(&mut self, node_id: BlockId, expected: TypeId) -> TypeId { + let block = node_id.get_node(self.compiler); for (i, inner_node_id) in block.nodes.iter().enumerate() { - if i == block.nodes.len() - 1 && self.is_expr(*inner_node_id) { - self.typecheck_expr(*inner_node_id, expected); + let expected_type = if i == block.nodes.len() - 1 { + expected } else { - self.typecheck_stmt(*inner_node_id); + TOP_TYPE + }; + match inner_node_id { + StatementOrExpression::Statement(stmt_id) => self.typecheck_stmt(*stmt_id), + StatementOrExpression::Expression(expr_id) => self.typecheck_expr(*expr_id, expected_type), } } @@ -311,30 +360,31 @@ impl<'a> Typechecker<'a> { let block_type = block .nodes .last() - .map_or(NONE_TYPE, |node_id| self.type_id_of(*node_id)); - self.set_node_type_id(node_id, block_type); + .map_or(NONE_TYPE, |node_id| node_id.type_id_of(self)); + node_id.set_node_type_id(self, block_type); block_type } - fn typecheck_stmt(&mut self, node_id: NodeId) { - match self.compiler.ast_nodes[node_id.0] { - AstNode::Let { + fn typecheck_stmt(&mut self, node_id:StatementNodeId) { + let node = node_id.get_node(self.compiler); + match node { + StatementNode::Let { variable_name, ty, initializer, is_mutable: _, } => self.typecheck_let(variable_name, ty, initializer, node_id), - AstNode::Def { + StatementNode::Def { name, params, in_out_types, block, .. } => self.typecheck_def(name, params, in_out_types, block, node_id), - AstNode::Alias { new_name, old_name } => { + StatementNode::Alias { new_name, old_name } => { self.typecheck_alias(new_name, old_name, node_id) } - AstNode::For { + StatementNode::For { variable, range, block, @@ -365,16 +415,16 @@ impl<'a> Typechecker<'a> { self.set_node_type_id(node_id, NONE_TYPE); } } - AstNode::While { condition, block } => { + StatementNode::While { condition, block } => { self.typecheck_expr(condition, BOOL_TYPE); self.typecheck_node(block); self.set_node_type_id(node_id, NONE_TYPE); } - AstNode::Loop { block } => { + StatementNode::Loop { block } => { self.typecheck_node(block); self.set_node_type_id(node_id, NONE_TYPE); } - AstNode::Break | AstNode::Continue => { + StatementNode::Break | StatementNode::Continue => { // TODO make sure we're in a loop self.set_node_type_id(node_id, NONE_TYPE); } @@ -391,14 +441,15 @@ impl<'a> Typechecker<'a> { } } - fn typecheck_expr(&mut self, node_id: NodeId, expected: TypeId) -> TypeId { - let ty_id = match self.compiler.ast_nodes[node_id.0] { - AstNode::Null => NOTHING_TYPE, - AstNode::Int => INT_TYPE, - AstNode::Float => FLOAT_TYPE, - AstNode::True | AstNode::False => BOOL_TYPE, - AstNode::String => STRING_TYPE, - AstNode::List(ref items) => { + fn typecheck_expr(&mut self, node_id: ExpressionNodeId, expected: TypeId) -> TypeId { + let node = node_id.get_node(self.compiler); + let ty_id = match node { + ExpressionNode::Null => NOTHING_TYPE, + ExpressionNode::Int => INT_TYPE, + ExpressionNode::Float => FLOAT_TYPE, + ExpressionNode::True | ExpressionNode::False => BOOL_TYPE, + ExpressionNode::String(_) => STRING_TYPE, + ExpressionNode::List(ref items) => { // TODO infer a union type instead if let Some(first_id) = items.first() { let expected_elem = self.extract_elem_type(expected); @@ -422,7 +473,7 @@ impl<'a> Typechecker<'a> { } if all_same { - self.push_type(Type::List(self.type_id_of(*first_id))) + self.push_type(Type::List(first_id.type_id_of(self))) } else if all_numbers { self.push_type(Type::List(NUMBER_TYPE)) } else { @@ -432,19 +483,19 @@ impl<'a> Typechecker<'a> { LIST_ANY_TYPE } } - AstNode::Record { ref pairs } => { + ExpressionNode::Record { ref pairs } => { // TODO take expected type into account let mut field_types = pairs .iter() .map(|(name, value)| (*name, self.typecheck_expr(*value, TOP_TYPE))) .collect::>(); - field_types.sort_by_cached_key(|(name, _)| self.compiler.get_span_contents(*name)); + field_types.sort_by_cached_key(|(name, _)| self.compiler.get_span_contents(NodeIndexer::Expression(*name))); self.record_types.push(field_types); self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))) } - AstNode::Pipeline(pipeline_id) => { - let pipeline = &self.compiler.pipelines[pipeline_id.0]; + ExpressionNode::Pipeline(pipeline_id) => { + let pipeline = self.compiler.pipelines.get_node(pipeline_id.0); let expressions = pipeline.get_expressions(); for inner in expressions { self.typecheck_expr(*inner, TOP_TYPE); @@ -454,43 +505,50 @@ impl<'a> Typechecker<'a> { // by themselves aren't supposed to be typed expressions .last() - .map_or(NONE_TYPE, |node_id| self.type_id_of(*node_id)) + .map_or(NONE_TYPE, |node_id| node_id.type_id_of(self)) } - AstNode::Closure { params, block } => { + ExpressionNode::Closure { params, block } => { // TODO: input/output types if let Some(params_node_id) = params { - self.typecheck_node(params_node_id); + self.typecheck_node(*params_node_id); } - self.typecheck_node(block); + self.typecheck_block(*block, expected); CLOSURE_TYPE } - AstNode::BinaryOp { lhs, op, rhs } => self.typecheck_binary_op(lhs, op, rhs), - AstNode::Variable => { + ExpressionNode::BinaryOp { lhs, op, rhs } => self.typecheck_binary_op(*lhs, *op, *rhs), + ExpressionNode::Variable(variable_node_id) => { let var_id = self .compiler .var_resolution - .get(&node_id) + .get(variable_node_id) .expect("missing resolved variable"); self.variable_types[var_id.0] } - AstNode::If { + ExpressionNode::If { condition, then_block, else_block, } => { - self.typecheck_expr(condition, BOOL_TYPE); + self.typecheck_expr(*condition, BOOL_TYPE); - let then_type_id = self.typecheck_block(then_block, expected); + let then_type_id = self.typecheck_block(*then_block, expected); if let Some(else_blk) = else_block { let else_type_id = - if let AstNode::Block(_) = self.compiler.ast_nodes[else_blk.0] { - self.typecheck_block(else_blk, expected) - } else { - self.typecheck_expr(else_blk, expected) - }; + match else_blk { + NodeIndexer::Expression(else_expr_id) => { + self.typecheck_expr(*else_expr_id, expected) + }, + NodeIndexer::Block(else_block_id) => { + self.typecheck_block(*else_block_id, expected) + } + _ => { + self.error("Else block of an if expression must be either a block or an expression", *else_blk); + return ERROR_TYPE; + } + }; let mut types = HashSet::new(); types.insert(then_type_id); types.insert(else_type_id); @@ -500,8 +558,8 @@ impl<'a> Typechecker<'a> { NONE_TYPE } } - AstNode::Call { ref parts } => self.typecheck_call(parts, node_id), - AstNode::Match { + ExpressionNode::Call { head, parts } => self.typecheck_call(head, parts, node_id), + ExpressionNode::Match { ref target, ref match_arms, } => { @@ -524,7 +582,7 @@ impl<'a> Typechecker<'a> { ERROR_TYPE } }; - self.set_node_type_id(node_id, ty_id); + node_id.set_node_type_id(self, ty_id); if !self.constrain_subtype(ty_id, expected) { self.error( @@ -564,36 +622,36 @@ impl<'a> Typechecker<'a> { fn typecheck_match( &mut self, - target: &NodeId, - match_arms: &Vec<(NodeId, NodeId)>, + target: &ExpressionNodeId, + match_arms: &Vec<(ExpressionNodeId,ExpressionNodeId)>, expected: TypeId, ) -> HashSet { self.typecheck_expr(*target, TOP_TYPE); let mut output_types = HashSet::new(); // typecheck each node - let target_id = self.type_id_of(*target); + let target_id = target.type_id_of(self); for (match_node, result_node) in match_arms { - self.typecheck_node(*match_node); + self.typecheck_expr(*match_node,expected); self.typecheck_expr(*result_node, expected); - let match_id = self.type_id_of(*match_node); + let match_id = match_node.type_id_of(self); match (self.type_of(*target), self.type_of(*match_node)) { // First is of type Any which will always match (Type::Any, _) => { - self.add_resolved_types(&mut output_types, &self.type_id_of(*result_node)); + self.add_resolved_types(&mut output_types, &result_node.type_id_of(self)); } // Same as above but for second (_, Type::Any) => { - self.add_resolved_types(&mut output_types, &self.type_id_of(*result_node)); + self.add_resolved_types(&mut output_types, &result_node.type_id_of(self)); } // the second is one of the possible types of the first (Type::OneOf(id), _) if self.oneof_types[id.0].contains(&match_id) => { - self.add_resolved_types(&mut output_types, &self.type_id_of(*result_node)); + self.add_resolved_types(&mut output_types, &result_node.type_id_of(self)); } // the first is one of the possible types of the second (_, Type::OneOf(id)) if self.oneof_types[id.0].contains(&target_id) => { - self.add_resolved_types(&mut output_types, &self.type_id_of(*result_node)); + self.add_resolved_types(&mut output_types, &result_node.type_id_of(self)); } // the both the target and the one matched against are // oneof then we need to check if they have any type in common @@ -603,28 +661,29 @@ impl<'a> Typechecker<'a> { .count() != 0 { - self.add_resolved_types(&mut output_types, &self.type_id_of(*result_node)); + self.add_resolved_types(&mut output_types, &result_node.type_id_of(self)); } else { - self.error("The target to be matched against and the possible types of the matched arm are completely disjoint", *match_node); + self.error("The target to be matched against and the possible types of the matched arm are completely disjoint", NodeIndexer::Expression(*match_node)); } } // Check if the two types can be matched (target_id, match_id) if self.is_type_compatible(target_id, match_id) => { - self.add_resolved_types(&mut output_types, &self.type_id_of(*result_node)); + self.add_resolved_types(&mut output_types, &result_node.type_id_of(self)); } _ => { - self.error("The types do not match", *match_node); + self.error("The types do not match", NodeIndexer::Expression(*match_node)) } } } output_types } - fn typecheck_binary_op(&mut self, lhs: NodeId, op: NodeId, rhs: NodeId) -> TypeId { + fn typecheck_binary_op(&mut self, lhs:ExpressionNodeId, op: NodeId, rhs:ExpressionNodeId) -> TypeId { self.set_node_type_id(op, FORBIDDEN_TYPE); // TODO: better error messages for type mismatches, the previous messages were better - match self.compiler.ast_nodes[op.0] { + let node = op.get_node(self.compiler); + match node { AstNode::Equal | AstNode::NotEqual => { let lhs_ty = self.typecheck_expr(lhs, TOP_TYPE); let rhs_ty = self.typecheck_expr(rhs, TOP_TYPE); @@ -714,7 +773,7 @@ impl<'a> Typechecker<'a> { if !self.constrain_subtype(lhs_ty, STRING_TYPE) { self.error( format!("Expected string, got {}", self.type_to_string(lhs_ty)), - lhs, + NodeIndexer::Expression(lhs), ); } STRING_TYPE @@ -722,7 +781,7 @@ impl<'a> Typechecker<'a> { if !self.constrain_subtype(lhs_ty, NUMBER_TYPE) { self.error( format!("Expected number, got {}", self.type_to_string(lhs_ty)), - lhs, + NodeIndexer::Expression(lhs) ); } self.numeric_op_type(lhs_ty, rhs_ty) @@ -848,7 +907,7 @@ impl<'a> Typechecker<'a> { ); } - fn typecheck_call(&mut self, parts: &[NodeId], node_id: NodeId) -> TypeId { + fn typecheck_call(&mut self, head: &[NodeId], parts: &[NodeId], node_id: NodeId) -> TypeId { if let Some(decl_id) = self.compiler.decl_resolution.get(&node_id) { let num_name_parts = self.compiler.decls[decl_id.0].name().split(' ').count(); let decl_node_id = self.compiler.decl_nodes[decl_id.0]; @@ -935,10 +994,10 @@ impl<'a> Typechecker<'a> { fn typecheck_let( &mut self, - variable_name: NodeId, + variable_name: VariableNodeId, ty: Option, - initializer: NodeId, - node_id: NodeId, + initializer: ExpressionNodeId, + node_id: StatementNodeId, ) { let type_id = if let Some(ty) = ty { let ty_id = self.typecheck_type(ty); @@ -1603,7 +1662,7 @@ impl<'a> Typechecker<'a> { } } - fn error(&mut self, msg: impl Into, node_id: NodeId) { + fn error(&mut self, msg: impl Into, node_id:NodeIndexer) { self.errors.push(SourceError { message: msg.into(), node_id, @@ -1611,15 +1670,15 @@ impl<'a> Typechecker<'a> { }) } - fn binary_op_err(&mut self, op_msg: &str, lhs: NodeId, op: NodeId, rhs: NodeId) { + fn binary_op_err(&mut self, op_msg: &str, lhs:ExpressionNodeId, op: NodeId, rhs:ExpressionNodeId) { self.error( format!( "type mismatch: unsupported {} between {} and {}", op_msg, - self.type_to_string(self.type_id_of(lhs)), - self.type_to_string(self.type_id_of(rhs)), + self.type_to_string(lhs.type_id_of(self)), + self.type_to_string(rhs.type_id_of(self)), ), - op, + NodeIndexer::General(op), ); self.set_node_type_id(op, ERROR_TYPE); } @@ -1892,3 +1951,104 @@ impl<'a> Typechecker<'a> { self.create_oneof(inters) } } + +trait NodeTypeSetter { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId); + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId; +} + +impl NodeTypeSetter for NameNodeId { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + typechecker.name_node_types[self.0] = type_id; + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + typechecker.name_node_types[self.0] + } +} + +impl NodeTypeSetter for StringNodeId { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + typechecker.string_node_types[self.0] = type_id; + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + typechecker.string_node_types[self.0] + } +} + +impl NodeTypeSetter for VariableNodeId { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + typechecker.variable_node_types[self.0] = type_id; + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + typechecker.variable_node_types[self.0] + } +} + +impl NodeTypeSetter for NodeId { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + typechecker.node_types[self.0] = type_id; + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + typechecker.node_types[self.0] + } +} + +impl NodeTypeSetter for ExpressionNodeId { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + typechecker.expression_node_types[self.0] = type_id; + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + typechecker.expression_node_types[self.0] + } +} + +impl NodeTypeSetter for StatementNodeId { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + typechecker.statement_node_types[self.0] = type_id; + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + typechecker.statement_node_types[self.0] + } +} + +impl NodeTypeSetter for BlockId { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + typechecker.block_node_types[self.0] = type_id; + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + typechecker.block_node_types[self.0] + } +} + +impl NodeTypeSetter for PipelineId { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + typechecker.pipeline_node_types[self.0] = type_id; + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + typechecker.pipeline_node_types[self.0] + } +} + +impl NodeTypeSetter for StatementOrExpression { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + match self { + StatementOrExpression::Statement(stmt_id) => stmt_id.set_node_type_id(typechecker, type_id), + StatementOrExpression::Expression(expr_id) => expr_id.set_node_type_id(typechecker, type_id), + } + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + match self { + StatementOrExpression::Statement(stmt_id) => stmt_id.type_id_of(typechecker), + StatementOrExpression::Expression(expr_id) => expr_id.type_id_of(typechecker), + } + } +} From 7c4a05b03a8cbc42b9339ef6c42965c2937fa0b0 Mon Sep 17 00:00:00 2001 From: WindSoilder Date: Mon, 16 Mar 2026 15:51:13 +0800 Subject: [PATCH 09/12] rename --- src/ast_nodes.rs | 72 ++++++++++++++++++++++----------------------- src/compiler.rs | 36 +++++++++++------------ src/ir_generator.rs | 2 +- src/parser.rs | 19 ++++++------ src/resolver.rs | 4 +-- src/typechecker.rs | 8 ++--- 6 files changed, 71 insertions(+), 70 deletions(-) diff --git a/src/ast_nodes.rs b/src/ast_nodes.rs index 3b12780..c9bed0d 100644 --- a/src/ast_nodes.rs +++ b/src/ast_nodes.rs @@ -36,13 +36,13 @@ pub enum NameOrString { } #[derive(Debug, Clone)] -pub struct Block { +pub struct BlockNode { pub nodes: Vec, } -impl Block { - pub fn new(nodes: Vec) -> Block { - Block { nodes } +impl BlockNode { + pub fn new(nodes: Vec) -> BlockNode { + BlockNode { nodes } } } @@ -60,11 +60,11 @@ pub struct PipelineId(pub usize); // Making such restriction can reduce indirect access on expression, which // can improve performance in parse time. #[derive(Debug, Clone, PartialEq)] -pub struct Pipeline { +pub struct PipelineNode { pub nodes: Vec, } -impl Pipeline { +impl PipelineNode { pub fn new(nodes: Vec) -> Self { debug_assert!( nodes.len() > 1, @@ -283,7 +283,7 @@ pub enum AstNode { Garbage, } -pub trait Tmp { +pub trait NodeIdGetter { type Output; fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output; fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output; @@ -298,12 +298,12 @@ pub trait Tmp { fn into_indexer(self) -> NodeIndexer; } -pub trait Tmp1 { +pub trait NodePusher { type Output; fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output; } -impl Tmp for NameNodeId { +impl NodeIdGetter for NameNodeId { type Output = NameNode; fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { @@ -323,7 +323,7 @@ impl Tmp for NameNodeId { } } -impl Tmp1 for NameNode { +impl NodePusher for NameNode { type Output = NameNodeId; fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { @@ -337,7 +337,7 @@ impl Tmp1 for NameNode { } } -impl Tmp for StringNodeId { +impl NodeIdGetter for StringNodeId { type Output = StringNode; fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { @@ -357,7 +357,7 @@ impl Tmp for StringNodeId { } } -impl Tmp1 for StringNode { +impl NodePusher for StringNode { type Output = StringNodeId; fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { @@ -371,7 +371,7 @@ impl Tmp1 for StringNode { } } -impl Tmp for VariableNodeId { +impl NodeIdGetter for VariableNodeId { type Output = VariableNode; fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { @@ -391,7 +391,7 @@ impl Tmp for VariableNodeId { } } -impl Tmp1 for VariableNode { +impl NodePusher for VariableNode { type Output = VariableNodeId; fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { @@ -405,19 +405,19 @@ impl Tmp1 for VariableNode { } } -impl Tmp for BlockId { - type Output = Block; +impl NodeIdGetter for BlockId { + type Output = BlockNode; fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { - compiler.blocks.get_node(self.0) + compiler.block_nodes.get_node(self.0) } fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { - compiler.blocks.get_node_mut(self.0) + compiler.block_nodes.get_node_mut(self.0) } fn get_span(&self, compiler: &Compiler) -> Span { - compiler.blocks.get_span(self.0) + compiler.block_nodes.get_span(self.0) } fn into_indexer(self) -> NodeIndexer { @@ -425,13 +425,13 @@ impl Tmp for BlockId { } } -impl Tmp1 for Block { +impl NodePusher for BlockNode { type Output = BlockId; fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { - compiler.blocks.push(span, self); + compiler.block_nodes.push(span, self); - let result = BlockId(compiler.blocks.len() - 1); + let result = BlockId(compiler.block_nodes.len() - 1); let indexer = NodeIndexer::Block(result); compiler.indexer.push(indexer); @@ -439,7 +439,7 @@ impl Tmp1 for Block { } } -impl Tmp for StatementNodeId { +impl NodeIdGetter for StatementNodeId { type Output = StatementNode; fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { @@ -459,7 +459,7 @@ impl Tmp for StatementNodeId { } } -impl Tmp1 for StatementNode { +impl NodePusher for StatementNode { type Output = StatementNodeId; fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { @@ -473,19 +473,19 @@ impl Tmp1 for StatementNode { } } -impl Tmp for PipelineId { - type Output = Pipeline; +impl NodeIdGetter for PipelineId { + type Output = PipelineNode; fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { - compiler.pipelines.get_node(self.0) + compiler.pipeline_nodes.get_node(self.0) } fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { - compiler.pipelines.get_node_mut(self.0) + compiler.pipeline_nodes.get_node_mut(self.0) } fn get_span(&self, compiler: &Compiler) -> Span { - compiler.pipelines.get_span(self.0) + compiler.pipeline_nodes.get_span(self.0) } fn into_indexer(self) -> NodeIndexer { @@ -493,20 +493,20 @@ impl Tmp for PipelineId { } } -impl Tmp1 for Pipeline { +impl NodePusher for PipelineNode { type Output = PipelineId; fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { - compiler.pipelines.push(span, self); + compiler.pipeline_nodes.push(span, self); - let result = PipelineId(compiler.pipelines.len() - 1); + let result = PipelineId(compiler.pipeline_nodes.len() - 1); let indexer = NodeIndexer::Pipeline(result); compiler.indexer.push(indexer); result } } -impl Tmp1 for ExpressionNode { +impl NodePusher for ExpressionNode { type Output = ExpressionNodeId; fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { @@ -520,7 +520,7 @@ impl Tmp1 for ExpressionNode { } } -impl Tmp for ExpressionNodeId { +impl NodeIdGetter for ExpressionNodeId { type Output = ExpressionNode; fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { @@ -539,7 +539,7 @@ impl Tmp for ExpressionNodeId { NodeIndexer::Expression(self) } } -impl Tmp1 for AstNode { +impl NodePusher for AstNode { type Output = NodeId; fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { @@ -552,7 +552,7 @@ impl Tmp1 for AstNode { result } } -impl Tmp for NodeId { +impl NodeIdGetter for NodeId { type Output = AstNode; fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { diff --git a/src/compiler.rs b/src/compiler.rs index 8a37941..72bcab1 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -1,7 +1,7 @@ use crate::ast_nodes::{ - AstNode, Block, BlockId, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, NodeId, - NodeIndexer, Pipeline, StatementNode, StatementNodeId, StringNode, StringNodeId, Tmp, Tmp1, - VariableNode, VariableNodeId, + AstNode, BlockId, BlockNode, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, NodeId, + NodeIdGetter, NodeIndexer, NodePusher, PipelineNode, StatementNode, StatementNodeId, + StringNode, StringNodeId, VariableNode, VariableNodeId, }; use crate::errors::SourceError; use crate::protocol::Command; @@ -100,8 +100,8 @@ pub struct Compiler { pub ast_nodes: NodeSpans, pub expression_nodes: NodeSpans, pub statement_nodes: NodeSpans, - pub blocks: NodeSpans, // Blocks, indexed by BlockId - pub pipelines: NodeSpans, // Pipelines, indexed by PipelineId + pub block_nodes: NodeSpans, // Blocks, indexed by BlockId + pub pipeline_nodes: NodeSpans, // Pipelines, indexed by PipelineId pub indexer: Vec, pub node_types: Vec, @@ -155,10 +155,10 @@ impl Compiler { name_nodes: NodeSpans::new(), expression_nodes: NodeSpans::new(), statement_nodes: NodeSpans::new(), - pipelines: NodeSpans::new(), + pipeline_nodes: NodeSpans::new(), node_types: vec![], indexer: vec![], - blocks: NodeSpans::new(), + block_nodes: NodeSpans::new(), source: vec![], file_offsets: vec![], @@ -218,12 +218,12 @@ impl Compiler { self.ast_nodes.get_span(i.0), ), NodeIndexer::Block(i) => ( - format!("{:?}", self.blocks.get_node(i.0)), - self.blocks.get_span(i.0), + format!("{:?}", self.block_nodes.get_node(i.0)), + self.block_nodes.get_span(i.0), ), NodeIndexer::Pipeline(i) => ( - format!("{:?}", self.pipelines.get_node(i.0)), - self.pipelines.get_span(i.0), + format!("{:?}", self.pipeline_nodes.get_node(i.0)), + self.pipeline_nodes.get_span(i.0), ), }; result.push_str(&format!( @@ -296,15 +296,15 @@ impl Compiler { self.source.len() } - pub fn get_node(&self, node_id: T) -> &T::Output { + pub fn get_node(&self, node_id: T) -> &T::Output { node_id.get_node(self) } - pub fn get_node_mut(&mut self, node_id: T) -> &mut T::Output { + pub fn get_node_mut(&mut self, node_id: T) -> &mut T::Output { node_id.get_node_mut(self) } - pub fn push_node(&mut self, span: Span, ast_node: T) -> T::Output { + pub fn push_node(&mut self, span: Span, ast_node: T) -> T::Output { ast_node.push_node(span, self) } @@ -317,13 +317,13 @@ impl Compiler { idx_expression_nodes: self.expression_nodes.len(), idx_statment_nodes: self.statement_nodes.len(), idx_errors: self.errors.len(), - idx_blocks: self.blocks.len(), + idx_blocks: self.block_nodes.len(), token_pos, } } pub fn apply_compiler_rollback(&mut self, rbp: RollbackPoint) -> usize { - self.blocks.truncate(rbp.idx_blocks); + self.block_nodes.truncate(rbp.idx_blocks); self.ast_nodes.truncate(rbp.idx_nodes); self.name_nodes.truncate(rbp.idx_name_nodes); self.string_nodes.truncate(rbp.idx_string_nodes); @@ -342,9 +342,9 @@ impl Compiler { NodeIndexer::Variable(i) => self.variable_nodes.get_span(i.0), NodeIndexer::General(i) => self.ast_nodes.get_span(i.0), NodeIndexer::Expression(i) => self.expression_nodes.get_span(i.0), - NodeIndexer::Block(i) => self.blocks.get_span(i.0), + NodeIndexer::Block(i) => self.block_nodes.get_span(i.0), NodeIndexer::Statement(i) => self.statement_nodes.get_span(i.0), - NodeIndexer::Pipeline(i) => self.pipelines.get_span(i.0), + NodeIndexer::Pipeline(i) => self.pipeline_nodes.get_span(i.0), } } diff --git a/src/ir_generator.rs b/src/ir_generator.rs index 74a61f4..93663cd 100644 --- a/src/ir_generator.rs +++ b/src/ir_generator.rs @@ -110,7 +110,7 @@ impl<'a> IrGenerator<'a> { Some(next_reg) } AstNode::Block(block_id) => { - let block = &self.compiler.blocks[block_id.0]; + let block = &self.compiler.block_nodes[block_id.0]; let mut last = None; for id in &block.nodes { last = self.generate_node(*id); diff --git a/src/parser.rs b/src/parser.rs index 4498a42..91db8f9 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,7 +1,8 @@ use crate::ast_nodes::{ - AstNode, Block, BlockId, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, NameOrString, - NodeId, NodeIndexer, Pipeline, PipelineId, StatementNode, StatementNodeId, - StatementOrExpression, StringNode, StringNodeId, Tmp, Tmp1, VariableNode, VariableNodeId, + AstNode, BlockId, BlockNode, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, + NameOrString, NodeId, NodeIdGetter, NodeIndexer, NodePusher, PipelineId, PipelineNode, + StatementNode, StatementNodeId, StatementOrExpression, StringNode, StringNodeId, VariableNode, + VariableNodeId, }; use crate::compiler::{Compiler, RollbackPoint, Span}; use crate::errors::{Severity, SourceError}; @@ -96,7 +97,7 @@ impl Parser { self.tokens.peek_span().start } - fn get_span_end(&self, node_id: T) -> usize { + fn get_span_end(&self, node_id: T) -> usize { node_id.get_span(&self.compiler).end } @@ -127,7 +128,7 @@ impl Parser { expressions.push(self.expression()?); } let span_end = self.position(); - let pipeline_id = Pipeline::new(expressions).push_node( + let pipeline_id = PipelineNode::new(expressions).push_node( Span { start: span_start, end: span_end, @@ -432,7 +433,7 @@ impl Parser { } } - pub fn advance_node(&mut self, node: T, span: Span) -> T::Output { + pub fn advance_node(&mut self, node: T, span: Span) -> T::Output { self.tokens.advance(); node.push_node(span, &mut self.compiler) } @@ -485,7 +486,7 @@ impl Parser { } } - pub fn advance_unchecked(&mut self, node: T) -> T::Output { + pub fn advance_unchecked(&mut self, node: T) -> T::Output { let span = self.tokens.peek_span(); self.tokens.advance(); node.push_node(span, &mut self.compiler) @@ -730,7 +731,7 @@ impl Parser { self.compiler.get_node(operator).precedence() } - pub fn spanning(&mut self, from: T, to: T) -> (usize, usize) { + pub fn spanning(&mut self, from: T, to: T) -> (usize, usize) { ( from.get_span(&self.compiler).start, to.get_span(&self.compiler).end, @@ -1506,7 +1507,7 @@ impl Parser { } let span_end = self.position(); - Some(Block { nodes: code_body }.push_node( + Some(BlockNode { nodes: code_body }.push_node( Span { start: span_start, end: span_end, diff --git a/src/resolver.rs b/src/resolver.rs index 5c281fb..c551ba7 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -438,7 +438,7 @@ impl<'a> Resolver<'a> { } pub fn resolve_pipeline(&mut self, pipeline_id: PipelineId) { - let pipeline = &self.compiler.pipelines[pipeline_id.0]; + let pipeline = &self.compiler.pipeline_nodes[pipeline_id.0]; for exp in pipeline.get_expressions() { self.resolve_node(*exp) @@ -533,7 +533,7 @@ impl<'a> Resolver<'a> { ) { let block = self .compiler - .blocks + .block_nodes .get(block_id.0) .expect("internal error: missing block"); diff --git a/src/typechecker.rs b/src/typechecker.rs index f2a3402..5978d39 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -2,7 +2,7 @@ //! how the typechecker works use crate::ast_nodes::{ - AstNode, Block, BlockId, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, NodeId, NodeIndexer, Pipeline, PipelineId, StatementNode, StatementNodeId, StatementOrExpression, StringNode, StringNodeId, Tmp, Tmp1, VariableNode, VariableNodeId + AstNode, BlockNode, BlockId, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, NodeId, NodeIndexer, PipelineNode, PipelineId, StatementNode, StatementNodeId, StatementOrExpression, StringNode, StringNodeId, NodeIdGetter, NodePusher, VariableNode, VariableNodeId }; use crate::compiler::Compiler; use crate::errors::{Severity, SourceError}; @@ -177,8 +177,8 @@ impl<'a> Typechecker<'a> { node_types: vec![UNKNOWN_TYPE; compiler.ast_nodes.len()], expression_node_types: vec![UNKNOWN_TYPE; compiler.expression_nodes.len()], statement_node_types: vec![UNKNOWN_TYPE; compiler.statement_nodes.len()], - block_node_types: vec![UNKNOWN_TYPE; compiler.blocks.len()], - pipeline_node_types: vec![UNKNOWN_TYPE; compiler.pipelines.len()], + block_node_types: vec![UNKNOWN_TYPE; compiler.block_nodes.len()], + pipeline_node_types: vec![UNKNOWN_TYPE; compiler.pipeline_nodes.len()], record_types: Vec::new(), oneof_types: Vec::new(), allof_types: Vec::new(), @@ -495,7 +495,7 @@ impl<'a> Typechecker<'a> { self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))) } ExpressionNode::Pipeline(pipeline_id) => { - let pipeline = self.compiler.pipelines.get_node(pipeline_id.0); + let pipeline = self.compiler.pipeline_nodes.get_node(pipeline_id.0); let expressions = pipeline.get_expressions(); for inner in expressions { self.typecheck_expr(*inner, TOP_TYPE); From a82e1032f27b93d9ca965e33392ce64ed6500935 Mon Sep 17 00:00:00 2001 From: WindSoilder Date: Mon, 16 Mar 2026 18:02:18 +0800 Subject: [PATCH 10/12] some minor adjust for type chagnes --- src/compiler.rs | 8 +-- src/typechecker.rs | 174 +++++++++++++++++++-------------------------- 2 files changed, 78 insertions(+), 104 deletions(-) diff --git a/src/compiler.rs b/src/compiler.rs index 72bcab1..b03ffc9 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -1,7 +1,7 @@ use crate::ast_nodes::{ - AstNode, BlockId, BlockNode, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, NodeId, - NodeIdGetter, NodeIndexer, NodePusher, PipelineNode, StatementNode, StatementNodeId, - StringNode, StringNodeId, VariableNode, VariableNodeId, + AstNode, BlockId, BlockNode, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, + NameOrString, NodeId, NodeIdGetter, NodeIndexer, NodePusher, PipelineNode, StatementNode, + StatementNodeId, StringNode, StringNodeId, VariableNode, VariableNodeId, }; use crate::errors::SourceError; use crate::protocol::Command; @@ -127,7 +127,7 @@ pub struct Compiler { /// Declaration NodeIds, indexed by DeclId pub decl_nodes: Vec, /// Mapping of decl's name node -> Command - pub decl_resolution: HashMap, + pub decl_resolution: HashMap, // Definitions: // indexed by FunId diff --git a/src/typechecker.rs b/src/typechecker.rs index 5978d39..9f457a5 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -2,7 +2,7 @@ //! how the typechecker works use crate::ast_nodes::{ - AstNode, BlockNode, BlockId, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, NodeId, NodeIndexer, PipelineNode, PipelineId, StatementNode, StatementNodeId, StatementOrExpression, StringNode, StringNodeId, NodeIdGetter, NodePusher, VariableNode, VariableNodeId + AstNode, BlockId, BlockNode, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, NameOrString, NodeId, NodeIdGetter, NodeIndexer, NodePusher, PipelineId, PipelineNode, StatementNode, StatementNodeId, StatementOrExpression, StringNode, StringNodeId, VariableNode, VariableNodeId }; use crate::compiler::Compiler; use crate::errors::{Severity, SourceError}; @@ -248,7 +248,7 @@ impl<'a> Typechecker<'a> { let last_indexer = self.compiler.indexer[length - 1]; match last_indexer { NodeIndexer::General(node_id) => self.typecheck_node(node_id), - NodeIndexer::Block(block_id) => self.typecheck_block(block_id, TOP_TYPE), + NodeIndexer::Block(block_id) => {self.typecheck_block(block_id, TOP_TYPE);} _ => return; } for i in 0..self.type_vars.len() { @@ -292,7 +292,7 @@ impl<'a> Typechecker<'a> { } /// Get type of node - pub fn type_of(&self, node_id: T) -> Type { + pub fn type_of(&self, node_id: &T) -> Type { let type_id = node_id.type_id_of(self); self.types[type_id.0] } @@ -340,7 +340,7 @@ impl<'a> Typechecker<'a> { } } - fn typecheck_block(&mut self, node_id: BlockId, expected: TypeId) -> TypeId { + fn typecheck_block(&mut self, node_id: &BlockId, expected: TypeId) -> TypeId { let block = node_id.get_node(self.compiler); for (i, inner_node_id) in block.nodes.iter().enumerate() { @@ -399,49 +399,46 @@ impl<'a> Typechecker<'a> { .expect("missing resolved variable"); if let Type::List(type_id) = self.type_of(range) { self.variable_types[var_id.0] = type_id; - self.set_node_type_id(variable, type_id); + variable.set_node_type_id(self, type_id); } else { self.variable_types[var_id.0] = ANY_TYPE; - self.set_node_type_id(variable, ERROR_TYPE); - self.error("For loop range is not a list", range); + variable.set_node_type_id(self, ERROR_TYPE); + self.error("For loop range is not a list", range.into_indexer()); } - self.typecheck_node(block); - if self.type_id_of(block) != NONE_TYPE { - self.error("Blocks in looping constructs cannot return values", block); + self.typecheck_block(block, TOP_TYPE); + if block.type_id_of(self) != NONE_TYPE { + self.error("Blocks in looping constructs cannot return values", block.into_indexer()); } - if self.type_id_of(node_id) != ERROR_TYPE { - self.set_node_type_id(node_id, NONE_TYPE); + if node_id.type_id_of(self) != ERROR_TYPE { + node_id.set_node_type_id(self, NONE_TYPE); } } StatementNode::While { condition, block } => { self.typecheck_expr(condition, BOOL_TYPE); - self.typecheck_node(block); - self.set_node_type_id(node_id, NONE_TYPE); + self.typecheck_block(block, TOP_TYPE); + node_id.set_node_type_id(self, NONE_TYPE); } StatementNode::Loop { block } => { - self.typecheck_node(block); - self.set_node_type_id(node_id, NONE_TYPE); + self.typecheck_block(block, TOP_TYPE); + node_id.set_node_type_id(self, NONE_TYPE); } StatementNode::Break | StatementNode::Continue => { // TODO make sure we're in a loop - self.set_node_type_id(node_id, NONE_TYPE); - } - _ if self.is_expr(node_id) => { - self.typecheck_expr(node_id, TOP_TYPE); + node_id.set_node_type_id(self, NONE_TYPE); } _ => self.error( format!( - "Expected statement to typecheck, got '{:?}'", - self.compiler.ast_nodes[node_id.0] + "unsupported statement node '{:?}' in typechecker", + node ), - node_id, + node_id.into_indexer() ), } } - fn typecheck_expr(&mut self, node_id: ExpressionNodeId, expected: TypeId) -> TypeId { + fn typecheck_expr(&mut self, node_id: &ExpressionNodeId, expected: TypeId) -> TypeId { let node = node_id.get_node(self.compiler); let ty_id = match node { ExpressionNode::Null => NOTHING_TYPE, @@ -453,15 +450,15 @@ impl<'a> Typechecker<'a> { // TODO infer a union type instead if let Some(first_id) = items.first() { let expected_elem = self.extract_elem_type(expected); - self.typecheck_expr(*first_id, expected_elem.unwrap_or(TOP_TYPE)); - let first_type = self.type_of(*first_id); + self.typecheck_expr(first_id, expected_elem.unwrap_or(TOP_TYPE)); + let first_type = self.type_of(first_id); let mut all_numbers = self.is_type_compatible(first_type, Type::Number); let mut all_same = true; for item_id in items.iter().skip(1) { - self.typecheck_expr(*item_id, TOP_TYPE); - let item_type = self.type_of(*item_id); + self.typecheck_expr(item_id, TOP_TYPE); + let item_type = self.type_of(item_id); if all_numbers && !self.is_type_compatible(item_type, Type::Number) { all_numbers = false; @@ -487,9 +484,9 @@ impl<'a> Typechecker<'a> { // TODO take expected type into account let mut field_types = pairs .iter() - .map(|(name, value)| (*name, self.typecheck_expr(*value, TOP_TYPE))) + .map(|(name, value)| (*name, self.typecheck_expr(value, TOP_TYPE))) .collect::>(); - field_types.sort_by_cached_key(|(name, _)| self.compiler.get_span_contents(NodeIndexer::Expression(*name))); + field_types.sort_by_cached_key(|(name, _)| self.compiler.get_span_contents(name.into_indexer())); self.record_types.push(field_types); self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))) @@ -498,7 +495,7 @@ impl<'a> Typechecker<'a> { let pipeline = self.compiler.pipeline_nodes.get_node(pipeline_id.0); let expressions = pipeline.get_expressions(); for inner in expressions { - self.typecheck_expr(*inner, TOP_TYPE); + self.typecheck_expr(inner, TOP_TYPE); } // pipeline type is the type of the last expression, since blocks @@ -513,7 +510,7 @@ impl<'a> Typechecker<'a> { self.typecheck_node(*params_node_id); } - self.typecheck_block(*block, expected); + self.typecheck_block(block, expected); CLOSURE_TYPE } ExpressionNode::BinaryOp { lhs, op, rhs } => self.typecheck_binary_op(*lhs, *op, *rhs), @@ -531,18 +528,18 @@ impl<'a> Typechecker<'a> { then_block, else_block, } => { - self.typecheck_expr(*condition, BOOL_TYPE); + self.typecheck_expr(condition, BOOL_TYPE); - let then_type_id = self.typecheck_block(*then_block, expected); + let then_type_id = self.typecheck_block(then_block, expected); if let Some(else_blk) = else_block { let else_type_id = match else_blk { NodeIndexer::Expression(else_expr_id) => { - self.typecheck_expr(*else_expr_id, expected) + self.typecheck_expr(else_expr_id, expected) }, NodeIndexer::Block(else_block_id) => { - self.typecheck_block(*else_block_id, expected) + self.typecheck_block(else_block_id, expected) } _ => { self.error("Else block of an if expression must be either a block or an expression", *else_blk); @@ -575,9 +572,9 @@ impl<'a> Typechecker<'a> { self.error( format!( "Expected an expression to typecheck, got '{:?}'", - self.compiler.ast_nodes[node_id.0] + node ), - node_id, + node_id.into_indexer(), ); ERROR_TYPE } @@ -591,52 +588,30 @@ impl<'a> Typechecker<'a> { self.type_to_string(expected), self.type_to_string(ty_id) ), - node_id, + node_id.into_indexer() ); } ty_id } - fn is_expr(&mut self, node_id: NodeId) -> bool { - matches!( - self.compiler.ast_nodes[node_id.0], - AstNode::Null - | AstNode::Int - | AstNode::Float - | AstNode::True - | AstNode::False - | AstNode::String - | AstNode::Variable - | AstNode::List(_) - | AstNode::Record { .. } - | AstNode::Table { .. } - | AstNode::Pipeline(_) - | AstNode::Closure { .. } - | AstNode::BinaryOp { .. } - | AstNode::If { .. } - | AstNode::Call { .. } - | AstNode::Match { .. } - ) - } - fn typecheck_match( &mut self, target: &ExpressionNodeId, match_arms: &Vec<(ExpressionNodeId,ExpressionNodeId)>, expected: TypeId, ) -> HashSet { - self.typecheck_expr(*target, TOP_TYPE); + self.typecheck_expr(target, TOP_TYPE); let mut output_types = HashSet::new(); // typecheck each node let target_id = target.type_id_of(self); for (match_node, result_node) in match_arms { - self.typecheck_expr(*match_node,expected); - self.typecheck_expr(*result_node, expected); + self.typecheck_expr(match_node,expected); + self.typecheck_expr(result_node, expected); let match_id = match_node.type_id_of(self); - match (self.type_of(*target), self.type_of(*match_node)) { + match (self.type_of(target), self.type_of(match_node)) { // First is of type Any which will always match (Type::Any, _) => { self.add_resolved_types(&mut output_types, &result_node.type_id_of(self)); @@ -671,14 +646,14 @@ impl<'a> Typechecker<'a> { self.add_resolved_types(&mut output_types, &result_node.type_id_of(self)); } _ => { - self.error("The types do not match", NodeIndexer::Expression(*match_node)) + self.error("The types do not match", match_node.into_indexer()) } } } output_types } - fn typecheck_binary_op(&mut self, lhs:ExpressionNodeId, op: NodeId, rhs:ExpressionNodeId) -> TypeId { + fn typecheck_binary_op(&mut self, lhs:&ExpressionNodeId, op: &NodeId, rhs:&ExpressionNodeId) -> TypeId { self.set_node_type_id(op, FORBIDDEN_TYPE); // TODO: better error messages for type mismatches, the previous messages were better @@ -838,11 +813,11 @@ impl<'a> Typechecker<'a> { fn typecheck_def( &mut self, - name: NodeId, - params: NodeId, - in_out_types: Option, - block: NodeId, - node_id: NodeId, + name: &NameOrString, + params: &NodeId, + in_out_types: &Option, + block:&BlockId, + node_id: StatementNodeId, ) { let in_out_types = in_out_types .map(|ty| { @@ -878,7 +853,7 @@ impl<'a> Typechecker<'a> { if in_out_types.is_empty() { self.decl_types[decl_id.0] = vec![InOutType { in_type: ANY_TYPE, - out_type: self.type_id_of(block), + out_type: block.get_id_of(self.compiler), }]; } else { // TODO check that block output type matches expected type @@ -886,7 +861,7 @@ impl<'a> Typechecker<'a> { } } - fn typecheck_alias(&mut self, new_name: NodeId, old_name: NodeId, node_id: NodeId) { + fn typecheck_alias(&mut self, new_name:&NameOrString, old_name:&NameOrString, node_id:StatementNodeId) { self.set_node_type_id(node_id, NONE_TYPE); // set input/output types for the command @@ -907,7 +882,7 @@ impl<'a> Typechecker<'a> { ); } - fn typecheck_call(&mut self, head: &[NodeId], parts: &[NodeId], node_id: NodeId) -> TypeId { + fn typecheck_call(&mut self, head: &[NameNodeId], parts: &[ExpressionNodeId], node_id: &ExpressionNodeId) -> TypeId { if let Some(decl_id) = self.compiler.decl_resolution.get(&node_id) { let num_name_parts = self.compiler.decls[decl_id.0].name().split(' ').count(); let decl_node_id = self.compiler.decl_nodes[decl_id.0]; @@ -980,12 +955,11 @@ impl<'a> Typechecker<'a> { self.create_oneof(out_types) } else { // external call - for part in &parts[1..] { - if matches!(self.compiler.ast_nodes[part.0], AstNode::Name) { - self.set_node_type_id(*part, STRING_TYPE); - } else { - self.typecheck_expr(*part, TOP_TYPE); - } + for h in head { + h.set_node_type_id(self, STRING_TYPE); + } + for part in parts { + self.typecheck_expr(part, TOP_TYPE); } BYTE_STREAM_TYPE @@ -994,9 +968,9 @@ impl<'a> Typechecker<'a> { fn typecheck_let( &mut self, - variable_name: VariableNodeId, - ty: Option, - initializer: ExpressionNodeId, + variable_name: &VariableNodeId, + ty: &Option, + initializer: &ExpressionNodeId, node_id: StatementNodeId, ) { let type_id = if let Some(ty) = ty { @@ -1291,8 +1265,8 @@ impl<'a> Typechecker<'a> { while i < sub_fields.len() && j < supe_fields.len() { let (sub_name, sub_ty) = sub_fields[i]; let (supe_name, supe_ty) = supe_fields[j]; - let sub_text = self.compiler.get_span_contents(sub_name); - let supe_text = self.compiler.get_span_contents(supe_name); + let sub_text = sub_name.get_span_contents(self.compiler); + let supe_text = supe_name.get_span_contents(self.compiler); match sub_text.cmp(supe_text) { Ordering::Less => { i += 1; @@ -1393,8 +1367,8 @@ impl<'a> Typechecker<'a> { while i < sub_fields.len() && j < supe_fields.len() { let (sub_name, sub_ty) = sub_fields[i]; let (supe_name, supe_ty) = supe_fields[j]; - let sub_text = self.compiler.get_span_contents(sub_name); - let supe_text = self.compiler.get_span_contents(supe_name); + let sub_text = sub_name.get_span_contents(self.compiler); + let supe_text = supe_name.get_span_contents(self.compiler); match sub_text.cmp(supe_text) { Ordering::Less => { i += 1; @@ -1556,7 +1530,7 @@ impl<'a> Typechecker<'a> { let mut fmt = "record<".to_string(); let types = &self.record_types[id.0]; for (name, ty) in types { - fmt += &String::from_utf8_lossy(self.compiler.get_span_contents(*name)); + fmt += &String::from_utf8_lossy(name.get_span_contents(&self.compiler)); fmt += ": "; fmt += &self.type_to_string(*ty); fmt += ", "; @@ -1637,8 +1611,8 @@ impl<'a> Typechecker<'a> { while l < lhs_fields.len() && r < rhs_fields.len() { let (lhs_name, lhs_ty) = lhs_fields[l]; let (rhs_name, rhs_ty) = rhs_fields[r]; - let lhs_text = self.compiler.get_span_contents(lhs_name); - let rhs_text = self.compiler.get_span_contents(rhs_name); + let lhs_text = lhs_name.get_span_contents(&self.compiler); + let rhs_text = rhs_name.get_span_contents(&self.compiler); match lhs_text.cmp(rhs_text) { Ordering::Less => { l += 1; @@ -1670,7 +1644,7 @@ impl<'a> Typechecker<'a> { }) } - fn binary_op_err(&mut self, op_msg: &str, lhs:ExpressionNodeId, op: NodeId, rhs:ExpressionNodeId) { + fn binary_op_err(&mut self, op_msg: &str, lhs:&ExpressionNodeId, op: &NodeId, rhs:&ExpressionNodeId) { self.error( format!( "type mismatch: unsupported {} between {} and {}", @@ -1678,9 +1652,9 @@ impl<'a> Typechecker<'a> { self.type_to_string(lhs.type_id_of(self)), self.type_to_string(rhs.type_id_of(self)), ), - NodeIndexer::General(op), + op.into_indexer() ); - self.set_node_type_id(op, ERROR_TYPE); + op.set_node_type_id(self, ERROR_TYPE); } fn add_resolved_types(&mut self, types: &mut HashSet, ty: &TypeId) { @@ -1714,7 +1688,7 @@ impl<'a> Typechecker<'a> { let mut simple_types = HashSet::::new(); let mut list_elems = HashSet::new(); - let mut record_fields = HashMap::<&[u8], (NodeId, HashSet)>::new(); + let mut record_fields = HashMap::<&[u8], (ExpressionNodeId, HashSet)>::new(); for ty_id in flattened { if simple_types.contains(&ty_id) { continue; @@ -1748,7 +1722,7 @@ impl<'a> Typechecker<'a> { Type::Record(rec_ty_id) => { let new_fields = &self.record_types[rec_ty_id.0]; for (name_node, ty) in new_fields.iter() { - let name = self.compiler.get_span_contents(*name_node); + let name = name_node.get_span_contents(&self.compiler); if let Some((_, types)) = record_fields.get_mut(&name) { types.insert(*ty); } else { @@ -1791,7 +1765,7 @@ impl<'a> Typechecker<'a> { for (_, (node, types)) in record_fields.into_iter() { fields.push((node, self.create_oneof(types))); } - fields.sort_by_cached_key(|(name_node, _)| self.compiler.get_span_contents(*name_node)); + fields.sort_by_cached_key(|(name_node, _)| name_node.get_span_contents(&self.compiler)); let rec_ty_id = RecordTypeId(self.record_types.len()); self.record_types.push(fields); @@ -1832,7 +1806,7 @@ impl<'a> Typechecker<'a> { let mut refs = HashMap::::new(); let mut simple_type: Option = None; let mut list_elems = HashSet::new(); - let mut record_fields = HashMap::<&[u8], (NodeId, HashSet)>::new(); + let mut record_fields = HashMap::<&[u8], (ExpressionNodeId, HashSet)>::new(); let mut oneof_ids = Vec::new(); for ty_id in flattened { let ty = self.types[ty_id.0]; @@ -1860,7 +1834,7 @@ impl<'a> Typechecker<'a> { } let new_fields = &self.record_types[rec_ty_id.0]; for (name_node, ty) in new_fields.iter() { - let name = self.compiler.get_span_contents(*name_node); + let name = name_node.get_span_contents(&self.compiler); if let Some((_, types)) = record_fields.get_mut(&name) { types.insert(*ty); } else { @@ -1907,7 +1881,7 @@ impl<'a> Typechecker<'a> { for (_, (node, types)) in record_fields.into_iter() { fields.push((node, self.create_oneof(types))); } - fields.sort_by_cached_key(|(name_node, _)| self.compiler.get_span_contents(*name_node)); + fields.sort_by_cached_key(|(name_node, _)| name_node.get_span_contents(&self.compiler)); let rec_ty_id = RecordTypeId(self.record_types.len()); self.record_types.push(fields); From d2dd7a668b8fc9db7f2c92f6ccc06eecb005a65d Mon Sep 17 00:00:00 2001 From: WindSoilder Date: Mon, 16 Mar 2026 18:13:20 +0800 Subject: [PATCH 11/12] some little change --- src/typechecker.rs | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/typechecker.rs b/src/typechecker.rs index 9f457a5..382e4ef 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -305,7 +305,7 @@ impl<'a> Typechecker<'a> { self.typecheck_node(*param); } // Params are not supposed to be evaluated - self.set_node_type_id(node_id, FORBIDDEN_TYPE); + node_id.set_node_type_id(self, FORBIDDEN_TYPE); } AstNode::Param { name, ty } => { if let Some(ty) = ty { @@ -317,9 +317,9 @@ impl<'a> Typechecker<'a> { .get(name) .expect("missing resolved variable"); self.variable_types[var_id.0] = ty_id; - self.set_node_type_id(node_id, ty_id); + node_id.set_node_type_id(self, ty_id); } else { - self.set_node_type_id(node_id, ANY_TYPE); + node_id.set_node_type_id(self, ANY_TYPE); } } AstNode::TypeArgs(ref args) => { @@ -327,7 +327,7 @@ impl<'a> Typechecker<'a> { self.typecheck_type(*arg); } // Type argument lists are not supposed to be evaluated - self.set_node_type_id(node_id, FORBIDDEN_TYPE); + node_id.set_node_type_id(self, FORBIDDEN_TYPE); } // NOTE: what about AstNode::TypeParams? _ => self.error( @@ -654,7 +654,7 @@ impl<'a> Typechecker<'a> { } fn typecheck_binary_op(&mut self, lhs:&ExpressionNodeId, op: &NodeId, rhs:&ExpressionNodeId) -> TypeId { - self.set_node_type_id(op, FORBIDDEN_TYPE); + op.set_node_type_id(self, FORBIDDEN_TYPE); // TODO: better error messages for type mismatches, the previous messages were better let node = op.get_node(self.compiler); @@ -841,7 +841,7 @@ impl<'a> Typechecker<'a> { self.typecheck_node(params); self.typecheck_node(block); - self.set_node_type_id(node_id, NONE_TYPE); + node_id.set_node_type_id(self, NONE_TYPE); // set input/output types for the command let decl_id = self @@ -862,7 +862,7 @@ impl<'a> Typechecker<'a> { } fn typecheck_alias(&mut self, new_name:&NameOrString, old_name:&NameOrString, node_id:StatementNodeId) { - self.set_node_type_id(node_id, NONE_TYPE); + node_id.set_node_type_id(self, NONE_TYPE); // set input/output types for the command let decl_id_new = self @@ -882,6 +882,8 @@ impl<'a> Typechecker<'a> { ); } + // TODO: something strange inside this function. + // The type of `self.compiler.deco_resolution` is unclear. fn typecheck_call(&mut self, head: &[NameNodeId], parts: &[ExpressionNodeId], node_id: &ExpressionNodeId) -> TypeId { if let Some(decl_id) = self.compiler.decl_resolution.get(&node_id) { let num_name_parts = self.compiler.decls[decl_id.0].name().split(' ').count(); @@ -928,7 +930,7 @@ impl<'a> Typechecker<'a> { if !self.constrain_subtype(STRING_TYPE, expected) { self.error( format!("Expected {}, got string", self.type_to_string(expected)), - *arg, + arg.into_indexer(), ); } } else { @@ -988,8 +990,8 @@ impl<'a> Typechecker<'a> { .expect("missing declared variable"); self.variable_types[var_id.0] = type_id; - self.set_node_type_id(variable_name, type_id); - self.set_node_type_id(node_id, NONE_TYPE); + variable_name.set_node_type_id(self, type_id); + node_id.set_node_type_id(self, NONE_TYPE); } fn typecheck_type(&mut self, node_id: NodeId) -> TypeId { @@ -1039,7 +1041,7 @@ impl<'a> Typechecker<'a> { ERROR_TYPE } }; - self.set_node_type_id(node_id, ty_id); + node_id.set_node_type_id(self, ty_id); ty_id } From f2e928e27b0395f4c51eb89cb9a6f3179b7adbf6 Mon Sep 17 00:00:00 2001 From: WindSoilder Date: Wed, 18 Mar 2026 17:49:00 +0800 Subject: [PATCH 12/12] more little changes --- src/ast_nodes.rs | 37 ++++++++++++-- src/compiler.rs | 20 +++++--- src/typechecker.rs | 119 +++++++++++++++++++++++---------------------- 3 files changed, 109 insertions(+), 67 deletions(-) diff --git a/src/ast_nodes.rs b/src/ast_nodes.rs index c9bed0d..f76dcf0 100644 --- a/src/ast_nodes.rs +++ b/src/ast_nodes.rs @@ -34,6 +34,31 @@ pub enum NameOrString { Name(NameNodeId), String(StringNodeId), } +impl NameOrString { + pub fn into_indexer(self) -> NodeIndexer { + match self { + Self::Name(x) => x.into_indexer(), + Self::String(x) => x.into_indexer(), + } + } +} + +// A helper enum for block compoments. Compiler doesn't save +// this as an individual id. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum NameOrVariable { + Name(NameNodeId), + Variable(VariableNodeId), +} + +impl NameOrVariable { + pub fn into_indexer(self) -> NodeIndexer { + match self { + Self::Name(x) => x.into_indexer(), + Self::Variable(x) => x.into_indexer(), + } + } +} #[derive(Debug, Clone)] pub struct BlockNode { @@ -330,7 +355,9 @@ impl NodePusher for NameNode { compiler.name_nodes.push(span, self); let result = NameNodeId(compiler.name_nodes.len() - 1); - let indexer = NodeIndexer::Name(result); + // let's push expression to indexer. + let expr_node_id = ExpressionNode::Name(result).push_node(span, compiler); + let indexer = NodeIndexer::Expression(expr_node_id); compiler.indexer.push(indexer); result @@ -364,7 +391,9 @@ impl NodePusher for StringNode { compiler.string_nodes.push(span, self); let result = StringNodeId(compiler.string_nodes.len() - 1); - let indexer = NodeIndexer::String(result); + // let's push expression to Indexer. + let expr_node_id = ExpressionNode::String(result).push_node(span, compiler); + let indexer = NodeIndexer::Expression(expr_node_id); compiler.indexer.push(indexer); result @@ -398,7 +427,9 @@ impl NodePusher for VariableNode { compiler.variable_nodes.push(span, self); let result = VariableNodeId(compiler.variable_nodes.len() - 1); - let indexer = NodeIndexer::Variable(result); + // let's push expression to indexer. + let expr_node_id = ExpressionNode::Variable(result).push_node(span, compiler); + let indexer = NodeIndexer::Expression(expr_node_id); compiler.indexer.push(indexer); result diff --git a/src/compiler.rs b/src/compiler.rs index b03ffc9..a23a176 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -1,7 +1,7 @@ use crate::ast_nodes::{ AstNode, BlockId, BlockNode, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, - NameOrString, NodeId, NodeIdGetter, NodeIndexer, NodePusher, PipelineNode, StatementNode, - StatementNodeId, StringNode, StringNodeId, VariableNode, VariableNodeId, + NameOrString, NameOrVariable, NodeId, NodeIdGetter, NodeIndexer, NodePusher, PipelineNode, + StatementNode, StatementNodeId, StringNode, StringNodeId, VariableNode, VariableNodeId, }; use crate::errors::SourceError; use crate::protocol::Command; @@ -89,6 +89,10 @@ impl NodeSpans { pub fn is_empty(&self) -> bool { self.nodes.is_empty() } + + pub fn iter_nodes(&self) -> std::slice::Iter<'_, T> { + self.nodes.iter() + } } #[derive(Clone)] @@ -97,8 +101,8 @@ pub struct Compiler { pub name_nodes: NodeSpans, pub string_nodes: NodeSpans, pub variable_nodes: NodeSpans, - pub ast_nodes: NodeSpans, pub expression_nodes: NodeSpans, + pub ast_nodes: NodeSpans, pub statement_nodes: NodeSpans, pub block_nodes: NodeSpans, // Blocks, indexed by BlockId pub pipeline_nodes: NodeSpans, // Pipelines, indexed by PipelineId @@ -117,17 +121,19 @@ pub struct Compiler { /// Variables, indexed by VarId pub variables: Vec, /// Mapping of variable's name node -> Variable - pub var_resolution: HashMap, + pub var_resolution: HashMap, /// Type declarations, indexed by TypeDeclId pub type_decls: Vec, /// Mapping of type decl's name node -> TypeDecl - pub type_resolution: HashMap, + pub type_resolution: HashMap, /// Declarations (commands, aliases, externs), indexed by DeclId pub decls: Vec>, /// Declaration NodeIds, indexed by DeclId - pub decl_nodes: Vec, + pub decl_nodes: Vec, /// Mapping of decl's name node -> Command - pub decl_resolution: HashMap, + /// It can be NameOrString, or an AstNode::Call. + // NOTE: not sure why it can be ExpressionNode::Call, but let's keep the original behavior. + pub decl_resolution: HashMap, // Definitions: // indexed by FunId diff --git a/src/typechecker.rs b/src/typechecker.rs index 382e4ef..2a029d6 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -2,7 +2,8 @@ //! how the typechecker works use crate::ast_nodes::{ - AstNode, BlockId, BlockNode, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, NameOrString, NodeId, NodeIdGetter, NodeIndexer, NodePusher, PipelineId, PipelineNode, StatementNode, StatementNodeId, StatementOrExpression, StringNode, StringNodeId, VariableNode, VariableNodeId + AstNode, BlockId, BlockNode, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, NameOrString, NodeId, NodeIdGetter, NodeIndexer, NodePusher, PipelineId, PipelineNode, StatementNode, StatementNodeId, StatementOrExpression, StringNode, StringNodeId, VariableNode, VariableNodeId, + NameOrVariable, }; use crate::compiler::Compiler; use crate::errors::{Severity, SourceError}; @@ -247,8 +248,8 @@ impl<'a> Typechecker<'a> { let length = self.compiler.indexer.len(); let last_indexer = self.compiler.indexer[length - 1]; match last_indexer { - NodeIndexer::General(node_id) => self.typecheck_node(node_id), - NodeIndexer::Block(block_id) => {self.typecheck_block(block_id, TOP_TYPE);} + NodeIndexer::General(node_id) => self.typecheck_node(&node_id), + NodeIndexer::Block(block_id) => {self.typecheck_block(&block_id, TOP_TYPE);} _ => return; } for i in 0..self.type_vars.len() { @@ -297,24 +298,24 @@ impl<'a> Typechecker<'a> { self.types[type_id.0] } - fn typecheck_node(&mut self, node_id: NodeId) { + fn typecheck_node(&mut self, node_id: &NodeId) { let node = node_id.get_node(self.compiler); match node { AstNode::Params(ref params) => { for param in params { - self.typecheck_node(*param); + self.typecheck_node(param); } // Params are not supposed to be evaluated node_id.set_node_type_id(self, FORBIDDEN_TYPE); } AstNode::Param { name, ty } => { if let Some(ty) = ty { - let ty_id = self.typecheck_type(*ty); + let ty_id = self.typecheck_type(ty); let var_id = self .compiler .var_resolution - .get(name) + .get(&NameOrVariable::Name(*name)) .expect("missing resolved variable"); self.variable_types[var_id.0] = ty_id; node_id.set_node_type_id(self, ty_id); @@ -324,7 +325,7 @@ impl<'a> Typechecker<'a> { } AstNode::TypeArgs(ref args) => { for arg in args { - self.typecheck_type(*arg); + self.typecheck_type(arg); } // Type argument lists are not supposed to be evaluated node_id.set_node_type_id(self, FORBIDDEN_TYPE); @@ -335,7 +336,7 @@ impl<'a> Typechecker<'a> { "unsupported/unexpected ast node '{:?}' in typechecker", node ), - NodeIndexer::General(node_id) + node_id.into_indexer() ), } } @@ -351,7 +352,7 @@ impl<'a> Typechecker<'a> { }; match inner_node_id { StatementOrExpression::Statement(stmt_id) => self.typecheck_stmt(*stmt_id), - StatementOrExpression::Expression(expr_id) => self.typecheck_expr(*expr_id, expected_type), + StatementOrExpression::Expression(expr_id) => {self.typecheck_expr(expr_id, expected_type);} } } @@ -395,7 +396,7 @@ impl<'a> Typechecker<'a> { let var_id = self .compiler .var_resolution - .get(&variable) + .get(&NameOrVariable::Variable(*variable)) .expect("missing resolved variable"); if let Type::List(type_id) = self.type_of(range) { self.variable_types[var_id.0] = type_id; @@ -507,7 +508,7 @@ impl<'a> Typechecker<'a> { ExpressionNode::Closure { params, block } => { // TODO: input/output types if let Some(params_node_id) = params { - self.typecheck_node(*params_node_id); + self.typecheck_node(params_node_id); } self.typecheck_block(block, expected); @@ -518,7 +519,7 @@ impl<'a> Typechecker<'a> { let var_id = self .compiler .var_resolution - .get(variable_node_id) + .get(&NameOrVariable::Variable(*variable_node_id)) .expect("missing resolved variable"); self.variable_types[var_id.0] @@ -555,7 +556,10 @@ impl<'a> Typechecker<'a> { NONE_TYPE } } - ExpressionNode::Call { head, parts } => self.typecheck_call(head, parts, node_id), + ExpressionNode::Call { head, parts } => { + // need to make sure that the node_id is either Name or String. + self.typecheck_call(head, parts, node_id) + } ExpressionNode::Match { ref target, ref match_arms, @@ -748,7 +752,7 @@ impl<'a> Typechecker<'a> { if !self.constrain_subtype(lhs_ty, STRING_TYPE) { self.error( format!("Expected string, got {}", self.type_to_string(lhs_ty)), - NodeIndexer::Expression(lhs), + lhs.into_indexer() ); } STRING_TYPE @@ -756,7 +760,7 @@ impl<'a> Typechecker<'a> { if !self.constrain_subtype(lhs_ty, NUMBER_TYPE) { self.error( format!("Expected number, got {}", self.type_to_string(lhs_ty)), - NodeIndexer::Expression(lhs) + lhs.into_indexer() ); } self.numeric_op_type(lhs_ty, rhs_ty) @@ -831,8 +835,8 @@ impl<'a> Typechecker<'a> { panic!("internal error: return type is not a return type"); }; InOutType { - in_type: self.typecheck_type(*in_ty), - out_type: self.typecheck_type(*out_ty), + in_type: self.typecheck_type(in_ty), + out_type: self.typecheck_type(out_ty), } }) .collect::>() @@ -840,20 +844,20 @@ impl<'a> Typechecker<'a> { .unwrap_or_default(); self.typecheck_node(params); - self.typecheck_node(block); + self.typecheck_block(block, TOP_TYPE); node_id.set_node_type_id(self, NONE_TYPE); // set input/output types for the command let decl_id = self .compiler .decl_resolution - .get(&name) + .get(&name.into_indexer()) .expect("missing declared decl"); if in_out_types.is_empty() { self.decl_types[decl_id.0] = vec![InOutType { in_type: ANY_TYPE, - out_type: block.get_id_of(self.compiler), + out_type: block.type_id_of(self), }]; } else { // TODO check that block output type matches expected type @@ -868,10 +872,10 @@ impl<'a> Typechecker<'a> { let decl_id_new = self .compiler .decl_resolution - .get(&new_name) + .get(&new_name.into_indexer()) .expect("missing declared new name for alias"); - let decl_id_old = self.compiler.decl_resolution.get(&old_name); + let decl_id_old = self.compiler.decl_resolution.get(&old_name.into_indexer()); self.decl_types[decl_id_new.0] = decl_id_old.map_or( vec![InOutType { @@ -885,23 +889,22 @@ impl<'a> Typechecker<'a> { // TODO: something strange inside this function. // The type of `self.compiler.deco_resolution` is unclear. fn typecheck_call(&mut self, head: &[NameNodeId], parts: &[ExpressionNodeId], node_id: &ExpressionNodeId) -> TypeId { - if let Some(decl_id) = self.compiler.decl_resolution.get(&node_id) { - let num_name_parts = self.compiler.decls[decl_id.0].name().split(' ').count(); + if let Some(decl_id) = self.compiler.decl_resolution.get(&node_id.into_indexer()) { let decl_node_id = self.compiler.decl_nodes[decl_id.0]; - let AstNode::Def { + let StatementNode::Def { type_params, params, .. - } = self.compiler.get_node(decl_node_id) + } = decl_node_id.get_node(&self.compiler) else { panic!("Internal error: Expected def") }; - let AstNode::Params(params) = self.compiler.get_node(*params) else { + let AstNode::Params(params) = params.get_node(self.compiler) else { panic!("Internal error: Expected params") }; let type_substs = if let Some(type_params) = type_params { - let AstNode::Params(type_params) = self.compiler.get_node(*type_params) else { + let AstNode::Params(type_params) = type_params.get_node(self.compiler) else { panic!("Internal error: expected type params"); }; let mut type_substs = HashMap::new(); @@ -915,18 +918,18 @@ impl<'a> Typechecker<'a> { HashMap::new() }; - let num_args = parts.len() - num_name_parts; + let num_args = parts.len() - head.len(); if params.len() != num_args { self.error( format!("Expected {} argument(s), got {}", params.len(), num_args), - node_id, + node_id.into_indexer(), ); } - for (param, arg) in params.iter().zip(&parts[num_name_parts..]) { + for (param, arg) in params.iter().zip(parts) { let expected = self.type_id_of(*param); let expected = self.subst(expected, &type_substs); - if matches!(self.compiler.ast_nodes[arg.0], AstNode::Name) { - self.set_node_type_id(*arg, STRING_TYPE); + if matches!(arg.get_node(&self.compiler), ExpressionNode::Name(_)) { + arg.set_node_type_id(self, STRING_TYPE); if !self.constrain_subtype(STRING_TYPE, expected) { self.error( format!("Expected {}, got string", self.type_to_string(expected)), @@ -934,16 +937,16 @@ impl<'a> Typechecker<'a> { ); } } else { - self.typecheck_expr(*arg, expected); + self.typecheck_expr(arg, expected); } } if num_args > params.len() { // Typecheck extra arguments too - for arg in &parts[num_name_parts + params.len()..] { - if matches!(self.compiler.ast_nodes[arg.0], AstNode::Name) { - self.set_node_type_id(*arg, STRING_TYPE); + for arg in &parts[params.len()..] { + if matches!(arg.get_node(&self.compiler), ExpressionNode::Name(_)) { + arg.set_node_type_id(self,STRING_TYPE); } else { - self.typecheck_expr(*arg, TOP_TYPE); + self.typecheck_expr(arg, TOP_TYPE); } } } @@ -994,8 +997,8 @@ impl<'a> Typechecker<'a> { node_id.set_node_type_id(self, NONE_TYPE); } - fn typecheck_type(&mut self, node_id: NodeId) -> TypeId { - let ty_id = match self.compiler.ast_nodes[node_id.0] { + fn typecheck_type(&mut self, node_id: &NodeId) -> TypeId { + let ty_id = match node_id.get_node(&self.compiler) { AstNode::Type { name, args, @@ -1005,27 +1008,29 @@ impl<'a> Typechecker<'a> { fields, optional: _, // TODO handle optional record types } => { - let AstNode::Params(field_nodes) = self.compiler.get_node(fields) else { + let AstNode::Params(field_nodes) = fields.get_node(&self.compiler) else { panic!("internal error: record fields aren't Params"); }; let mut fields = field_nodes .iter() .map(|field| { - let AstNode::Param { name, ty } = self.compiler.get_node(*field) else { + let AstNode::Param { name, ty } = field.get_node(&self.compiler) else { panic!("internal error: record field isn't Param"); }; let ty_id = match ty { Some(ty) => { - self.typecheck_type(*ty); + self.typecheck_type(ty); self.type_id_of(*ty) } None => ANY_TYPE, }; - (*name, ty_id) + // NOTE: a bad way to convert from NameNodeId to ExpressionNodeId + let expr_node_id = self.compiler.expression_nodes.iter_nodes().position(|expr_node| *expr_node == ExpressionNode::Name(*name)).expect("the Expression::Name should exist"); + (ExpressionNodeId(expr_node_id), ty_id) }) .collect::>(); // Store fields sorted by name - fields.sort_by_cached_key(|(name, _)| self.compiler.get_span_contents(*name)); + fields.sort_by_cached_key(|(name, _)| name.get_span_contents(&self.compiler)); self.record_types.push(fields); self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))) @@ -1034,9 +1039,9 @@ impl<'a> Typechecker<'a> { self.error( format!( "Internal error: expected type, got '{:?}'", - self.compiler.ast_nodes[node_id.0] + node_id.get_node(&self.compiler) ), - node_id, + node_id.into_indexer() ); ERROR_TYPE } @@ -1047,11 +1052,11 @@ impl<'a> Typechecker<'a> { fn typecheck_type_ref( &mut self, - name_id: NodeId, - args_id: Option, - _optional: bool, + name_id: &NameNodeId, + args_id: &Option, + _optional: &bool, ) -> TypeId { - let name = self.compiler.get_span_contents(name_id); + let name = name_id.get_span_contents(&self.compiler); // taken from parse_shape_name() in Nushell: match name { @@ -1062,14 +1067,14 @@ impl<'a> Typechecker<'a> { if let Some(args_id) = args_id { self.typecheck_node(args_id); - if let AstNode::TypeArgs(args) = self.compiler.get_node(args_id) { + if let AstNode::TypeArgs(args) = args_id.get_node(&self.compiler) { if args.len() > 1 { let types = - String::from_utf8_lossy(self.compiler.get_span_contents(args_id)); - self.error(format!("list must have only one type argument (to allow selection of types, use oneof{} -- WIP)", types), args_id); + String::from_utf8_lossy(args_id.get_span_contents(&self.compiler)); + self.error(format!("list must have only one type argument (to allow selection of types, use oneof{} -- WIP)", types), args_id.into_indexer()); self.push_type(Type::List(UNKNOWN_TYPE)) } else if args.is_empty() { - self.error("list must have one type argument", args_id); + self.error("list must have one type argument", args_id.into_indexer()); self.push_type(Type::List(UNKNOWN_TYPE)) } else { let args_ty_id = self.type_id_of(args[0]); @@ -1109,7 +1114,7 @@ impl<'a> Typechecker<'a> { // if bytes.contains(&b'@') { // // type with completion // } else { - if let Some(type_decl) = self.compiler.type_resolution.get(&name_id) { + if let Some(type_decl) = self.compiler.type_resolution.get(&NameOrVariable::Name(*name_id)) { self.push_type(Type::Ref(*type_decl)) } else { UNKNOWN_TYPE