diff --git a/src/ast_nodes.rs b/src/ast_nodes.rs new file mode 100644 index 0000000..ac17a07 --- /dev/null +++ b/src/ast_nodes.rs @@ -0,0 +1,603 @@ +use super::compiler::{Compiler, Span}; +use nu_protocol::{ast::Expression, engine::Variable}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct NameNodeId(pub usize); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct NameNode; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct StringNodeId(pub usize); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct StringNode; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct VariableNodeId(pub usize); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct VariableNode; + +// A helper enum for block compoments. Compiler doesn't save +// this as an individual id. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum StatementOrExpression { + Statement(StatementNodeId), + Expression(ExpressionNodeId), +} + +// A helper enum for block compoments. Compiler doesn't save +// this as an individual id. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum NameOrString { + Name(NameNodeId), + String(StringNodeId), +} +impl NameOrString { + pub fn into_indexer(self) -> NodeIndexer { + match self { + Self::Name(x) => x.into_indexer(), + Self::String(x) => x.into_indexer(), + } + } +} + +// A helper enum for block compoments. Compiler doesn't save +// this as an individual id. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum NameOrVariable { + Name(NameNodeId), + Variable(VariableNodeId), +} + +impl NameOrVariable { + pub fn into_indexer(self) -> NodeIndexer { + match self { + Self::Name(x) => x.into_indexer(), + Self::Variable(x) => x.into_indexer(), + } + } +} + +#[derive(Debug, Clone)] +pub struct BlockNode { + pub nodes: Vec, +} + +impl BlockNode { + pub fn new(nodes: Vec) -> BlockNode { + BlockNode { nodes } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct BlockId(pub usize); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct PipelineId(pub usize); + +// Pipeline just contains a list of expressions +// +// It's not allowed if there is only one element in pipeline, in that +// case, it's just an expression. +// +// Making such restriction can reduce indirect access on expression, which +// can improve performance in parse time. +#[derive(Debug, Clone, PartialEq)] +pub struct PipelineNode { + pub nodes: Vec, +} + +impl PipelineNode { + pub fn new(nodes: Vec) -> Self { + debug_assert!( + nodes.len() > 1, + "a pipeline must contain at least 2 nodes, or else it's actually an expression" + ); + Self { nodes } + } + + pub fn get_expressions(&self) -> &Vec { + &self.nodes + } +} +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ExpressionNode { + Int, + Float, + String(StringNodeId), + Name(NameNodeId), + Variable(VariableNodeId), + + // Booleans + True, + False, + + // Empty values + Null, + + VarRef, + + Closure { + params: Option, + block: BlockId, + }, + + Call { + head: Vec, + parts: Vec, + }, + NamedValue { + name: NodeId, + value: NodeId, + }, + BinaryOp { + lhs: ExpressionNodeId, + op: NodeId, + rhs: ExpressionNodeId, + }, + Range { + lhs: ExpressionNodeId, + rhs: ExpressionNodeId, + }, + List(Vec), + Table { + header: ExpressionNodeId, + rows: Vec, + }, + Record { + pairs: Vec<(ExpressionNodeId, ExpressionNodeId)>, + }, + MemberAccess { + target: ExpressionNodeId, + field: ExpressionNodeId, + }, + // Pipeline is also an expression, and it contains a list of expressions. + Pipeline(PipelineId), + If { + condition: ExpressionNodeId, + then_block: BlockId, + else_block: Option, // it can be a block, or another if expression (else if) + }, + Try { + try_block: BlockId, + catch_block: Option, + finally_block: Option, + }, + Match { + target: ExpressionNodeId, + match_arms: Vec<(ExpressionNodeId, ExpressionNodeId)>, + }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ExpressionNodeId(pub usize); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum StatementNode { + // Definitions + Def { + name: NameOrString, + type_params: Option, + params: NodeId, + in_out_types: Option, + block: BlockId, + env: bool, + wrapped: bool, + }, + Extern { + name: NameOrString, + params: NodeId, + }, + Alias { + new_name: NameOrString, + old_name: NameOrString, + }, + Let { + variable_name: VariableNodeId, + ty: Option, + initializer: ExpressionNodeId, + is_mutable: bool, + }, + + While { + condition: ExpressionNodeId, + block: BlockId, + }, + For { + variable: VariableNodeId, + range: ExpressionNodeId, + block: BlockId, + }, + Loop { + block: BlockId, + }, + Return(Option), + Break, + Continue, + + Expr(ExpressionNodeId), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct StatementNodeId(pub usize); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct NodeId(pub usize); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum NodeIndexer { + Name(NameNodeId), + String(StringNodeId), + Variable(VariableNodeId), + Expression(ExpressionNodeId), + Statement(StatementNodeId), + Block(BlockId), + Pipeline(PipelineId), + General(NodeId), +} + +// TODO: All nodes with Vec<...> should be moved to their own ID (like BlockId) to allow Copy trait +#[derive(Debug, Clone, PartialEq)] +pub enum AstNode { + Type { + name: NameNodeId, + args: Option, + optional: bool, + }, + TypeArgs(Vec), + RecordType { + /// Contains [AstNode::Params] + fields: NodeId, + optional: bool, + }, + VarDecl, + + // Operators + Pow, + Multiply, + Divide, + FloorDiv, + Modulo, + Plus, + Minus, + Equal, + NotEqual, + LessThan, + GreaterThan, + LessThanOrEqual, + GreaterThanOrEqual, + RegexMatch, + NotRegexMatch, + In, + Append, + And, + Xor, + Or, + + // Assignments + Assignment, + AddAssignment, + SubtractAssignment, + MultiplyAssignment, + DivideAssignment, + AppendAssignment, + + TypeParams(Vec), + Params(Vec), + Param { + name: NameNodeId, + ty: Option, + }, + InOutTypes(Vec), + /// Input/output type pair for a command + InOutType(NodeId, NodeId), + + /// Long flag ('--' + one or more letters) + FlagLong, + /// Short flag ('-' + single letter) + FlagShort, + /// Group of short flags ('-' + more than 1 letters) + FlagShortGroup, + + // ??? should statement belongs to AstNode? + Statement(StatementNodeId), + + Garbage, +} + +pub trait NodeIdGetter { + type Output; + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output; + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output; + fn get_span(&self, compiler: &Compiler) -> Span; + fn get_span_contents<'a>(&self, compiler: &'a Compiler) -> &'a [u8] { + let span = self.get_span(compiler); + compiler + .source + .get(span.start..span.end) + .expect("internal error: missing source of span") + } + fn into_indexer(self) -> NodeIndexer; +} + +pub trait NodePusher { + type Output; + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output; +} + +impl NodeIdGetter for NameNodeId { + type Output = NameNode; + + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { + compiler.name_nodes.get_node(self.0) + } + + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { + compiler.name_nodes.get_node_mut(self.0) + } + + fn get_span(&self, compiler: &Compiler) -> Span { + compiler.name_nodes.get_span(self.0) + } + + fn into_indexer(self) -> NodeIndexer { + NodeIndexer::Name(self) + } +} + +impl NodePusher for NameNode { + type Output = NameNodeId; + + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { + compiler.name_nodes.push(span, self); + + let result = NameNodeId(compiler.name_nodes.len() - 1); + // let's push expression to indexer. + let expr_node_id = ExpressionNode::Name(result).push_node(span, compiler); + let indexer = NodeIndexer::Expression(expr_node_id); + compiler.indexer.push(indexer); + + result + } +} + +impl NodeIdGetter for StringNodeId { + type Output = StringNode; + + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { + compiler.string_nodes.get_node(self.0) + } + + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { + compiler.string_nodes.get_node_mut(self.0) + } + + fn get_span(&self, compiler: &Compiler) -> Span { + compiler.string_nodes.get_span(self.0) + } + + fn into_indexer(self) -> NodeIndexer { + NodeIndexer::String(self) + } +} + +impl NodePusher for StringNode { + type Output = StringNodeId; + + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { + compiler.string_nodes.push(span, self); + + let result = StringNodeId(compiler.string_nodes.len() - 1); + // let's push expression to Indexer. + let expr_node_id = ExpressionNode::String(result).push_node(span, compiler); + let indexer = NodeIndexer::Expression(expr_node_id); + compiler.indexer.push(indexer); + + result + } +} + +impl NodeIdGetter for VariableNodeId { + type Output = VariableNode; + + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { + compiler.variable_nodes.get_node(self.0) + } + + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { + compiler.variable_nodes.get_node_mut(self.0) + } + + fn get_span(&self, compiler: &Compiler) -> Span { + compiler.variable_nodes.get_span(self.0) + } + + fn into_indexer(self) -> NodeIndexer { + NodeIndexer::Variable(self) + } +} + +impl NodePusher for VariableNode { + type Output = VariableNodeId; + + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { + compiler.variable_nodes.push(span, self); + + let result = VariableNodeId(compiler.variable_nodes.len() - 1); + // let's push expression to indexer. + let expr_node_id = ExpressionNode::Variable(result).push_node(span, compiler); + let indexer = NodeIndexer::Expression(expr_node_id); + compiler.indexer.push(indexer); + + result + } +} + +impl NodeIdGetter for BlockId { + type Output = BlockNode; + + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { + compiler.block_nodes.get_node(self.0) + } + + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { + compiler.block_nodes.get_node_mut(self.0) + } + + fn get_span(&self, compiler: &Compiler) -> Span { + compiler.block_nodes.get_span(self.0) + } + + fn into_indexer(self) -> NodeIndexer { + NodeIndexer::Block(self) + } +} + +impl NodePusher for BlockNode { + type Output = BlockId; + + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { + compiler.block_nodes.push(span, self); + + let result = BlockId(compiler.block_nodes.len() - 1); + let indexer = NodeIndexer::Block(result); + compiler.indexer.push(indexer); + + result + } +} + +impl NodeIdGetter for StatementNodeId { + type Output = StatementNode; + + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { + compiler.statement_nodes.get_node(self.0) + } + + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { + compiler.statement_nodes.get_node_mut(self.0) + } + + fn get_span(&self, compiler: &Compiler) -> Span { + compiler.statement_nodes.get_span(self.0) + } + + fn into_indexer(self) -> NodeIndexer { + NodeIndexer::Statement(self) + } +} + +impl NodePusher for StatementNode { + type Output = StatementNodeId; + + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { + compiler.statement_nodes.push(span, self); + + let result = StatementNodeId(compiler.statement_nodes.len() - 1); + let indexer = NodeIndexer::Statement(result); + compiler.indexer.push(indexer); + + result + } +} + +impl NodeIdGetter for PipelineId { + type Output = PipelineNode; + + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { + compiler.pipeline_nodes.get_node(self.0) + } + + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { + compiler.pipeline_nodes.get_node_mut(self.0) + } + + fn get_span(&self, compiler: &Compiler) -> Span { + compiler.pipeline_nodes.get_span(self.0) + } + + fn into_indexer(self) -> NodeIndexer { + NodeIndexer::Pipeline(self) + } +} + +impl NodePusher for PipelineNode { + type Output = PipelineId; + + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { + compiler.pipeline_nodes.push(span, self); + + let result = PipelineId(compiler.pipeline_nodes.len() - 1); + let indexer = NodeIndexer::Pipeline(result); + compiler.indexer.push(indexer); + + result + } +} +impl NodePusher for ExpressionNode { + type Output = ExpressionNodeId; + + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { + compiler.expression_nodes.push(span, self); + + let result = ExpressionNodeId(compiler.expression_nodes.len() - 1); + let indexer = NodeIndexer::Expression(result); + compiler.indexer.push(indexer); + + result + } +} + +impl NodeIdGetter for ExpressionNodeId { + type Output = ExpressionNode; + + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { + compiler.expression_nodes.get_node(self.0) + } + + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { + compiler.expression_nodes.get_node_mut(self.0) + } + + fn get_span(&self, compiler: &Compiler) -> Span { + compiler.expression_nodes.get_span(self.0) + } + + fn into_indexer(self) -> NodeIndexer { + NodeIndexer::Expression(self) + } +} +impl NodePusher for AstNode { + type Output = NodeId; + + fn push_node(self, span: Span, compiler: &mut Compiler) -> Self::Output { + compiler.ast_nodes.push(span, self); + + let result = NodeId(compiler.ast_nodes.len() - 1); + let indexer = NodeIndexer::General(result); + compiler.indexer.push(indexer); + + result + } +} +impl NodeIdGetter for NodeId { + type Output = AstNode; + + fn get_node<'a>(&self, compiler: &'a Compiler) -> &'a Self::Output { + compiler.ast_nodes.get_node(self.0) + } + + fn get_node_mut<'a>(&self, compiler: &'a mut Compiler) -> &'a mut Self::Output { + compiler.ast_nodes.get_node_mut(self.0) + } + + fn get_span(&self, compiler: &Compiler) -> Span { + compiler.ast_nodes.get_span(self.0) + } + + fn into_indexer(self) -> NodeIndexer { + NodeIndexer::General(self) + } +} diff --git a/src/compiler.rs b/src/compiler.rs index 11136e0..a23a176 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -1,5 +1,9 @@ +use crate::ast_nodes::{ + AstNode, BlockId, BlockNode, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, + NameOrString, NameOrVariable, NodeId, NodeIdGetter, NodeIndexer, NodePusher, PipelineNode, + StatementNode, StatementNodeId, StringNode, StringNodeId, VariableNode, VariableNodeId, +}; use crate::errors::SourceError; -use crate::parser::{AstNode, Block, NodeId, Pipeline}; use crate::protocol::Command; use crate::resolver::{ DeclId, Frame, NameBindings, ScopeId, TypeDecl, TypeDeclId, VarId, Variable, @@ -8,8 +12,12 @@ use crate::typechecker::{TypeId, Types}; use std::collections::HashMap; pub struct RollbackPoint { - idx_span_start: usize, idx_nodes: usize, + idx_name_nodes: usize, + idx_string_nodes: usize, + idx_variable_nodes: usize, + idx_expression_nodes: usize, + idx_statment_nodes: usize, idx_errors: usize, idx_blocks: usize, token_pos: usize, @@ -39,15 +47,69 @@ impl Spanned { } } +#[derive(Clone, Debug)] +pub struct NodeSpans { + nodes: Vec, // indexed by relative nodeId + spans: Vec, +} + +impl NodeSpans { + pub fn new() -> Self { + Self { + nodes: vec![], + spans: vec![], + } + } + pub fn get_span(&self, i: usize) -> Span { + self.spans[i] + } + + pub fn get_node(&self, i: usize) -> &T { + &self.nodes[i] + } + + pub fn get_node_mut(&mut self, i: usize) -> &mut T { + &mut self.nodes[i] + } + + pub fn push(&mut self, span: Span, node: T) { + self.spans.push(span); + self.nodes.push(node); + } + + pub fn len(&self) -> usize { + self.nodes.len() + } + + pub fn truncate(&mut self, len: usize) { + self.nodes.truncate(len); + self.spans.truncate(len); + } + + pub fn is_empty(&self) -> bool { + self.nodes.is_empty() + } + + pub fn iter_nodes(&self) -> std::slice::Iter<'_, T> { + self.nodes.iter() + } +} + #[derive(Clone)] pub struct Compiler { - // Core information, indexed by NodeId: - pub spans: Vec, - pub ast_nodes: Vec, + // different types of nodes. + pub name_nodes: NodeSpans, + pub string_nodes: NodeSpans, + pub variable_nodes: NodeSpans, + pub expression_nodes: NodeSpans, + pub ast_nodes: NodeSpans, + pub statement_nodes: NodeSpans, + pub block_nodes: NodeSpans, // Blocks, indexed by BlockId + pub pipeline_nodes: NodeSpans, // Pipelines, indexed by PipelineId + pub indexer: Vec, + pub node_types: Vec, // node_lifetimes: Vec, - pub blocks: Vec, // Blocks, indexed by BlockId - pub pipelines: Vec, // Pipelines, indexed by PipelineId pub source: Vec, pub file_offsets: Vec<(String, usize, usize)>, // fname, start, end @@ -59,17 +121,19 @@ pub struct Compiler { /// Variables, indexed by VarId pub variables: Vec, /// Mapping of variable's name node -> Variable - pub var_resolution: HashMap, + pub var_resolution: HashMap, /// Type declarations, indexed by TypeDeclId pub type_decls: Vec, /// Mapping of type decl's name node -> TypeDecl - pub type_resolution: HashMap, + pub type_resolution: HashMap, /// Declarations (commands, aliases, externs), indexed by DeclId pub decls: Vec>, /// Declaration NodeIds, indexed by DeclId - pub decl_nodes: Vec, + pub decl_nodes: Vec, /// Mapping of decl's name node -> Command - pub decl_resolution: HashMap, + /// It can be NameOrString, or an AstNode::Call. + // NOTE: not sure why it can be ExpressionNode::Call, but let's keep the original behavior. + pub decl_resolution: HashMap, // Definitions: // indexed by FunId @@ -91,11 +155,16 @@ impl Default for Compiler { impl Compiler { pub fn new() -> Self { Self { - spans: vec![], - ast_nodes: vec![], + string_nodes: NodeSpans::new(), + variable_nodes: NodeSpans::new(), + ast_nodes: NodeSpans::new(), + name_nodes: NodeSpans::new(), + expression_nodes: NodeSpans::new(), + statement_nodes: NodeSpans::new(), + pipeline_nodes: NodeSpans::new(), node_types: vec![], - blocks: vec![], - pipelines: vec![], + indexer: vec![], + block_nodes: NodeSpans::new(), source: vec![], file_offsets: vec![], @@ -128,20 +197,62 @@ impl Compiler { // TODO: This should say PARSER, not COMPILER let mut result = "==== COMPILER ====\n".to_string(); - for (idx, ast_node) in self.ast_nodes.iter().enumerate() { + for (idx, indexer) in self.indexer.iter().enumerate() { + let (node_dbg_string, span) = match indexer { + NodeIndexer::String(i) => ( + format!("{:?}", self.string_nodes.get_node(i.0)), + self.string_nodes.get_span(i.0), + ), + NodeIndexer::Name(i) => ( + format!("{:?}", self.name_nodes.get_node(i.0)), + self.name_nodes.get_span(i.0), + ), + NodeIndexer::Variable(i) => ( + format!("{:?}", self.variable_nodes.get_node(i.0)), + self.variable_nodes.get_span(i.0), + ), + NodeIndexer::Expression(i) => ( + format!("{:?}", self.expression_nodes.get_node(i.0)), + self.expression_nodes.get_span(i.0), + ), + NodeIndexer::Statement(i) => ( + format!("{:?}", self.statement_nodes.get_node(i.0)), + self.statement_nodes.get_span(i.0), + ), + NodeIndexer::General(i) => ( + format!("{:?}", self.ast_nodes.get_node(i.0)), + self.ast_nodes.get_span(i.0), + ), + NodeIndexer::Block(i) => ( + format!("{:?}", self.block_nodes.get_node(i.0)), + self.block_nodes.get_span(i.0), + ), + NodeIndexer::Pipeline(i) => ( + format!("{:?}", self.pipeline_nodes.get_node(i.0)), + self.pipeline_nodes.get_span(i.0), + ), + }; result.push_str(&format!( - "{}: {:?} ({} to {})", - idx, ast_node, self.spans[idx].start, self.spans[idx].end + "{}: {} ({} to {})", + idx, node_dbg_string, span.start, span.end )); if matches!( - ast_node, - AstNode::Name | AstNode::Variable | AstNode::Int | AstNode::Float | AstNode::String + indexer, + NodeIndexer::Name(_) | NodeIndexer::Variable(_) | NodeIndexer::String(_) ) { result.push_str(&format!( " \"{}\"", - String::from_utf8_lossy(self.get_span_contents(NodeId(idx))) + String::from_utf8_lossy(self.get_span_contents(*indexer)) )); + } else if let NodeIndexer::Expression(i) = indexer { + let node = self.expression_nodes.get_node(i.0); + if matches!(node, ExpressionNode::Int | ExpressionNode::Float) { + result.push_str(&format!( + " \"{}\"", + String::from_utf8_lossy(self.get_span_contents(*indexer)) + )); + } } result.push('\n'); @@ -151,8 +262,8 @@ impl Compiler { result.push_str("==== COMPILER ERRORS ====\n"); for error in &self.errors { result.push_str(&format!( - "{:?} (NodeId {}): {}\n", - error.severity, error.node_id.0, error.message + "{:?} (NodeId {:?}): {}\n", + error.severity, error.node_id, error.message )); } } @@ -164,12 +275,12 @@ impl Compiler { self.scope.extend(name_bindings.scope); self.scope_stack.extend(name_bindings.scope_stack); self.variables.extend(name_bindings.variables); - self.var_resolution.extend(name_bindings.var_resolution); + // self.var_resolution.extend(name_bindings.var_resolution); self.type_decls.extend(name_bindings.type_decls); - self.type_resolution.extend(name_bindings.type_resolution); + // self.type_resolution.extend(name_bindings.type_resolution); self.decls.extend(name_bindings.decls); - self.decl_nodes.extend(name_bindings.decl_nodes); - self.decl_resolution.extend(name_bindings.decl_resolution); + // self.decl_nodes.extend(name_bindings.decl_nodes); + // self.decl_resolution.extend(name_bindings.decl_resolution); self.errors.extend(name_bindings.errors); } @@ -191,50 +302,62 @@ impl Compiler { self.source.len() } - pub fn get_node(&self, node_id: NodeId) -> &AstNode { - &self.ast_nodes[node_id.0] + pub fn get_node(&self, node_id: T) -> &T::Output { + node_id.get_node(self) } - pub fn get_node_mut(&mut self, node_id: NodeId) -> &mut AstNode { - &mut self.ast_nodes[node_id.0] + pub fn get_node_mut(&mut self, node_id: T) -> &mut T::Output { + node_id.get_node_mut(self) } - pub fn push_node(&mut self, ast_node: AstNode) -> NodeId { - self.ast_nodes.push(ast_node); - - NodeId(self.ast_nodes.len() - 1) + pub fn push_node(&mut self, span: Span, ast_node: T) -> T::Output { + ast_node.push_node(span, self) } pub fn get_rollback_point(&self, token_pos: usize) -> RollbackPoint { RollbackPoint { - idx_span_start: self.spans.len(), idx_nodes: self.ast_nodes.len(), + idx_name_nodes: self.name_nodes.len(), + idx_string_nodes: self.string_nodes.len(), + idx_variable_nodes: self.variable_nodes.len(), + idx_expression_nodes: self.expression_nodes.len(), + idx_statment_nodes: self.statement_nodes.len(), idx_errors: self.errors.len(), - idx_blocks: self.blocks.len(), + idx_blocks: self.block_nodes.len(), token_pos, } } pub fn apply_compiler_rollback(&mut self, rbp: RollbackPoint) -> usize { - self.blocks.truncate(rbp.idx_blocks); + self.block_nodes.truncate(rbp.idx_blocks); self.ast_nodes.truncate(rbp.idx_nodes); + self.name_nodes.truncate(rbp.idx_name_nodes); + self.string_nodes.truncate(rbp.idx_string_nodes); + self.variable_nodes.truncate(rbp.idx_variable_nodes); self.errors.truncate(rbp.idx_errors); - self.spans.truncate(rbp.idx_span_start); rbp.token_pos } /// Get span of node - pub fn get_span(&self, node_id: NodeId) -> Span { - *self - .spans - .get(node_id.0) - .expect("internal error: missing span of node") + /// TODO: no need this. + pub fn get_span(&self, node_indexer: NodeIndexer) -> Span { + match node_indexer { + NodeIndexer::String(i) => self.string_nodes.get_span(i.0), + NodeIndexer::Name(i) => self.name_nodes.get_span(i.0), + NodeIndexer::Variable(i) => self.variable_nodes.get_span(i.0), + NodeIndexer::General(i) => self.ast_nodes.get_span(i.0), + NodeIndexer::Expression(i) => self.expression_nodes.get_span(i.0), + NodeIndexer::Block(i) => self.block_nodes.get_span(i.0), + NodeIndexer::Statement(i) => self.statement_nodes.get_span(i.0), + NodeIndexer::Pipeline(i) => self.pipeline_nodes.get_span(i.0), + } } /// Get the source contents of a span of a node - pub fn get_span_contents(&self, node_id: NodeId) -> &[u8] { - let span = self.get_span(node_id); + /// TODO: no need this. + pub fn get_span_contents(&self, node_indexer: NodeIndexer) -> &[u8] { + let span = self.get_span(node_indexer); self.source .get(span.start..span.end) .expect("internal error: missing source of span") @@ -248,14 +371,14 @@ impl Compiler { } /// Get the source contents of a node - pub fn node_as_str(&self, node_id: NodeId) -> &str { - std::str::from_utf8(self.get_span_contents(node_id)) + pub fn node_as_str(&self, node_indexer: NodeIndexer) -> &str { + std::str::from_utf8(self.get_span_contents(node_indexer)) .expect("internal error: expected utf8 string") } /// Get the source contents of a node as i64 - pub fn node_as_i64(&self, node_id: NodeId) -> i64 { - self.node_as_str(node_id) + pub fn node_as_i64(&self, node_indexer: NodeIndexer) -> i64 { + self.node_as_str(node_indexer) .parse::() .expect("internal error: expected i64") } diff --git a/src/errors.rs b/src/errors.rs index 6f0d562..b395c59 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,4 +1,4 @@ -use crate::parser::NodeId; +use crate::{ast_nodes::NodeIndexer, parser::NodeId}; #[derive(Debug, Clone, Copy)] pub enum Severity { @@ -9,6 +9,6 @@ pub enum Severity { #[derive(Debug, Clone)] pub struct SourceError { pub message: String, - pub node_id: NodeId, + pub node_id: NodeIndexer, pub severity: Severity, } diff --git a/src/ir_generator.rs b/src/ir_generator.rs index 74a61f4..93663cd 100644 --- a/src/ir_generator.rs +++ b/src/ir_generator.rs @@ -110,7 +110,7 @@ impl<'a> IrGenerator<'a> { Some(next_reg) } AstNode::Block(block_id) => { - let block = &self.compiler.blocks[block_id.0]; + let block = &self.compiler.block_nodes[block_id.0]; let mut last = None; for id in &block.nodes { last = self.generate_node(*id); diff --git a/src/lib.rs b/src/lib.rs index 6d74421..5c9dd37 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod ast_nodes; pub mod compiler; pub mod errors; pub mod ir_generator; diff --git a/src/parser.rs b/src/parser.rs index 0bece02..91db8f9 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,3 +1,9 @@ +use crate::ast_nodes::{ + AstNode, BlockId, BlockNode, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, + NameOrString, NodeId, NodeIdGetter, NodeIndexer, NodePusher, PipelineId, PipelineNode, + StatementNode, StatementNodeId, StatementOrExpression, StringNode, StringNodeId, VariableNode, + VariableNodeId, +}; use crate::compiler::{Compiler, RollbackPoint, Span}; use crate::errors::{Severity, SourceError}; use crate::lexer::{Token, Tokens}; @@ -9,52 +15,6 @@ pub struct Parser { tokens: Tokens, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct NodeId(pub usize); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct BlockId(pub usize); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct PipelineId(pub usize); - -#[derive(Debug, Clone)] -pub struct Block { - pub nodes: Vec, -} - -impl Block { - pub fn new(nodes: Vec) -> Block { - Block { nodes } - } -} - -// Pipeline just contains a list of expressions -// -// It's not allowed if there is only one element in pipeline, in that -// case, it's just an expression. -// -// Making such restriction can reduce indirect access on expression, which -// can improve performance in parse time. -#[derive(Debug, Clone, PartialEq)] -pub struct Pipeline { - pub nodes: Vec, -} - -impl Pipeline { - pub fn new(nodes: Vec) -> Self { - debug_assert!( - nodes.len() > 1, - "a pipeline must contain at least 2 nodes, or else it's actually an expression" - ); - Self { nodes } - } - - pub fn get_expressions(&self) -> &Vec { - &self.nodes - } -} - #[derive(Debug, Clone, PartialEq)] pub enum BlockContext { /// This block is a whole block of code not wrapped in curlies (e.g., a file) @@ -96,173 +56,6 @@ impl AssignmentOrExpression { } } -// TODO: All nodes with Vec<...> should be moved to their own ID (like BlockId) to allow Copy trait -#[derive(Debug, PartialEq, Clone)] -pub enum AstNode { - Int, - Float, - String, - Name, - Type { - name: NodeId, - args: Option, - optional: bool, - }, - TypeArgs(Vec), - RecordType { - /// Contains [AstNode::Params] - fields: NodeId, - optional: bool, - }, - Variable, - - // Booleans - True, - False, - - // Empty values - Null, - - // Operators - Pow, - Multiply, - Divide, - FloorDiv, - Modulo, - Plus, - Minus, - Equal, - NotEqual, - LessThan, - GreaterThan, - LessThanOrEqual, - GreaterThanOrEqual, - RegexMatch, - NotRegexMatch, - In, - Append, - And, - Xor, - Or, - - // Assignments - Assignment, - AddAssignment, - SubtractAssignment, - MultiplyAssignment, - DivideAssignment, - AppendAssignment, - - // Statements - Let { - variable_name: NodeId, - ty: Option, - initializer: NodeId, - is_mutable: bool, - }, - While { - condition: NodeId, - block: NodeId, - }, - For { - variable: NodeId, - range: NodeId, - block: NodeId, - }, - Loop { - block: NodeId, - }, - Return(Option), - Break, - Continue, - - // Definitions - Def { - name: NodeId, - type_params: Option, - params: NodeId, - in_out_types: Option, - block: NodeId, - env: bool, - wrapped: bool, - }, - Extern { - name: NodeId, - params: NodeId, - }, - Params(Vec), - Param { - name: NodeId, - ty: Option, - }, - InOutTypes(Vec), - /// Input/output type pair for a command - InOutType(NodeId, NodeId), - Closure { - params: Option, - block: NodeId, - }, - Alias { - new_name: NodeId, - old_name: NodeId, - }, - - /// Long flag ('--' + one or more letters) - FlagLong, - /// Short flag ('-' + single letter) - FlagShort, - /// Group of short flags ('-' + more than 1 letters) - FlagShortGroup, - - // Expressions - Call { - parts: Vec, - }, - NamedValue { - name: NodeId, - value: NodeId, - }, - BinaryOp { - lhs: NodeId, - op: NodeId, - rhs: NodeId, - }, - Range { - lhs: NodeId, - rhs: NodeId, - }, - List(Vec), - Table { - header: NodeId, - rows: Vec, - }, - Record { - pairs: Vec<(NodeId, NodeId)>, - }, - MemberAccess { - target: NodeId, - field: NodeId, - }, - Block(BlockId), - Pipeline(PipelineId), - If { - condition: NodeId, - then_block: NodeId, - else_block: Option, - }, - Try { - try_block: NodeId, - catch_block: Option, - finally_block: Option, - }, - Match { - target: NodeId, - match_arms: Vec<(NodeId, NodeId)>, - }, - Statement(NodeId), - Garbage, -} - pub const ASSIGNMENT_PRECEDENCE: usize = 10; impl AstNode { @@ -304,23 +97,27 @@ impl Parser { self.tokens.peek_span().start } - fn get_span_end(&self, node_id: NodeId) -> usize { - self.compiler.spans[node_id.0].end + fn get_span_end(&self, node_id: T) -> usize { + node_id.get_span(&self.compiler).end } pub fn parse(mut self) -> Compiler { let _span = span!(); - self.block(BlockContext::Bare); + let _ = self.block(BlockContext::Bare); self.compiler } - pub fn expression(&mut self) -> NodeId { + pub fn expression(&mut self) -> Option { let _span = span!(); - self.math_expression(false).get_node_id() + self.math_expression(false) } - fn pipeline(&mut self, first_element: NodeId, span_start: usize) -> NodeId { + fn pipeline( + &mut self, + first_element: ExpressionNodeId, + span_start: usize, + ) -> Option { let mut expressions = vec![first_element]; while self.is_pipe() { self.pipe(); @@ -328,46 +125,67 @@ impl Parser { if self.is_newline() { self.tokens.advance() } - expressions.push(self.expression()); + expressions.push(self.expression()?); } - self.compiler.pipelines.push(Pipeline::new(expressions)); let span_end = self.position(); - self.create_node( - AstNode::Pipeline(PipelineId(self.compiler.pipelines.len() - 1)), - span_start, - span_end, - ) + let pipeline_id = PipelineNode::new(expressions).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ); + // pipeline itself is an expression, so we push an expression node for it. + // It may make more overhead but it simpifies this `pipeline` interface. + Some(ExpressionNode::Pipeline(pipeline_id).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } - pub fn pipeline_or_expression_or_assignment(&mut self) -> NodeId { + + pub fn pipeline_or_expression_or_assignment(&mut self) -> Option { // get the first expression let _span = span!(); let span_start = self.position(); - let first = self.math_expression(true); - let first_id = first.get_node_id(); - if let AssignmentOrExpression::Assignment(_) = &first { - return first_id; + let first = self.math_expression(true)?; + if let ExpressionNode::BinaryOp { op, .. } = self.compiler.get_node(first) { + if matches!( + self.compiler.get_node(*op), + AstNode::Assignment + | AstNode::AddAssignment + | AstNode::SubtractAssignment + | AstNode::MultiplyAssignment + | AstNode::DivideAssignment + | AstNode::AppendAssignment + ) { + return Some(first); + } } // pipeline with one element is an expression actually if !self.is_pipe() { - return first_id; + return Some(first); } - self.pipeline(first_id, span_start) + self.pipeline(first, span_start) } - pub fn pipeline_or_expression(&mut self) -> NodeId { + // Can be a pipeline or expression + pub fn pipeline_or_expression(&mut self) -> Option { let _span = span!(); let span_start = self.position(); - let first_id = self.expression(); + let first_id = self.expression()?; // pipeline with one element is an expression actually. if !self.is_pipe() { - return first_id; + return Some(first_id); } self.pipeline(first_id, span_start) } - fn math_expression(&mut self, allow_assignment: bool) -> AssignmentOrExpression { + fn math_expression(&mut self, allow_assignment: bool) -> Option { let _span = span!(); - let mut expr_stack = Vec::<(NodeId, NodeId)>::new(); + let mut expr_stack = Vec::<(NodeId, ExpressionNodeId)>::new(); let mut last_prec = 1000000; @@ -375,63 +193,69 @@ impl Parser { // Check for special forms if self.is_keyword(b"if") { - return AssignmentOrExpression::Expression(self.if_expression()); + return self.if_expression(); } else if self.is_keyword(b"match") { - return AssignmentOrExpression::Expression(self.match_expression()); + return self.match_expression(); } else if self.is_keyword(b"try") { - return AssignmentOrExpression::Expression(self.try_expression()); + return self.try_expression(); } // TODO // } else if self.is_keyword(b"where") { // } // Otherwise assume a math expression - let mut leftmost = self.simple_expression(BarewordContext::Call); + let mut leftmost = self.simple_expression(BarewordContext::Call)?; if self.is_equals() { if !allow_assignment { self.error("assignment found in expression"); } - let op = self.operator(); + let op = self.operator()?; - let rhs = self.pipeline_or_expression(); + let rhs = self.pipeline_or_expression()?; let span_end = self.get_span_end(rhs); - return AssignmentOrExpression::Assignment(self.create_node( - AstNode::BinaryOp { + return Some( + ExpressionNode::BinaryOp { lhs: leftmost, op, rhs, - }, - span_start, - span_end, - )); + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), + ); } while self.has_tokens() { if self.is_operator() { let missing_space_before_op = !self.is_horizontal_space(); - let op = self.operator(); + let op = self.operator()?; let missing_space_after_op = !self.is_horizontal_space(); if missing_space_before_op { - self.error_on_node("missing space before operator", op); + self.error_on_node("missing space before operator", NodeIndexer::General(op)); } if missing_space_after_op { - self.error_on_node("missing space after operator", op); + self.error_on_node("missing space after operator", NodeIndexer::General(op)); } let op_prec = self.operator_precedence(op); if op_prec == ASSIGNMENT_PRECEDENCE && !allow_assignment { - self.error_on_node("assignment found in expression", op); + self.error_on_node("assignment found in expression", NodeIndexer::General(op)); } let rhs = if self.is_simple_expression() { - self.simple_expression(BarewordContext::Call) + self.simple_expression(BarewordContext::Call)? } else { - self.error("incomplete math expression") + self.error("incomplete math expression"); + return None; }; while op_prec <= last_prec { @@ -449,10 +273,12 @@ impl Parser { let lhs = expr_stack.last_mut().map_or(&mut leftmost, |l| &mut l.1); let (span_start, span_end) = self.spanning(*lhs, rhs); - *lhs = self.create_node( - AstNode::BinaryOp { lhs: *lhs, op, rhs }, - span_start, - span_end, + *lhs = ExpressionNode::BinaryOp { lhs: *lhs, op, rhs }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, ); } @@ -469,17 +295,22 @@ impl Parser { let (span_start, span_end) = self.spanning(*lhs, rhs); - *lhs = self.create_node( - AstNode::BinaryOp { lhs: *lhs, op, rhs }, - span_start, - span_end, + *lhs = ExpressionNode::BinaryOp { lhs: *lhs, op, rhs }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, ); } - AssignmentOrExpression::Expression(leftmost) + Some(leftmost) } - pub fn simple_expression(&mut self, bareword_context: BarewordContext) -> NodeId { + pub fn simple_expression( + &mut self, + bareword_context: BarewordContext, + ) -> Option { let _span = span!(); // skip comments and newlines @@ -492,42 +323,55 @@ impl Parser { let (token, span) = self.tokens.peek(); let mut expr = match token { - Token::LCurly => self.record_or_closure(), + Token::LCurly => self.record_or_closure()?, Token::LParen => { self.tokens.advance(); if self.tokens.peek_token() == Token::RParen { - self.error("use null instead of ()") + self.error("use null instead of ()"); + return None; } else { - let output = self.expression(); + let output = self.expression()?; self.rparen(); output } } - Token::LSquare => self.list_or_table(), - Token::Int => self.advance_node(AstNode::Int, span), - Token::Float => self.advance_node(AstNode::Float, span), - Token::DoubleQuotedString => self.advance_node(AstNode::String, span), - Token::SingleQuotedString => self.advance_node(AstNode::String, span), - Token::Dollar => self.variable(), + Token::LSquare => self.list_or_table()?, + Token::Int => self.advance_node(ExpressionNode::Int, span), + Token::Float => self.advance_node(ExpressionNode::Float, span), + Token::DoubleQuotedString => { + let string_node_id = self.advance_node(StringNode, span); + self.advance_node(ExpressionNode::String(string_node_id), span) + } + Token::SingleQuotedString => { + let string_node_id = self.advance_node(StringNode, span); + self.advance_node(ExpressionNode::String(string_node_id), span) + } + Token::Dollar => { + let var_id = self.variable()?; + self.advance_node(ExpressionNode::Variable(var_id), span) + } Token::Bareword => match self.compiler.get_span_contents_manual(span.start, span.end) { - b"true" => self.advance_node(AstNode::True, span), - b"false" => self.advance_node(AstNode::False, span), - b"null" => self.advance_node(AstNode::Null, span), + b"true" => self.advance_node(ExpressionNode::True, span), + b"false" => self.advance_node(ExpressionNode::False, span), + b"null" => self.advance_node(ExpressionNode::Null, span), _ => match bareword_context { BarewordContext::String => { - let node_id = self.name(); - self.compiler.ast_nodes[node_id.0] = AstNode::String; - node_id + // it's a string, so just make a string. + let string_node_id = self.advance_node(StringNode, span); + self.advance_node(ExpressionNode::String(string_node_id), span) } - BarewordContext::Call => self.call(), + BarewordContext::Call => self.call()?, }, }, - _ => self.error("incomplete expression"), + _ => { + self.error("incomplete expression"); + return None; + } }; loop { if self.is_horizontal_space() { - return expr; + return Some(expr); } else if self.is_dotdot() { // Range self.tokens.advance(); @@ -537,13 +381,18 @@ impl Parser { // // TODO: tweak the garbage location. self.error("incomplete range"); - return expr; + return Some(expr); } else { - let rhs = self.simple_expression(BarewordContext::String); + let rhs = self.simple_expression(BarewordContext::String)?; let span_end = self.get_span_end(rhs); - expr = - self.create_node(AstNode::Range { lhs: expr, rhs }, span_start, span_end); + ExpressionNode::Range { lhs: expr, rhs }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ); } } else if self.is_dot() { // Member access @@ -551,61 +400,69 @@ impl Parser { if self.is_horizontal_space() { self.error("missing path name"); - return expr; + return Some(expr); } - let name = self.name(); + let name = self.name()?; let field_or_call = if self.is_lparen() { - self.variable() + let var_id = self.variable()?; + self.advance_node( + ExpressionNode::Variable(var_id), + name.get_span(&self.compiler), + ) } else { - name + self.advance_node(ExpressionNode::Name(name), name.get_span(&self.compiler)) }; let span_end = self.get_span_end(field_or_call); - match self.compiler.get_node_mut(field_or_call) { - AstNode::Variable | AstNode::Name => { - expr = self.create_node( - AstNode::MemberAccess { - target: expr, - field: field_or_call, - }, - span_start, - span_end, - ); - } - _ => { - self.error("expected field"); - } + expr = ExpressionNode::MemberAccess { + target: expr, + field: field_or_call, } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ); } else { - return expr; + return Some(expr); } } } - pub fn advance_node(&mut self, node: AstNode, span: Span) -> NodeId { + pub fn advance_node(&mut self, node: T, span: Span) -> T::Output { self.tokens.advance(); - self.create_node(node, span.start, span.end) + node.push_node(span, &mut self.compiler) } - pub fn variable(&mut self) -> NodeId { + pub fn variable(&mut self) -> Option { if self.is_dollar() { let span_start = self.position(); self.tokens.advance(); if let (Token::Bareword, name_span) = self.tokens.peek() { self.tokens.advance(); - self.create_node(AstNode::Variable, span_start, name_span.end) + Some(VariableNode.push_node( + Span { + start: span_start, + end: name_span.end, + }, + &mut self.compiler, + )) } else { - self.error("variable name must be a bareword") + self.error("variable name must be a bareword"); + None } } else { - self.error("expected variable starting with '$'") + self.error("expected variable starting with '$'"); + None } } - pub fn variable_decl(&mut self) -> NodeId { + pub fn variable_decl(&mut self) -> Option { let _span = span!(); let span_start = self.position(); @@ -616,15 +473,29 @@ impl Parser { if let (Token::Bareword, name_span) = self.tokens.peek() { self.tokens.advance(); - self.create_node(AstNode::Variable, span_start, name_span.end) + Some(VariableNode.push_node( + Span { + start: span_start, + end: name_span.end, + }, + &mut self.compiler, + )) } else { - self.error("variable assignment name must be a bareword") + self.error("variable assignment name must be a bareword"); + None } } - pub fn call(&mut self) -> NodeId { + pub fn advance_unchecked(&mut self, node: T) -> T::Output { + let span = self.tokens.peek_span(); + self.tokens.advance(); + node.push_node(span, &mut self.compiler) + } + + pub fn call(&mut self) -> Option { let _span = span!(); - let mut parts = vec![self.call_name()]; + let mut head = vec![self.call_name()]; + let mut parts = vec![]; let mut is_head = true; let span_start = self.position(); @@ -634,23 +505,29 @@ impl Parser { } if self.is_name() && is_head { - parts.push(self.name()); + head.push(self.advance_unchecked(NameNode)); continue; } // TODO: Add flags is_head = false; - let arg_id = self.simple_expression(BarewordContext::String); + let arg_id = self.simple_expression(BarewordContext::String)?; parts.push(arg_id); } let span_end = self.position(); - self.create_node(AstNode::Call { parts }, span_start, span_end) + Some(ExpressionNode::Call { head, parts }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } - pub fn list_or_table(&mut self) -> NodeId { + pub fn list_or_table(&mut self) -> Option { let _span = span!(); let span_start = self.position(); let mut is_table = false; @@ -670,15 +547,18 @@ impl Parser { } else if self.is_semicolon() { if items.len() != 1 { self.error("semicolon to create table should immediately follow headers"); - } else if !matches!(self.compiler.get_node(items[0]), AstNode::List(_)) { - self.error_on_node("tables require a list for their headers", items[0]) + } else if !matches!(self.compiler.get_node(items[0]), ExpressionNode::List(_)) { + self.error_on_node( + "tables require a list for their headers", + NodeIndexer::Expression(items[0]), + ) } self.tokens.advance(); is_table = true; } else if self.is_simple_expression() { - items.push(self.simple_expression(BarewordContext::String)); + items.push(self.simple_expression(BarewordContext::String)?); } else { - items.push(self.error("expected list item")); + self.error("expected list item"); if self.is_eof() { // prevent forever looping if there is no token to put the error on break; @@ -688,20 +568,31 @@ impl Parser { if is_table { let header = items.remove(0); - self.create_node( - AstNode::Table { + Some( + ExpressionNode::Table { header, rows: items, - }, - span_start, - span_end, + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), ) } else { - self.create_node(AstNode::List(items), span_start, span_end) + Some(ExpressionNode::List(items).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } } - pub fn record_or_closure(&mut self) -> NodeId { + pub fn record_or_closure(&mut self) -> Option { let _span = span!(); let span_start = self.position(); let mut span_end = self.position(); // TODO: make sure we only initialize it expectedly @@ -716,12 +607,18 @@ impl Parser { // Explicit closure case if self.is_pipe() { - let params = Some(self.signature_params(ParamsContext::Pipes)); - let block = self.block(BlockContext::Closure); + let params = self.signature_params(ParamsContext::Pipes); + let block = self.block(BlockContext::Closure)?; self.rcurly(); span_end = self.position(); - return self.create_node(AstNode::Closure { params, block }, span_start, span_end); + return Some(ExpressionNode::Closure { params, block }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )); } let rollback_point = self.get_rollback_point(); @@ -732,7 +629,7 @@ impl Parser { span_end = self.position(); break; } - let key = self.simple_expression(BarewordContext::String); + let key = self.simple_expression(BarewordContext::String)?; self.skip_newlines(); if first_pass && !self.is_colon() { is_closure = true; @@ -740,7 +637,7 @@ impl Parser { } self.colon(); self.skip_newlines(); - let val = self.simple_expression(BarewordContext::String); + let val = self.simple_expression(BarewordContext::String)?; items.push((key, val)); first_pass = false; @@ -755,61 +652,78 @@ impl Parser { if is_closure { self.apply_rollback(rollback_point); - let block = self.block(BlockContext::Closure); + let block = self.block(BlockContext::Closure)?; self.rcurly(); span_end = self.position(); - self.create_node( - AstNode::Closure { + Some( + ExpressionNode::Closure { params: None, block, - }, - span_start, - span_end, + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), ) } else { - self.create_node(AstNode::Record { pairs: items }, span_start, span_end) + Some(ExpressionNode::Record { pairs: items }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } } - pub fn operator(&mut self) -> NodeId { + pub fn operator(&mut self) -> Option { let (token, span) = self.tokens.peek(); match token { - Token::Plus => self.advance_node(AstNode::Plus, span), - Token::PlusPlus => self.advance_node(AstNode::Append, span), - Token::Dash => self.advance_node(AstNode::Minus, span), - Token::Asterisk => self.advance_node(AstNode::Multiply, span), - Token::ForwardSlash => self.advance_node(AstNode::Divide, span), - Token::ForwardSlashForwardSlash => self.advance_node(AstNode::FloorDiv, span), - Token::LessThan => self.advance_node(AstNode::LessThan, span), - Token::LessThanEqual => self.advance_node(AstNode::LessThanOrEqual, span), - Token::GreaterThan => self.advance_node(AstNode::GreaterThan, span), - Token::GreaterThanEqual => self.advance_node(AstNode::GreaterThanOrEqual, span), - Token::EqualsEquals => self.advance_node(AstNode::Equal, span), - Token::ExclamationEquals => self.advance_node(AstNode::NotEqual, span), - Token::EqualsTilde => self.advance_node(AstNode::RegexMatch, span), - Token::ExclamationTilde => self.advance_node(AstNode::NotRegexMatch, span), - Token::AsteriskAsterisk => self.advance_node(AstNode::Pow, span), - Token::Equals => self.advance_node(AstNode::Assignment, span), - Token::PlusEquals => self.advance_node(AstNode::AddAssignment, span), - Token::DashEquals => self.advance_node(AstNode::SubtractAssignment, span), - Token::AsteriskEquals => self.advance_node(AstNode::MultiplyAssignment, span), - Token::ForwardSlashEquals => self.advance_node(AstNode::DivideAssignment, span), - Token::PlusPlusEquals => self.advance_node(AstNode::AppendAssignment, span), + Token::Plus => Some(self.advance_node(AstNode::Plus, span)), + Token::PlusPlus => Some(self.advance_node(AstNode::Append, span)), + Token::Dash => Some(self.advance_node(AstNode::Minus, span)), + Token::Asterisk => Some(self.advance_node(AstNode::Multiply, span)), + Token::ForwardSlash => Some(self.advance_node(AstNode::Divide, span)), + Token::ForwardSlashForwardSlash => Some(self.advance_node(AstNode::FloorDiv, span)), + Token::LessThan => Some(self.advance_node(AstNode::LessThan, span)), + Token::LessThanEqual => Some(self.advance_node(AstNode::LessThanOrEqual, span)), + Token::GreaterThan => Some(self.advance_node(AstNode::GreaterThan, span)), + Token::GreaterThanEqual => Some(self.advance_node(AstNode::GreaterThanOrEqual, span)), + Token::EqualsEquals => Some(self.advance_node(AstNode::Equal, span)), + Token::ExclamationEquals => Some(self.advance_node(AstNode::NotEqual, span)), + Token::EqualsTilde => Some(self.advance_node(AstNode::RegexMatch, span)), + Token::ExclamationTilde => Some(self.advance_node(AstNode::NotRegexMatch, span)), + Token::AsteriskAsterisk => Some(self.advance_node(AstNode::Pow, span)), + Token::Equals => Some(self.advance_node(AstNode::Assignment, span)), + Token::PlusEquals => Some(self.advance_node(AstNode::AddAssignment, span)), + Token::DashEquals => Some(self.advance_node(AstNode::SubtractAssignment, span)), + Token::AsteriskEquals => Some(self.advance_node(AstNode::MultiplyAssignment, span)), + Token::ForwardSlashEquals => Some(self.advance_node(AstNode::DivideAssignment, span)), + Token::PlusPlusEquals => Some(self.advance_node(AstNode::AppendAssignment, span)), Token::Bareword => match self.compiler.get_span_contents_manual(span.start, span.end) { - b"mod" => self.advance_node(AstNode::Modulo, span), - b"in" => self.advance_node(AstNode::In, span), - b"and" => self.advance_node(AstNode::And, span), - b"xor" => self.advance_node(AstNode::Xor, span), - b"or" => self.advance_node(AstNode::Or, span), - op => self.error(format!( - "Unknown operator: '{}'", - String::from_utf8_lossy(op) - )), + b"mod" => Some(self.advance_node(AstNode::Modulo, span)), + b"in" => Some(self.advance_node(AstNode::In, span)), + b"and" => Some(self.advance_node(AstNode::And, span)), + b"xor" => Some(self.advance_node(AstNode::Xor, span)), + b"or" => Some(self.advance_node(AstNode::Or, span)), + op => { + self.error(format!( + "Unknown operator: '{}'", + String::from_utf8_lossy(op) + )); + None + } }, - _ => self.error("expected: operator"), + _ => { + self.error("expected: operator"); + None + } } } @@ -817,29 +731,35 @@ impl Parser { self.compiler.get_node(operator).precedence() } - pub fn spanning(&mut self, from: NodeId, to: NodeId) -> (usize, usize) { + pub fn spanning(&mut self, from: T, to: T) -> (usize, usize) { ( - self.compiler.spans[from.0].start, - self.compiler.spans[to.0].end, + from.get_span(&self.compiler).start, + to.get_span(&self.compiler).end, ) } - pub fn string(&mut self) -> NodeId { + pub fn string(&mut self) -> Option { match self.tokens.peek() { - (Token::DoubleQuotedString, span) => self.advance_node(AstNode::String, span), - (Token::SingleQuotedString, span) => self.advance_node(AstNode::String, span), - _ => self.error("expected: string"), + (Token::DoubleQuotedString, span) => Some(self.advance_node(StringNode, span)), + (Token::SingleQuotedString, span) => Some(self.advance_node(StringNode, span)), + _ => { + self.error("expected: string"); + None + } } } - pub fn name(&mut self) -> NodeId { + pub fn name(&mut self) -> Option { match self.tokens.peek() { - (Token::Bareword, span) => self.advance_node(AstNode::Name, span), - _ => self.error("expected: name"), + (Token::Bareword, span) => Some(self.advance_node(NameNode, span)), + _ => { + self.error("expected: name"); + None + } } } - pub fn call_name(&mut self) -> NodeId { + pub fn call_name(&mut self) -> NameNodeId { let (mut token, mut span) = self.tokens.peek(); loop { @@ -859,25 +779,26 @@ impl Parser { span.end = next_span.end; } - self.create_node(AstNode::Name, span.start, span.end) + NameNode.push_node(span, &mut self.compiler) } pub fn has_tokens(&mut self) -> bool { self.tokens.peek_token() != Token::Eof } - pub fn match_expression(&mut self) -> NodeId { + pub fn match_expression(&mut self) -> Option { let _span = span!(); let span_start = self.position(); let span_end; self.keyword(b"match"); - let target = self.simple_expression(BarewordContext::String); + let target = self.simple_expression(BarewordContext::String)?; let mut match_arms = vec![]; if !self.is_lcurly() { - return self.error("expected left curly brace '{'"); + self.error("expected left curly brace '{'"); + return None; } self.lcurly(); @@ -888,14 +809,15 @@ impl Parser { self.rcurly(); break; } else if self.is_simple_expression() { - let pattern = self.simple_expression(BarewordContext::String); + let pattern = self.simple_expression(BarewordContext::String)?; if !self.is_thick_arrow() { - return self.error("expected thick arrow (=>) between match cases"); + self.error("expected thick arrow (=>) between match cases"); + return None; } self.tokens.advance(); - let pattern_result = self.simple_expression(BarewordContext::String); + let pattern_result = self.simple_expression(BarewordContext::String)?; if self.is_comma() { self.tokens.advance(); @@ -905,24 +827,31 @@ impl Parser { } else if self.is_newline() { self.tokens.advance(); } else { - return self.error("expected match arm in match"); + self.error("expected match arm in match"); + return None; } } - self.create_node(AstNode::Match { target, match_arms }, span_start, span_end) + Some(ExpressionNode::Match { target, match_arms }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } - pub fn if_expression(&mut self) -> NodeId { + pub fn if_expression(&mut self) -> Option { let _span = span!(); let span_start = self.position(); let span_end; self.keyword(b"if"); - let condition = self.expression(); + let condition = self.expression()?; self.skip_newlines(); - let then_block = self.block(BlockContext::Curlies); + let then_block = self.block(BlockContext::Curlies)?; self.skip_newlines(); let else_block = if self.is_keyword(b"else") { @@ -930,37 +859,47 @@ impl Parser { self.skip_newlines(); let block = if self.is_keyword(b"if") { - self.if_expression() + let exp = self.if_expression()?; + span_end = self.get_span_end(exp); + NodeIndexer::Expression(self.if_expression()?) } else if self.is_keyword(b"match") { - self.match_expression() + let match_exp = self.match_expression()?; + span_end = self.get_span_end(match_exp); + NodeIndexer::Expression(match_exp) } else { - self.block(BlockContext::Curlies) + let exp = self.block(BlockContext::Curlies)?; + span_end = self.get_span_end(exp); + NodeIndexer::Block(exp) }; - span_end = self.get_span_end(block); Some(block) } else { span_end = self.get_span_end(then_block); None }; - self.create_node( - AstNode::If { + Some( + ExpressionNode::If { condition, then_block, else_block, - }, - span_start, - span_end, + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), ) } - pub fn try_expression(&mut self) -> NodeId { + pub fn try_expression(&mut self) -> Option { let _span = span!(); let span_start = self.position(); self.keyword(b"try"); - let try_block = self.block(BlockContext::Curlies); + let try_block = self.block(BlockContext::Curlies)?; let mut span_end = self.get_span_end(try_block); self.skip_newlines(); @@ -969,7 +908,7 @@ impl Parser { self.tokens.advance(); self.skip_newlines(); - let block = self.block(BlockContext::Curlies); + let block = self.block(BlockContext::Curlies)?; span_end = self.get_span_end(block); Some(block) @@ -982,27 +921,32 @@ impl Parser { self.tokens.advance(); self.skip_newlines(); - let block = self.block(BlockContext::Curlies); + let block = self.block(BlockContext::Curlies)?; span_end = self.get_span_end(block); Some(block) } else { None }; - self.create_node( - AstNode::Try { + Some( + ExpressionNode::Try { try_block, catch_block, finally_block, - }, - span_start, - span_end, + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), ) } // directly ripped from `type_params` just changed delimiters // FIXME: simplify if appropriate - pub fn signature_params(&mut self, params_context: ParamsContext) -> NodeId { + pub fn signature_params(&mut self, params_context: ParamsContext) -> Option { let _span = span!(); let span_start = self.position(); let span_end; @@ -1039,26 +983,31 @@ impl Parser { continue; } - let name = self.name(); + let name = self.name()?; let ty = if self.is_colon() { // We have a type self.colon(); - Some(self.typename()) + Some(self.typename()?) } else { None }; - let name_span = self.compiler.spans[name.0]; + let name_span = name.get_span(&self.compiler); let param_span_end = if let Some(ty_id) = ty { - self.compiler.spans[ty_id.0].end + ty_id.get_span(&self.compiler).end } else { name_span.end }; - let param = - self.create_node(AstNode::Param { name, ty }, name_span.start, param_span_end); + let param = AstNode::Param { name, ty }.push_node( + Span { + start: name_span.start, + end: param_span_end, + }, + &mut self.compiler, + ); // output.push(self.name()); output.push(param); @@ -1075,7 +1024,13 @@ impl Parser { output }; - self.create_node(AstNode::Params(param_list), span_start, span_end) + Some(AstNode::Params(param_list).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } pub fn type_params(&mut self) -> NodeId { @@ -1095,16 +1050,24 @@ impl Parser { continue; } - param_list.push(self.name()); + if let Some(name) = self.name() { + param_list.push(name) + } } let span_end = self.position() + 1; self.greater_than(); - self.create_node(AstNode::Params(param_list), span_start, span_end) + AstNode::TypeParams(param_list).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ) } - pub fn type_args(&mut self) -> NodeId { + pub fn type_args(&mut self) -> Option { let _span = span!(); let span_start = self.position(); let span_end; @@ -1123,7 +1086,7 @@ impl Parser { continue; } - output.push(self.typename()); + output.push(self.typename()?); } span_end = self.position() + 1; @@ -1132,17 +1095,23 @@ impl Parser { output }; - self.create_node(AstNode::TypeArgs(arg_list), span_start, span_end) + Some(AstNode::TypeArgs(arg_list).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } - pub fn typename(&mut self) -> NodeId { + pub fn typename(&mut self) -> Option { let _span = span!(); if let (Token::Bareword, span) = self.tokens.peek() { - let name = self.name(); - let name_text = self.compiler.get_span_contents(name); + let name = self.name()?; + let name_text = name.get_span_contents(&self.compiler); if name_text == b"record" { - let fields = self.signature_params(ParamsContext::Angles); + let fields = self.signature_params(ParamsContext::Angles)?; let optional = if self.is_question_mark() { // We have an optional type self.tokens.advance(); @@ -1151,17 +1120,19 @@ impl Parser { false }; let span_end = self.position(); - return self.create_node( - AstNode::RecordType { fields, optional }, - span.start, - span_end, - ); + return Some(AstNode::RecordType { fields, optional }.push_node( + Span { + start: span.start, + end: span_end, + }, + &mut self.compiler, + )); } let mut args = None; if self.is_less_than() { // We have generics - args = Some(self.type_args()); + args = Some(self.type_args()?); } let optional = if self.is_question_mark() { @@ -1171,33 +1142,46 @@ impl Parser { } else { false }; - self.create_node( + + Some( AstNode::Type { name, args, optional, - }, - span.start, - span.end, // FIXME: this uses the end of the name as its end + } + .push_node( + Span { + start: span.start, + end: span.end, + }, + &mut self.compiler, + ), ) } else { - self.error("expect name") + self.error("expect name"); + None } } - pub fn in_out_type(&mut self) -> NodeId { + pub fn in_out_type(&mut self) -> Option { let _span = span!(); let span_start = self.position(); - let in_ty = self.typename(); + let in_ty = self.typename()?; self.thin_arrow(); - let out_ty = self.typename(); + let out_ty = self.typename()?; let span_end = self.position(); - self.create_node(AstNode::InOutType(in_ty, out_ty), span_start, span_end) + Some(AstNode::InOutType(in_ty, out_ty).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } - pub fn in_out_types(&mut self) -> NodeId { + pub fn in_out_types(&mut self) -> Option { let _span = span!(); self.colon(); @@ -1217,21 +1201,33 @@ impl Parser { continue; } - output.push(self.in_out_type()); + output.push(self.in_out_type()?); } self.rsquare(); let span_end = self.position(); - self.create_node(AstNode::InOutTypes(output), span_start, span_end) + Some(AstNode::InOutTypes(output).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } else { - let ty = self.in_out_type(); - let span = self.compiler.get_span(ty); - self.create_node(AstNode::InOutTypes(vec![ty]), span.start, span.end) + let ty = self.in_out_type()?; + let span = ty.get_span(&self.compiler); + Some(AstNode::InOutType(ty, ty).push_node( + Span { + start: span.start, + end: span.end, + }, + &mut self.compiler, + )) } } - pub fn def_statement(&mut self) -> NodeId { + pub fn def_statement(&mut self) -> Option { let _span = span!(); let span_start = self.position(); @@ -1248,29 +1244,38 @@ impl Parser { let flag_name = self.compiler.get_span_contents_manual(span.start, span.end); if flag_name == b"env" { if has_env_flag { - return self.error("duplicated --env flag"); + self.error("duplicated --env flag"); + return None; } has_env_flag = true; } else if flag_name == b"wrapped" { if has_wrapped_flag { - return self.error("duplicated --wrapped flag"); + self.error("duplicated --wrapped flag"); + return None; } has_wrapped_flag = true } else { - return self.error("expect --env or --wrapped"); + self.error("expect --env or --wrapped"); + return None; } self.tokens.advance(); } - _ => return self.error("incomplete flag name"), + _ => { + self.error("incomplete flag name"); + return None; + } } } let name = match self.tokens.peek() { - (Token::Bareword, span) => self.advance_node(AstNode::Name, span), + (Token::Bareword, span) => NameOrString::Name(self.advance_node(NameNode, span)), (Token::DoubleQuotedString | Token::SingleQuotedString, span) => { - self.advance_node(AstNode::String, span) + NameOrString::String(self.advance_node(StringNode, span)) + } + _ => { + self.error("expected def name"); + return None; } - _ => return self.error("expected def name"), }; let type_params = if self.is_less_than() { @@ -1279,18 +1284,18 @@ impl Parser { None }; - let params = self.signature_params(ParamsContext::Squares); + let params = self.signature_params(ParamsContext::Squares)?; let in_out_types = if self.is_colon() { - Some(self.in_out_types()) + Some(self.in_out_types()?) } else { None }; - let block = self.block(BlockContext::Curlies); + let block = self.block(BlockContext::Curlies)?; let span_end = self.get_span_end(block); - self.create_node( - AstNode::Def { + Some( + StatementNode::Def { name, type_params, params, @@ -1298,103 +1303,127 @@ impl Parser { block, env: has_env_flag, wrapped: has_wrapped_flag, - }, - span_start, - span_end, + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), ) } - pub fn extern_statement(&mut self) -> NodeId { + pub fn extern_statement(&mut self) -> Option { let _span = span!(); let span_start = self.position(); self.keyword(b"extern"); let name = match self.tokens.peek() { - (Token::Bareword, span) => self.advance_node(AstNode::Name, span), + (Token::Bareword, span) => NameOrString::Name(self.advance_node(NameNode, span)), (Token::DoubleQuotedString | Token::SingleQuotedString, span) => { - self.advance_node(AstNode::String, span) + NameOrString::String(self.advance_node(StringNode, span)) + } + _ => { + self.error("expected def name"); + return None; } - _ => return self.error("expected def name"), }; - let params = self.signature_params(ParamsContext::Squares); + let params = self.signature_params(ParamsContext::Squares)?; let span_end = self.position(); - self.create_node(AstNode::Extern { name, params }, span_start, span_end) + Some(StatementNode::Extern { name, params }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } // TODO: Deduplicate code between let/mut/const assignments - pub fn let_statement(&mut self) -> NodeId { + pub fn let_statement(&mut self) -> Option { let _span = span!(); let is_mutable = false; let span_start = self.position(); self.keyword(b"let"); - let variable_name = self.variable_decl(); + let variable_name = self.variable_decl()?; let ty = if self.is_colon() { // We have a type self.colon(); - Some(self.typename()) + Some(self.typename()?) } else { None }; self.equals(); - let initializer = self.pipeline_or_expression(); + let initializer = self.pipeline_or_expression()?; let span_end = self.get_span_end(initializer); - self.create_node( - AstNode::Let { + Some( + StatementNode::Let { variable_name, ty, initializer, is_mutable, - }, - span_start, - span_end, + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), ) } // TODO: Deduplicate code between let/mut/const assignments - pub fn mut_statement(&mut self) -> NodeId { + pub fn mut_statement(&mut self) -> Option { let _span = span!(); let is_mutable = true; let span_start = self.position(); self.keyword(b"mut"); - let variable_name = self.variable_decl(); + let variable_name = self.variable_decl()?; let ty = if self.is_colon() { // We have a type self.colon(); - Some(self.typename()) + Some(self.typename()?) } else { None }; self.equals(); - let initializer = self.pipeline_or_expression(); + let initializer = self.pipeline_or_expression()?; let span_end = self.get_span_end(initializer); - self.create_node( - AstNode::Let { + Some( + StatementNode::Let { variable_name, ty, initializer, is_mutable, - }, - span_start, - span_end, + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), ) } @@ -1410,7 +1439,7 @@ impl Parser { } } - pub fn block(&mut self, context: BlockContext) -> NodeId { + pub fn block(&mut self, context: BlockContext) -> Option { let _span = span!(); let span_start = self.position(); @@ -1419,6 +1448,9 @@ impl Parser { self.lcurly(); } + // NOTE: Here we make early returns on statements, so there is only one error pop up. + // We can also consider parsing as much as possible and collect all errors, but that might + // cause a lot of cascading errors, so let's see how this goes first. while self.has_tokens() { if self.is_rcurly() && context == BlockContext::Curlies { self.rcurly(); @@ -1430,57 +1462,61 @@ impl Parser { self.tokens.advance(); continue; } else if self.is_keyword(b"def") { - code_body.push(self.def_statement()); + code_body.push(StatementOrExpression::Statement(self.def_statement()?)); } else if self.is_keyword(b"let") { - code_body.push(self.let_statement()); + code_body.push(StatementOrExpression::Statement(self.let_statement()?)); } else if self.is_keyword(b"mut") { - code_body.push(self.mut_statement()); + code_body.push(StatementOrExpression::Statement(self.mut_statement()?)); } else if self.is_keyword(b"while") { - code_body.push(self.while_statement()); + code_body.push(StatementOrExpression::Statement(self.while_statement()?)); } else if self.is_keyword(b"for") { - code_body.push(self.for_statement()); + code_body.push(StatementOrExpression::Statement(self.for_statement()?)); } else if self.is_keyword(b"loop") { - code_body.push(self.loop_statement()); + code_body.push(StatementOrExpression::Statement(self.loop_statement()?)); } else if self.is_keyword(b"return") { - code_body.push(self.return_statement()); + code_body.push(StatementOrExpression::Statement(self.return_statement()?)); } else if self.is_keyword(b"continue") { - code_body.push(self.continue_statement()); + code_body.push(StatementOrExpression::Statement(self.continue_statement())); } else if self.is_keyword(b"break") { - code_body.push(self.break_statement()); + code_body.push(StatementOrExpression::Statement(self.break_statement())); } else if self.is_keyword(b"alias") { - code_body.push(self.alias_statement()); + code_body.push(StatementOrExpression::Statement(self.alias_statement()?)); } else if self.is_keyword(b"extern") { - code_body.push(self.extern_statement()); + code_body.push(StatementOrExpression::Statement(self.extern_statement()?)); } else { let exp_span_start = self.position(); - let pipeline = self.pipeline_or_expression_or_assignment(); + let pipeline = self.pipeline_or_expression_or_assignment()?; let exp_span_end = self.get_span_end(pipeline); if self.is_semicolon() { // This is a statement, not an expression self.tokens.advance(); - code_body.push(self.create_node( - AstNode::Statement(pipeline), - exp_span_start, - exp_span_end, - )) + code_body.push(StatementOrExpression::Statement( + StatementNode::Expr(pipeline).push_node( + Span { + start: exp_span_start, + end: exp_span_end, + }, + &mut self.compiler, + ), + )); } else { - code_body.push(pipeline); + code_body.push(StatementOrExpression::Expression(pipeline)); } } } - self.compiler.blocks.push(Block::new(code_body)); let span_end = self.position(); - - self.create_node( - AstNode::Block(BlockId(self.compiler.blocks.len() - 1)), - span_start, - span_end, - ) + Some(BlockNode { nodes: code_body }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } - pub fn while_statement(&mut self) -> NodeId { + pub fn while_statement(&mut self) -> Option { let _span = span!(); let span_start = self.position(); self.keyword(b"while"); @@ -1491,47 +1527,64 @@ impl Parser { self.tokens.advance(); } - let condition = self.expression(); - let block = self.block(BlockContext::Curlies); + let condition = self.expression()?; + let block = self.block(BlockContext::Curlies)?; let span_end = self.get_span_end(block); - self.create_node(AstNode::While { condition, block }, span_start, span_end) + Some(StatementNode::While { condition, block }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } - pub fn for_statement(&mut self) -> NodeId { + pub fn for_statement(&mut self) -> Option { let _span = span!(); let span_start = self.position(); self.keyword(b"for"); - let variable = self.variable_decl(); + let variable = self.variable_decl()?; self.keyword(b"in"); - let range = self.simple_expression(BarewordContext::String); - let block = self.block(BlockContext::Curlies); + let range = self.simple_expression(BarewordContext::String)?; + let block = self.block(BlockContext::Curlies)?; let span_end = self.get_span_end(block); - self.create_node( - AstNode::For { + Some( + StatementNode::For { variable, range, block, - }, - span_start, - span_end, + } + .push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ), ) } - pub fn loop_statement(&mut self) -> NodeId { + pub fn loop_statement(&mut self) -> Option { let _span = span!(); let span_start = self.position(); self.keyword(b"loop"); - let block = self.block(BlockContext::Curlies); + let block = self.block(BlockContext::Curlies)?; let span_end = self.get_span_end(block); - self.create_node(AstNode::Loop { block }, span_start, span_end) + Some(StatementNode::Loop { block }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } - pub fn return_statement(&mut self) -> NodeId { + pub fn return_statement(&mut self) -> Option { let _span = span!(); let span_start = self.position(); let span_end; @@ -1539,7 +1592,7 @@ impl Parser { self.keyword(b"return"); let ret_val = if self.is_expression() { - let expr = self.expression(); + let expr = self.expression()?; span_end = self.get_span_end(expr); Some(expr) } else { @@ -1547,44 +1600,69 @@ impl Parser { None }; - self.create_node(AstNode::Return(ret_val), span_start, span_end) + Some(StatementNode::Return(ret_val).push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } - pub fn continue_statement(&mut self) -> NodeId { + pub fn continue_statement(&mut self) -> StatementNodeId { let _span = span!(); let span_start = self.position(); self.keyword(b"continue"); let span_end = span_start + b"continue".len(); - self.create_node(AstNode::Continue, span_start, span_end) + StatementNode::Continue.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ) } - pub fn break_statement(&mut self) -> NodeId { + pub fn break_statement(&mut self) -> StatementNodeId { let _span = span!(); let span_start = self.position(); self.keyword(b"break"); let span_end = span_start + b"break".len(); - self.create_node(AstNode::Break, span_start, span_end) + StatementNode::Break.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + ) } - pub fn alias_statement(&mut self) -> NodeId { + pub fn alias_statement(&mut self) -> Option { let _span = span!(); let span_start = self.position(); self.keyword(b"alias"); let new_name = if self.is_string() { - self.string() + NameOrString::String(self.string()?) } else { - self.name() + NameOrString::Name(self.name()?) }; self.equals(); - let old_name = if self.is_string() { - self.string() + let (old_name, span_end) = if self.is_string() { + let s = self.string()?; + (NameOrString::String(s), self.get_span_end(s)) } else { - self.name() + let s = self.name()?; + (NameOrString::Name(s), self.get_span_end(s)) }; - let span_end = self.get_span_end(old_name); - self.create_node(AstNode::Alias { new_name, old_name }, span_start, span_end) + Some(StatementNode::Alias { new_name, old_name }.push_node( + Span { + start: span_start, + end: span_end, + }, + &mut self.compiler, + )) } pub fn is_operator(&mut self) -> bool { @@ -1766,7 +1844,7 @@ impl Parser { || self.is_name() } - pub fn error_on_node(&mut self, message: impl Into, node_id: NodeId) { + pub fn error_on_node(&mut self, message: impl Into, node_id: NodeIndexer) { self.compiler.errors.push(SourceError { message: message.into(), node_id, @@ -1774,29 +1852,19 @@ impl Parser { }); } - pub fn error(&mut self, message: impl Into) -> NodeId { + pub fn error(&mut self, message: impl Into) { let (token, span) = self.tokens.peek(); if token != Token::Eof { self.tokens.advance(); } - let node_id = self.create_node(AstNode::Garbage, span.start, span.end); + let node_id = NodeIndexer::General(AstNode::Garbage.push_node(span, &mut self.compiler)); self.compiler.errors.push(SourceError { message: message.into(), node_id, severity: Severity::Error, }); - - node_id - } - - pub fn create_node(&mut self, ast_node: AstNode, span_start: usize, span_end: usize) -> NodeId { - self.compiler.spans.push(Span { - start: span_start, - end: span_end, - }); - self.compiler.push_node(ast_node) } pub fn lparen(&mut self) { diff --git a/src/resolver.rs b/src/resolver.rs index 5c281fb..56c390e 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -1,8 +1,13 @@ +use crate::ast_nodes::{ + AstNode, BlockId, BlockNode, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, + NameOrString, NameOrVariable, NodeId, NodeIdGetter, NodeIndexer, NodePusher, PipelineId, + PipelineNode, StatementNode, StatementNodeId, StatementOrExpression, StringNode, StringNodeId, + VariableNode, VariableNodeId, +}; use crate::protocol::{Command, Declaration}; use crate::{ compiler::Compiler, errors::{Severity, SourceError}, - parser::{AstNode, BlockId, NodeId, PipelineId}, }; use std::collections::HashMap; @@ -67,12 +72,12 @@ pub struct NameBindings { pub scope: Vec, pub scope_stack: Vec, pub variables: Vec, - pub var_resolution: HashMap, + pub var_resolution: HashMap, pub type_decls: Vec, - pub type_resolution: HashMap, + pub type_resolution: HashMap, pub decls: Vec>, - pub decl_nodes: Vec, - pub decl_resolution: HashMap, + pub decl_nodes: Vec, + pub decl_resolution: HashMap, pub errors: Vec, } @@ -110,17 +115,19 @@ pub struct Resolver<'a> { /// Variables, indexed by VarId pub variables: Vec, /// Mapping of variable's name node -> Variable - pub var_resolution: HashMap, + pub var_resolution: HashMap, /// Type declarations, indexed by TypeDeclId pub type_decls: Vec, /// Mapping of type decl's name node -> TypeDecl - pub type_resolution: HashMap, + pub type_resolution: HashMap, /// Declarations (commands, aliases, etc.), indexed by DeclId pub decls: Vec>, /// Declaration nodes, indexed by DeclId - pub decl_nodes: Vec, + pub decl_nodes: Vec, /// Mapping of decl's name node -> Command - pub decl_resolution: HashMap, + /// It can be NameOrString, or an AstNode::Call. + // NOTE: not sure why it can be ExpressionNode::Call, but let's keep the original behavior. + pub decl_resolution: HashMap, /// Errors encountered during name binding pub errors: Vec, } @@ -231,20 +238,31 @@ impl<'a> Resolver<'a> { } pub fn resolve(&mut self) { - if !self.compiler.ast_nodes.is_empty() { - let last = self.compiler.ast_nodes.len() - 1; - let last_node_id = NodeId(last); - self.resolve_node(last_node_id) + if !self.compiler.indexer.is_empty() { + let length = self.compiler.indexer.len(); + let last_indexer = self.compiler.indexer[length - 1]; + match last_indexer { + NodeIndexer::General(node_id) => self.resolve_node(node_id), + NodeIndexer::Block(block_id) => self.resolve_block(, block_id,None), + NodeIndexer::Expression(expr_id) => self.resolve_expr(expr_id), + NodeIndexer::Statement(stmt_id) => self.resolve_statement(stmt_id), + NodeIndexer::Pipeline(pipeline_id) => self.resolve_pipeline(pipeline_id), + _ => return, + } } + // if !self.compiler.ast_nodes.is_empty() { + // let last = self.compiler.ast_nodes.len() - 1; + // let last_node_id = NodeId(last); + // self.resolve_node(last_node_id) + // } } - pub fn resolve_node(&mut self, node_id: NodeId) { - // TODO: Move node_id param to the end, same as in typechecker - match self.compiler.ast_nodes[node_id.0] { - AstNode::Variable => self.resolve_variable(node_id), - AstNode::Call { ref parts } => self.resolve_call(node_id, parts), - AstNode::Block(block_id) => self.resolve_block(node_id, block_id, None), - AstNode::Closure { params, block } => { + pub fn resolve_expression(&mut self, expr_id: ExpressionNodeId) { + let node = expr_id.get_node(&self.compiler); + match node { + ExpressionNode::Variable(node_id) => self.resolve_variable(node_id), + ExpressionNode::Call { head, parts } => self.resolve_call(expr_id, head, parts), + ExpressionNode::Closure { params, block } => { // making sure the closure parameters and body end up in the same scope frame let closure_scope = if let Some(params) = params { self.enter_scope(block); @@ -260,15 +278,65 @@ impl<'a> Resolver<'a> { self.resolve_block(block, block_id, closure_scope); } - AstNode::Def { - name, - type_params, - params, - in_out_types, - block, - env: _, - wrapped: _, + ExpressionNode::BinaryOp { lhs, op: _, rhs } => { + self.resolve_node(lhs); + self.resolve_node(rhs); + } + ExpressionNode::Range { lhs, rhs } => { + self.resolve_node(lhs); + self.resolve_node(rhs); + } + ExpressionNode::List(ref nodes) => { + for node in nodes { + self.resolve_node(*node); + } + } + ExpressionNode::Table { header, ref rows } => { + self.resolve_node(header); + for row in rows { + self.resolve_node(*row); + } + } + ExpressionNode::Record { ref pairs } => { + for (key, val) in pairs { + self.resolve_node(*key); + self.resolve_node(*val); + } + } + ExpressionNode::MemberAccess { target, field } => { + self.resolve_node(target); + self.resolve_node(field); + } + ExpressionNode::If { + condition, + then_block, + else_block, + } => { + self.resolve_node(condition); + self.resolve_node(then_block); + if let Some(block) = else_block { + self.resolve_node(block); + } + } + ExpressionNode::Match { + target, + ref match_arms, } => { + self.resolve_node(target); + for (arm_lhs, arm_rhs) in match_arms { + self.resolve_node(*arm_lhs); + self.resolve_node(*arm_rhs); + } + } + ExpressionNode::Pipeline(pipeline_id) => self.resolve_pipeline(pipeline_id), + AstNode::NamedValue { .. } => (/* seems unused for now */), + } + } + + pub fn resolve_statement(&mut self, stmt_id: StatementNodeId) { + let node = stmt_id.get_node(&self.compiler); + match node { + StatementNode::Def { name, type_params, params, in_out_types, block, env, wrapped } => { // define the command before the block to enable recursive calls self.define_decl(name, node_id); @@ -293,25 +361,10 @@ impl<'a> Resolver<'a> { }; self.resolve_block(block, block_id, Some(def_scope)); - } - AstNode::Alias { - new_name, - old_name: _, - } => { - self.define_decl(new_name, node_id); - } - AstNode::Params(ref params) => { - for param in params { - let AstNode::Param { name, ty } = self.compiler.ast_nodes[param.0] else { - panic!("param is not a param"); - }; - self.define_variable(name, false); - if let Some(ty) = ty { - self.resolve_node(ty); - } - } - } - AstNode::Let { + }, + StatementNode::Alias { new_name, old_name } => {self.define_decl(new_name, node_id); + }, + StatementNode::Let { variable_name, ty, initializer, @@ -323,11 +376,11 @@ impl<'a> Resolver<'a> { self.resolve_node(initializer); self.define_variable(variable_name, is_mutable) } - AstNode::While { condition, block } => { + StatementNode::While { condition, block } => { self.resolve_node(condition); self.resolve_node(block); } - AstNode::For { + StatementNode::For { variable, range, block, @@ -339,66 +392,32 @@ impl<'a> Resolver<'a> { self.resolve_node(range); - let AstNode::Block(block_id) = self.compiler.ast_nodes[block.0] else { + let StatementNode::Block(block_id) = self.compiler.ast_nodes[block.0] else { panic!("internal error: for's body is not a block"); }; self.resolve_block(block, block_id, Some(for_body_scope)); } - AstNode::Loop { block } => { + StatementNode::Loop { block } => { self.resolve_node(block); } - AstNode::BinaryOp { lhs, op: _, rhs } => { - self.resolve_node(lhs); - self.resolve_node(rhs); - } - AstNode::Range { lhs, rhs } => { - self.resolve_node(lhs); - self.resolve_node(rhs); - } - AstNode::List(ref nodes) => { - for node in nodes { - self.resolve_node(*node); - } - } - AstNode::Table { header, ref rows } => { - self.resolve_node(header); - for row in rows { - self.resolve_node(*row); - } - } - AstNode::Record { ref pairs } => { - for (key, val) in pairs { - self.resolve_node(*key); - self.resolve_node(*val); - } - } - AstNode::MemberAccess { target, field } => { - self.resolve_node(target); - self.resolve_node(field); - } - AstNode::If { - condition, - then_block, - else_block, - } => { - self.resolve_node(condition); - self.resolve_node(then_block); - if let Some(block) = else_block { - self.resolve_node(block); - } - } - AstNode::Match { - target, - ref match_arms, - } => { - self.resolve_node(target); - for (arm_lhs, arm_rhs) in match_arms { - self.resolve_node(*arm_lhs); - self.resolve_node(*arm_rhs); + + } + } + pub fn resolve_node(&mut self, node_id: NodeId) { + // TODO: Move node_id param to the end, same as in typechecker + match self.compiler.ast_nodes[node_id.0] { + AstNode::Params(ref params) => { + for param in params { + let AstNode::Param { name, ty } = self.compiler.ast_nodes[param.0] else { + panic!("param is not a param"); + }; + self.define_variable(name, false); + if let Some(ty) = ty { + self.resolve_node(ty); + } } } - AstNode::Statement(node) => self.resolve_node(node), AstNode::Type { name, args, .. } => { self.resolve_type(name); if let Some(args) = args { @@ -429,29 +448,27 @@ impl<'a> Resolver<'a> { self.resolve_node(in_ty); self.resolve_node(out_ty); } - AstNode::Pipeline(pipeline_id) => self.resolve_pipeline(pipeline_id), AstNode::Param { .. } => (/* seems unused for now */), - AstNode::NamedValue { .. } => (/* seems unused for now */), // All remaining matches do not contain NodeId => there is nothing to resolve _ => (), } } pub fn resolve_pipeline(&mut self, pipeline_id: PipelineId) { - let pipeline = &self.compiler.pipelines[pipeline_id.0]; + let pipeline = &self.compiler.pipeline_nodes[pipeline_id.0]; for exp in pipeline.get_expressions() { self.resolve_node(*exp) } } - pub fn resolve_variable(&mut self, unbound_node_id: NodeId) { - let var_name = trim_var_name(self.compiler.get_span_contents(unbound_node_id)); + pub fn resolve_variable(&mut self, unbound_node_id: &VariableNodeId) { + let var_name = trim_var_name(unbound_node_id.get_span_contents(&self.compiler)); if let Some(node_id) = self.find_variable(var_name) { let var_id = self .var_resolution - .get(&node_id) + .get(&NameOrVariable::Variable(*node_id)) .expect("internal error: missing resolved variable"); self.var_resolution.insert(unbound_node_id, *var_id); @@ -533,8 +550,8 @@ impl<'a> Resolver<'a> { ) { let block = self .compiler - .blocks - .get(block_id.0) + .block_nodes + .get(block_id) .expect("internal error: missing block"); if let Some(scope_id) = reused_scope { diff --git a/src/typechecker.rs b/src/typechecker.rs index 9364d81..9dff335 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -1,9 +1,14 @@ //! See typechecking.md in the contributing/ folder for more information on //! how the typechecker works +use crate::ast_nodes::{ + AstNode, BlockId, BlockNode, ExpressionNode, ExpressionNodeId, NameNode, NameNodeId, + NameOrString, NameOrVariable, NodeId, NodeIdGetter, NodeIndexer, NodePusher, PipelineId, + PipelineNode, StatementNode, StatementNodeId, StatementOrExpression, StringNode, StringNodeId, + VariableNode, VariableNodeId, +}; use crate::compiler::Compiler; use crate::errors::{Severity, SourceError}; -use crate::parser::{AstNode, NodeId}; use crate::resolver::{TypeDecl, TypeDeclId}; use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; @@ -77,7 +82,14 @@ pub enum Type { pub struct Types { pub types: Vec, + pub name_node_types: Vec, + pub string_node_types: Vec, + pub variable_node_types: Vec, pub node_types: Vec, + pub expression_node_types: Vec, + pub statement_node_types: Vec, + pub block_node_types: Vec, + pub pipeline_node_types: Vec, pub errors: Vec, } @@ -113,10 +125,17 @@ pub struct Typechecker<'a> { types: Vec, /// Types of nodes. Each type in this vector matches a node in compiler.ast_nodes at the same position. + pub name_node_types: Vec, + pub string_node_types: Vec, + pub variable_node_types: Vec, pub node_types: Vec, + pub expression_node_types: Vec, + pub statement_node_types: Vec, + pub block_node_types: Vec, + pub pipeline_node_types: Vec, /// Record fields used for `RecordType`. Each value in this vector matches with the index in RecordTypeId. /// The individual field lists are stored sorted by field name. - pub record_types: Vec>, + pub record_types: Vec>, /// Types used for `OneOf`. Each value in this vector matches with the index in OneOfId. pub oneof_types: Vec>, /// Types used for `AllOf`. Each value in this vector matches with the index in AllOfId. @@ -155,7 +174,14 @@ impl<'a> Typechecker<'a> { Type::Top, Type::Bottom, ], + name_node_types: vec![UNKNOWN_TYPE; compiler.name_nodes.len()], + string_node_types: vec![UNKNOWN_TYPE; compiler.string_nodes.len()], + variable_node_types: vec![UNKNOWN_TYPE; compiler.variable_nodes.len()], node_types: vec![UNKNOWN_TYPE; compiler.ast_nodes.len()], + expression_node_types: vec![UNKNOWN_TYPE; compiler.expression_nodes.len()], + statement_node_types: vec![UNKNOWN_TYPE; compiler.statement_nodes.len()], + block_node_types: vec![UNKNOWN_TYPE; compiler.block_nodes.len()], + pipeline_node_types: vec![UNKNOWN_TYPE; compiler.pipeline_nodes.len()], record_types: Vec::new(), oneof_types: Vec::new(), allof_types: Vec::new(), @@ -175,7 +201,14 @@ impl<'a> Typechecker<'a> { pub fn to_types(self) -> Types { Types { types: self.types, + name_node_types: self.name_node_types, + string_node_types: self.string_node_types, + variable_node_types: self.variable_node_types, node_types: self.node_types, + expression_node_types: self.expression_node_types, + statement_node_types: self.statement_node_types, + block_node_types: self.block_node_types, + pipeline_node_types: self.pipeline_node_types, errors: self.errors, } } @@ -213,11 +246,16 @@ impl<'a> Typechecker<'a> { /// Typecheck AST nodes, starting from the last node pub fn typecheck(&mut self) { - if !self.compiler.ast_nodes.is_empty() { - let last = self.compiler.ast_nodes.len() - 1; - let last_node_id = NodeId(last); - self.typecheck_node(last_node_id); - + if !self.compiler.indexer.is_empty() { + let length = self.compiler.indexer.len(); + let last_indexer = self.compiler.indexer[length - 1]; + match last_indexer { + NodeIndexer::General(node_id) => self.typecheck_node(&node_id), + NodeIndexer::Block(block_id) => { + self.typecheck_block(&block_id, TOP_TYPE); + } + _ => return, + } for i in 0..self.type_vars.len() { let var = &self.type_vars[i]; let bound = var.lower_bound; @@ -240,19 +278,20 @@ impl<'a> Typechecker<'a> { } /// Get type of node - pub fn type_of(&self, node_id: NodeId) -> Type { - let type_id = self.type_id_of(node_id); + pub fn type_of(&self, node_id: &T) -> Type { + let type_id = node_id.type_id_of(self); self.types[type_id.0] } - fn typecheck_node(&mut self, node_id: NodeId) { - match self.compiler.ast_nodes[node_id.0] { + fn typecheck_node(&mut self, node_id: &NodeId) { + let node = node_id.get_node(self.compiler); + match node { AstNode::Params(ref params) => { for param in params { - self.typecheck_node(*param); + self.typecheck_node(param); } // Params are not supposed to be evaluated - self.set_node_type_id(node_id, FORBIDDEN_TYPE); + node_id.set_node_type_id(self, FORBIDDEN_TYPE); } AstNode::Param { name, ty } => { if let Some(ty) = ty { @@ -261,48 +300,46 @@ impl<'a> Typechecker<'a> { let var_id = self .compiler .var_resolution - .get(&name) + .get(&NameOrVariable::Name(*name)) .expect("missing resolved variable"); self.variable_types[var_id.0] = ty_id; - self.set_node_type_id(node_id, ty_id); + node_id.set_node_type_id(self, ty_id); } else { - self.set_node_type_id(node_id, ANY_TYPE); + node_id.set_node_type_id(self, ANY_TYPE); } } AstNode::TypeArgs(ref args) => { for arg in args { - self.typecheck_type(*arg); + self.typecheck_type(arg); } // Type argument lists are not supposed to be evaluated - self.set_node_type_id(node_id, FORBIDDEN_TYPE); - } - AstNode::Block(_) => { - self.typecheck_block(node_id, TOP_TYPE); + node_id.set_node_type_id(self, FORBIDDEN_TYPE); } + // NOTE: what about AstNode::TypeParams? _ => self.error( format!( "unsupported/unexpected ast node '{:?}' in typechecker", - self.compiler.ast_nodes[node_id.0] + node ), - node_id, + node_id.into_indexer(), ), } } - fn typecheck_block(&mut self, node_id: NodeId, expected: TypeId) -> TypeId { - let AstNode::Block(block_id) = self.compiler.ast_nodes[node_id.0] else { - panic!( - "Expected block to typecheck, got '{:?}'", - self.compiler.ast_nodes[node_id.0] - ); - }; - let block = &self.compiler.blocks[block_id.0]; + fn typecheck_block(&mut self, node_id: &BlockId, expected: TypeId) -> TypeId { + let block = node_id.get_node(self.compiler); for (i, inner_node_id) in block.nodes.iter().enumerate() { - if i == block.nodes.len() - 1 && self.is_expr(*inner_node_id) { - self.typecheck_expr(*inner_node_id, expected); + let expected_type = if i == block.nodes.len() - 1 { + expected } else { - self.typecheck_stmt(*inner_node_id); + TOP_TYPE + }; + match inner_node_id { + StatementOrExpression::Statement(stmt_id) => self.typecheck_stmt(*stmt_id), + StatementOrExpression::Expression(expr_id) => { + self.typecheck_expr(expr_id, expected_type); + } } } @@ -311,30 +348,31 @@ impl<'a> Typechecker<'a> { let block_type = block .nodes .last() - .map_or(NONE_TYPE, |node_id| self.type_id_of(*node_id)); - self.set_node_type_id(node_id, block_type); + .map_or(NONE_TYPE, |node_id| node_id.type_id_of(self)); + node_id.set_node_type_id(self, block_type); block_type } - fn typecheck_stmt(&mut self, node_id: NodeId) { - match self.compiler.ast_nodes[node_id.0] { - AstNode::Let { + fn typecheck_stmt(&mut self, node_id: StatementNodeId) { + let node = node_id.get_node(self.compiler); + match node { + StatementNode::Let { variable_name, ty, initializer, is_mutable: _, } => self.typecheck_let(variable_name, ty, initializer, node_id), - AstNode::Def { + StatementNode::Def { name, params, in_out_types, block, .. } => self.typecheck_def(name, params, in_out_types, block, node_id), - AstNode::Alias { new_name, old_name } => { + StatementNode::Alias { new_name, old_name } => { self.typecheck_alias(new_name, old_name, node_id) } - AstNode::For { + StatementNode::For { variable, range, block, @@ -345,72 +383,70 @@ impl<'a> Typechecker<'a> { let var_id = self .compiler .var_resolution - .get(&variable) + .get(&NameOrVariable::Variable(*variable)) .expect("missing resolved variable"); if let Type::List(type_id) = self.type_of(range) { self.variable_types[var_id.0] = type_id; - self.set_node_type_id(variable, type_id); + variable.set_node_type_id(self, type_id); } else { self.variable_types[var_id.0] = ANY_TYPE; - self.set_node_type_id(variable, ERROR_TYPE); - self.error("For loop range is not a list", range); + variable.set_node_type_id(self, ERROR_TYPE); + self.error("For loop range is not a list", range.into_indexer()); } - self.typecheck_node(block); - if self.type_id_of(block) != NONE_TYPE { - self.error("Blocks in looping constructs cannot return values", block); + self.typecheck_block(block, TOP_TYPE); + if block.type_id_of(self) != NONE_TYPE { + self.error( + "Blocks in looping constructs cannot return values", + block.into_indexer(), + ); } - if self.type_id_of(node_id) != ERROR_TYPE { - self.set_node_type_id(node_id, NONE_TYPE); + if node_id.type_id_of(self) != ERROR_TYPE { + node_id.set_node_type_id(self, NONE_TYPE); } } - AstNode::While { condition, block } => { + StatementNode::While { condition, block } => { self.typecheck_expr(condition, BOOL_TYPE); - self.typecheck_node(block); - self.set_node_type_id(node_id, NONE_TYPE); + self.typecheck_block(block, TOP_TYPE); + node_id.set_node_type_id(self, NONE_TYPE); } - AstNode::Loop { block } => { - self.typecheck_node(block); - self.set_node_type_id(node_id, NONE_TYPE); + StatementNode::Loop { block } => { + self.typecheck_block(block, TOP_TYPE); + node_id.set_node_type_id(self, NONE_TYPE); } - AstNode::Break | AstNode::Continue => { + StatementNode::Break | StatementNode::Continue => { // TODO make sure we're in a loop - self.set_node_type_id(node_id, NONE_TYPE); - } - _ if self.is_expr(node_id) => { - self.typecheck_expr(node_id, TOP_TYPE); + node_id.set_node_type_id(self, NONE_TYPE); } _ => self.error( - format!( - "Expected statement to typecheck, got '{:?}'", - self.compiler.ast_nodes[node_id.0] - ), - node_id, + format!("unsupported statement node '{:?}' in typechecker", node), + node_id.into_indexer(), ), } } - fn typecheck_expr(&mut self, node_id: NodeId, expected: TypeId) -> TypeId { - let ty_id = match self.compiler.ast_nodes[node_id.0] { - AstNode::Null => NOTHING_TYPE, - AstNode::Int => INT_TYPE, - AstNode::Float => FLOAT_TYPE, - AstNode::True | AstNode::False => BOOL_TYPE, - AstNode::String => STRING_TYPE, - AstNode::List(ref items) => { + fn typecheck_expr(&mut self, node_id: &ExpressionNodeId, expected: TypeId) -> TypeId { + let node = node_id.get_node(self.compiler); + let ty_id = match node { + ExpressionNode::Null => NOTHING_TYPE, + ExpressionNode::Int => INT_TYPE, + ExpressionNode::Float => FLOAT_TYPE, + ExpressionNode::True | ExpressionNode::False => BOOL_TYPE, + ExpressionNode::String(_) => STRING_TYPE, + ExpressionNode::List(ref items) => { // TODO infer a union type instead if let Some(first_id) = items.first() { let expected_elem = self.extract_elem_type(expected); - self.typecheck_expr(*first_id, expected_elem.unwrap_or(TOP_TYPE)); - let first_type = self.type_of(*first_id); + self.typecheck_expr(first_id, expected_elem.unwrap_or(TOP_TYPE)); + let first_type = self.type_of(first_id); let mut all_numbers = self.is_type_compatible(first_type, Type::Number); let mut all_same = true; for item_id in items.iter().skip(1) { - self.typecheck_expr(*item_id, TOP_TYPE); - let item_type = self.type_of(*item_id); + self.typecheck_expr(item_id, TOP_TYPE); + let item_type = self.type_of(item_id); if all_numbers && !self.is_type_compatible(item_type, Type::Number) { all_numbers = false; @@ -422,7 +458,7 @@ impl<'a> Typechecker<'a> { } if all_same { - self.push_type(Type::List(self.type_id_of(*first_id))) + self.push_type(Type::List(first_id.type_id_of(self))) } else if all_numbers { self.push_type(Type::List(NUMBER_TYPE)) } else { @@ -432,50 +468,52 @@ impl<'a> Typechecker<'a> { LIST_ANY_TYPE } } - AstNode::Record { ref pairs } => { + ExpressionNode::Record { ref pairs } => { // TODO take expected type into account let mut field_types = pairs .iter() - .map(|(name, value)| (*name, self.typecheck_expr(*value, TOP_TYPE))) + .map(|(name, value)| (*name, self.typecheck_expr(value, TOP_TYPE))) .collect::>(); - field_types.sort_by_cached_key(|(name, _)| self.compiler.get_span_contents(*name)); + field_types.sort_by_cached_key(|(name, _)| { + self.compiler.get_span_contents(name.into_indexer()) + }); self.record_types.push(field_types); self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))) } - AstNode::Pipeline(pipeline_id) => { - let pipeline = &self.compiler.pipelines[pipeline_id.0]; + ExpressionNode::Pipeline(pipeline_id) => { + let pipeline = self.compiler.pipeline_nodes.get_node(pipeline_id.0); let expressions = pipeline.get_expressions(); for inner in expressions { - self.typecheck_expr(*inner, TOP_TYPE); + self.typecheck_expr(inner, TOP_TYPE); } // pipeline type is the type of the last expression, since blocks // by themselves aren't supposed to be typed expressions .last() - .map_or(NONE_TYPE, |node_id| self.type_id_of(*node_id)) + .map_or(NONE_TYPE, |node_id| node_id.type_id_of(self)) } - AstNode::Closure { params, block } => { + ExpressionNode::Closure { params, block } => { // TODO: input/output types if let Some(params_node_id) = params { self.typecheck_node(params_node_id); } - self.typecheck_node(block); + self.typecheck_block(block, expected); CLOSURE_TYPE } - AstNode::BinaryOp { lhs, op, rhs } => self.typecheck_binary_op(lhs, op, rhs), - AstNode::Variable => { + ExpressionNode::BinaryOp { lhs, op, rhs } => self.typecheck_binary_op(lhs, op, rhs), + ExpressionNode::Variable(variable_node_id) => { let var_id = self .compiler .var_resolution - .get(&node_id) + .get(&NameOrVariable::Variable(*variable_node_id)) .expect("missing resolved variable"); self.variable_types[var_id.0] } - AstNode::If { + ExpressionNode::If { condition, then_block, else_block, @@ -485,12 +523,18 @@ impl<'a> Typechecker<'a> { let then_type_id = self.typecheck_block(then_block, expected); if let Some(else_blk) = else_block { - let else_type_id = - if let AstNode::Block(_) = self.compiler.ast_nodes[else_blk.0] { - self.typecheck_block(else_blk, expected) - } else { - self.typecheck_expr(else_blk, expected) - }; + let else_type_id = match else_blk { + NodeIndexer::Expression(else_expr_id) => { + self.typecheck_expr(else_expr_id, expected) + } + NodeIndexer::Block(else_block_id) => { + self.typecheck_block(else_block_id, expected) + } + _ => { + self.error("Else block of an if expression must be either a block or an expression", *else_blk); + return ERROR_TYPE; + } + }; let mut types = HashSet::new(); types.insert(then_type_id); types.insert(else_type_id); @@ -500,8 +544,11 @@ impl<'a> Typechecker<'a> { NONE_TYPE } } - AstNode::Call { ref parts } => self.typecheck_call(parts, node_id), - AstNode::Match { + ExpressionNode::Call { head, parts } => { + // need to make sure that the node_id is either Name or String. + self.typecheck_call(head, parts, node_id) + } + ExpressionNode::Match { ref target, ref match_arms, } => { @@ -515,16 +562,13 @@ impl<'a> Typechecker<'a> { } _ => { self.error( - format!( - "Expected an expression to typecheck, got '{:?}'", - self.compiler.ast_nodes[node_id.0] - ), - node_id, + format!("Expected an expression to typecheck, got '{:?}'", node), + node_id.into_indexer(), ); ERROR_TYPE } }; - self.set_node_type_id(node_id, ty_id); + node_id.set_node_type_id(self, ty_id); if !self.constrain_subtype(ty_id, expected) { self.error( @@ -533,67 +577,45 @@ impl<'a> Typechecker<'a> { self.type_to_string(expected), self.type_to_string(ty_id) ), - node_id, + node_id.into_indexer(), ); } ty_id } - fn is_expr(&mut self, node_id: NodeId) -> bool { - matches!( - self.compiler.ast_nodes[node_id.0], - AstNode::Null - | AstNode::Int - | AstNode::Float - | AstNode::True - | AstNode::False - | AstNode::String - | AstNode::Variable - | AstNode::List(_) - | AstNode::Record { .. } - | AstNode::Table { .. } - | AstNode::Pipeline(_) - | AstNode::Closure { .. } - | AstNode::BinaryOp { .. } - | AstNode::If { .. } - | AstNode::Call { .. } - | AstNode::Match { .. } - ) - } - fn typecheck_match( &mut self, - target: &NodeId, - match_arms: &Vec<(NodeId, NodeId)>, + target: &ExpressionNodeId, + match_arms: &Vec<(ExpressionNodeId, ExpressionNodeId)>, expected: TypeId, ) -> HashSet { - self.typecheck_expr(*target, TOP_TYPE); + self.typecheck_expr(target, TOP_TYPE); let mut output_types = HashSet::new(); // typecheck each node - let target_id = self.type_id_of(*target); + let target_id = target.type_id_of(self); for (match_node, result_node) in match_arms { - self.typecheck_node(*match_node); - self.typecheck_expr(*result_node, expected); + self.typecheck_expr(match_node, expected); + self.typecheck_expr(result_node, expected); - let match_id = self.type_id_of(*match_node); - match (self.type_of(*target), self.type_of(*match_node)) { + let match_id = match_node.type_id_of(self); + match (self.type_of(target), self.type_of(match_node)) { // First is of type Any which will always match (Type::Any, _) => { - self.add_resolved_types(&mut output_types, &self.type_id_of(*result_node)); + self.add_resolved_types(&mut output_types, &result_node.type_id_of(self)); } // Same as above but for second (_, Type::Any) => { - self.add_resolved_types(&mut output_types, &self.type_id_of(*result_node)); + self.add_resolved_types(&mut output_types, &result_node.type_id_of(self)); } // the second is one of the possible types of the first (Type::OneOf(id), _) if self.oneof_types[id.0].contains(&match_id) => { - self.add_resolved_types(&mut output_types, &self.type_id_of(*result_node)); + self.add_resolved_types(&mut output_types, &result_node.type_id_of(self)); } // the first is one of the possible types of the second (_, Type::OneOf(id)) if self.oneof_types[id.0].contains(&target_id) => { - self.add_resolved_types(&mut output_types, &self.type_id_of(*result_node)); + self.add_resolved_types(&mut output_types, &result_node.type_id_of(self)); } // the both the target and the one matched against are // oneof then we need to check if they have any type in common @@ -603,28 +625,32 @@ impl<'a> Typechecker<'a> { .count() != 0 { - self.add_resolved_types(&mut output_types, &self.type_id_of(*result_node)); + self.add_resolved_types(&mut output_types, &result_node.type_id_of(self)); } else { - self.error("The target to be matched against and the possible types of the matched arm are completely disjoint", *match_node); + self.error("The target to be matched against and the possible types of the matched arm are completely disjoint", NodeIndexer::Expression(*match_node)); } } // Check if the two types can be matched (target_id, match_id) if self.is_type_compatible(target_id, match_id) => { - self.add_resolved_types(&mut output_types, &self.type_id_of(*result_node)); - } - _ => { - self.error("The types do not match", *match_node); + self.add_resolved_types(&mut output_types, &result_node.type_id_of(self)); } + _ => self.error("The types do not match", match_node.into_indexer()), } } output_types } - fn typecheck_binary_op(&mut self, lhs: NodeId, op: NodeId, rhs: NodeId) -> TypeId { - self.set_node_type_id(op, FORBIDDEN_TYPE); + fn typecheck_binary_op( + &mut self, + lhs: &ExpressionNodeId, + op: &NodeId, + rhs: &ExpressionNodeId, + ) -> TypeId { + op.set_node_type_id(self, FORBIDDEN_TYPE); // TODO: better error messages for type mismatches, the previous messages were better - match self.compiler.ast_nodes[op.0] { + let node = op.get_node(self.compiler); + match node { AstNode::Equal | AstNode::NotEqual => { let lhs_ty = self.typecheck_expr(lhs, TOP_TYPE); let rhs_ty = self.typecheck_expr(rhs, TOP_TYPE); @@ -714,7 +740,7 @@ impl<'a> Typechecker<'a> { if !self.constrain_subtype(lhs_ty, STRING_TYPE) { self.error( format!("Expected string, got {}", self.type_to_string(lhs_ty)), - lhs, + lhs.into_indexer(), ); } STRING_TYPE @@ -722,7 +748,7 @@ impl<'a> Typechecker<'a> { if !self.constrain_subtype(lhs_ty, NUMBER_TYPE) { self.error( format!("Expected number, got {}", self.type_to_string(lhs_ty)), - lhs, + lhs.into_indexer(), ); } self.numeric_op_type(lhs_ty, rhs_ty) @@ -779,11 +805,11 @@ impl<'a> Typechecker<'a> { fn typecheck_def( &mut self, - name: NodeId, - params: NodeId, - in_out_types: Option, - block: NodeId, - node_id: NodeId, + name: &NameOrString, + params: &NodeId, + in_out_types: &Option, + block: &BlockId, + node_id: StatementNodeId, ) { let in_out_types = in_out_types .map(|ty| { @@ -797,8 +823,8 @@ impl<'a> Typechecker<'a> { panic!("internal error: return type is not a return type"); }; InOutType { - in_type: self.typecheck_type(*in_ty), - out_type: self.typecheck_type(*out_ty), + in_type: self.typecheck_type(in_ty), + out_type: self.typecheck_type(out_ty), } }) .collect::>() @@ -806,20 +832,20 @@ impl<'a> Typechecker<'a> { .unwrap_or_default(); self.typecheck_node(params); - self.typecheck_node(block); - self.set_node_type_id(node_id, NONE_TYPE); + self.typecheck_block(block, TOP_TYPE); + node_id.set_node_type_id(self, NONE_TYPE); // set input/output types for the command let decl_id = self .compiler .decl_resolution - .get(&name) + .get(&name.into_indexer()) .expect("missing declared decl"); if in_out_types.is_empty() { self.decl_types[decl_id.0] = vec![InOutType { in_type: ANY_TYPE, - out_type: self.type_id_of(block), + out_type: block.type_id_of(self), }]; } else { // TODO check that block output type matches expected type @@ -827,17 +853,22 @@ impl<'a> Typechecker<'a> { } } - fn typecheck_alias(&mut self, new_name: NodeId, old_name: NodeId, node_id: NodeId) { - self.set_node_type_id(node_id, NONE_TYPE); + fn typecheck_alias( + &mut self, + new_name: &NameOrString, + old_name: &NameOrString, + node_id: StatementNodeId, + ) { + node_id.set_node_type_id(self, NONE_TYPE); // set input/output types for the command let decl_id_new = self .compiler .decl_resolution - .get(&new_name) + .get(&new_name.into_indexer()) .expect("missing declared new name for alias"); - let decl_id_old = self.compiler.decl_resolution.get(&old_name); + let decl_id_old = self.compiler.decl_resolution.get(&old_name.into_indexer()); self.decl_types[decl_id_new.0] = decl_id_old.map_or( vec![InOutType { @@ -848,29 +879,39 @@ impl<'a> Typechecker<'a> { ); } - fn typecheck_call(&mut self, parts: &[NodeId], node_id: NodeId) -> TypeId { - if let Some(decl_id) = self.compiler.decl_resolution.get(&node_id) { - let num_name_parts = self.compiler.decls[decl_id.0].name().split(' ').count(); + // TODO: something strange inside this function. + // The type of `self.compiler.deco_resolution` is unclear. + fn typecheck_call( + &mut self, + head: &[NameNodeId], + parts: &[ExpressionNodeId], + node_id: &ExpressionNodeId, + ) -> TypeId { + if let Some(decl_id) = self.compiler.decl_resolution.get(&node_id.into_indexer()) { let decl_node_id = self.compiler.decl_nodes[decl_id.0]; - let AstNode::Def { + let StatementNode::Def { type_params, params, .. - } = self.compiler.get_node(decl_node_id) + } = decl_node_id.get_node(&self.compiler) else { panic!("Internal error: Expected def") }; - let AstNode::Params(params) = self.compiler.get_node(*params) else { + let AstNode::Params(params) = params.get_node(self.compiler) else { panic!("Internal error: Expected params") }; let type_substs = if let Some(type_params) = type_params { - let AstNode::Params(type_params) = self.compiler.get_node(*type_params) else { + let AstNode::TypeParams(type_params) = type_params.get_node(self.compiler) else { panic!("Internal error: expected type params"); }; let mut type_substs = HashMap::new(); for type_param in type_params.iter() { - let type_decl_id = self.compiler.type_resolution[type_param]; + let type_decl_id = *self + .compiler + .type_resolution + .get(&NameOrVariable::Name(*type_param)) + .expect("should already resolved in resolver"); let var = self.new_typevar(BOTTOM_TYPE, TOP_TYPE); type_substs.insert(type_decl_id, var); } @@ -879,35 +920,35 @@ impl<'a> Typechecker<'a> { HashMap::new() }; - let num_args = parts.len() - num_name_parts; + let num_args = parts.len() - head.len(); if params.len() != num_args { self.error( format!("Expected {} argument(s), got {}", params.len(), num_args), - node_id, + node_id.into_indexer(), ); } - for (param, arg) in params.iter().zip(&parts[num_name_parts..]) { + for (param, arg) in params.iter().zip(parts) { let expected = self.type_id_of(*param); let expected = self.subst(expected, &type_substs); - if matches!(self.compiler.ast_nodes[arg.0], AstNode::Name) { - self.set_node_type_id(*arg, STRING_TYPE); + if matches!(arg.get_node(&self.compiler), ExpressionNode::Name(_)) { + arg.set_node_type_id(self, STRING_TYPE); if !self.constrain_subtype(STRING_TYPE, expected) { self.error( format!("Expected {}, got string", self.type_to_string(expected)), - *arg, + arg.into_indexer(), ); } } else { - self.typecheck_expr(*arg, expected); + self.typecheck_expr(arg, expected); } } if num_args > params.len() { // Typecheck extra arguments too - for arg in &parts[num_name_parts + params.len()..] { - if matches!(self.compiler.ast_nodes[arg.0], AstNode::Name) { - self.set_node_type_id(*arg, STRING_TYPE); + for arg in &parts[params.len()..] { + if matches!(arg.get_node(&self.compiler), ExpressionNode::Name(_)) { + arg.set_node_type_id(self, STRING_TYPE); } else { - self.typecheck_expr(*arg, TOP_TYPE); + self.typecheck_expr(arg, TOP_TYPE); } } } @@ -921,12 +962,11 @@ impl<'a> Typechecker<'a> { self.create_oneof(out_types) } else { // external call - for part in &parts[1..] { - if matches!(self.compiler.ast_nodes[part.0], AstNode::Name) { - self.set_node_type_id(*part, STRING_TYPE); - } else { - self.typecheck_expr(*part, TOP_TYPE); - } + for h in head { + h.set_node_type_id(self, STRING_TYPE); + } + for part in parts { + self.typecheck_expr(part, TOP_TYPE); } BYTE_STREAM_TYPE @@ -935,10 +975,10 @@ impl<'a> Typechecker<'a> { fn typecheck_let( &mut self, - variable_name: NodeId, - ty: Option, - initializer: NodeId, - node_id: NodeId, + variable_name: &VariableNodeId, + ty: &Option, + initializer: &ExpressionNodeId, + node_id: StatementNodeId, ) { let type_id = if let Some(ty) = ty { let ty_id = self.typecheck_type(ty); @@ -951,16 +991,16 @@ impl<'a> Typechecker<'a> { let var_id = self .compiler .var_resolution - .get(&variable_name) + .get(&NameOrVariable::Variable(*variable_name)) .expect("missing declared variable"); self.variable_types[var_id.0] = type_id; - self.set_node_type_id(variable_name, type_id); - self.set_node_type_id(node_id, NONE_TYPE); + variable_name.set_node_type_id(self, type_id); + node_id.set_node_type_id(self, NONE_TYPE); } - fn typecheck_type(&mut self, node_id: NodeId) -> TypeId { - let ty_id = match self.compiler.ast_nodes[node_id.0] { + fn typecheck_type(&mut self, node_id: &NodeId) -> TypeId { + let ty_id = match node_id.get_node(&self.compiler) { AstNode::Type { name, args, @@ -970,27 +1010,34 @@ impl<'a> Typechecker<'a> { fields, optional: _, // TODO handle optional record types } => { - let AstNode::Params(field_nodes) = self.compiler.get_node(fields) else { + let AstNode::Params(field_nodes) = fields.get_node(&self.compiler) else { panic!("internal error: record fields aren't Params"); }; let mut fields = field_nodes .iter() .map(|field| { - let AstNode::Param { name, ty } = self.compiler.get_node(*field) else { + let AstNode::Param { name, ty } = field.get_node(&self.compiler) else { panic!("internal error: record field isn't Param"); }; let ty_id = match ty { Some(ty) => { - self.typecheck_type(*ty); + self.typecheck_type(ty); self.type_id_of(*ty) } None => ANY_TYPE, }; - (*name, ty_id) + // NOTE: a bad way to convert from NameNodeId to ExpressionNodeId + let expr_node_id = self + .compiler + .expression_nodes + .iter_nodes() + .position(|expr_node| *expr_node == ExpressionNode::Name(*name)) + .expect("the Expression::Name should exist"); + (ExpressionNodeId(expr_node_id), ty_id) }) .collect::>(); // Store fields sorted by name - fields.sort_by_cached_key(|(name, _)| self.compiler.get_span_contents(*name)); + fields.sort_by_cached_key(|(name, _)| name.get_span_contents(&self.compiler)); self.record_types.push(fields); self.push_type(Type::Record(RecordTypeId(self.record_types.len() - 1))) @@ -999,24 +1046,24 @@ impl<'a> Typechecker<'a> { self.error( format!( "Internal error: expected type, got '{:?}'", - self.compiler.ast_nodes[node_id.0] + node_id.get_node(&self.compiler) ), - node_id, + node_id.into_indexer(), ); ERROR_TYPE } }; - self.set_node_type_id(node_id, ty_id); + node_id.set_node_type_id(self, ty_id); ty_id } fn typecheck_type_ref( &mut self, - name_id: NodeId, - args_id: Option, - _optional: bool, + name_id: &NameNodeId, + args_id: &Option, + _optional: &bool, ) -> TypeId { - let name = self.compiler.get_span_contents(name_id); + let name = name_id.get_span_contents(&self.compiler); // taken from parse_shape_name() in Nushell: match name { @@ -1027,14 +1074,14 @@ impl<'a> Typechecker<'a> { if let Some(args_id) = args_id { self.typecheck_node(args_id); - if let AstNode::TypeArgs(args) = self.compiler.get_node(args_id) { + if let AstNode::TypeArgs(args) = args_id.get_node(&self.compiler) { if args.len() > 1 { let types = - String::from_utf8_lossy(self.compiler.get_span_contents(args_id)); - self.error(format!("list must have only one type argument (to allow selection of types, use oneof{} -- WIP)", types), args_id); + String::from_utf8_lossy(args_id.get_span_contents(&self.compiler)); + self.error(format!("list must have only one type argument (to allow selection of types, use oneof{} -- WIP)", types), args_id.into_indexer()); self.push_type(Type::List(UNKNOWN_TYPE)) } else if args.is_empty() { - self.error("list must have one type argument", args_id); + self.error("list must have one type argument", args_id.into_indexer()); self.push_type(Type::List(UNKNOWN_TYPE)) } else { let args_ty_id = self.type_id_of(args[0]); @@ -1074,7 +1121,11 @@ impl<'a> Typechecker<'a> { // if bytes.contains(&b'@') { // // type with completion // } else { - if let Some(type_decl) = self.compiler.type_resolution.get(&name_id) { + if let Some(type_decl) = self + .compiler + .type_resolution + .get(&NameOrVariable::Name(*name_id)) + { self.push_type(Type::Ref(*type_decl)) } else { UNKNOWN_TYPE @@ -1232,8 +1283,8 @@ impl<'a> Typechecker<'a> { while i < sub_fields.len() && j < supe_fields.len() { let (sub_name, sub_ty) = sub_fields[i]; let (supe_name, supe_ty) = supe_fields[j]; - let sub_text = self.compiler.get_span_contents(sub_name); - let supe_text = self.compiler.get_span_contents(supe_name); + let sub_text = sub_name.get_span_contents(self.compiler); + let supe_text = supe_name.get_span_contents(self.compiler); match sub_text.cmp(supe_text) { Ordering::Less => { i += 1; @@ -1334,8 +1385,8 @@ impl<'a> Typechecker<'a> { while i < sub_fields.len() && j < supe_fields.len() { let (sub_name, sub_ty) = sub_fields[i]; let (supe_name, supe_ty) = supe_fields[j]; - let sub_text = self.compiler.get_span_contents(sub_name); - let supe_text = self.compiler.get_span_contents(supe_name); + let sub_text = sub_name.get_span_contents(self.compiler); + let supe_text = supe_name.get_span_contents(self.compiler); match sub_text.cmp(supe_text) { Ordering::Less => { i += 1; @@ -1497,7 +1548,7 @@ impl<'a> Typechecker<'a> { let mut fmt = "record<".to_string(); let types = &self.record_types[id.0]; for (name, ty) in types { - fmt += &String::from_utf8_lossy(self.compiler.get_span_contents(*name)); + fmt += &String::from_utf8_lossy(name.get_span_contents(&self.compiler)); fmt += ": "; fmt += &self.type_to_string(*ty); fmt += ", "; @@ -1578,8 +1629,8 @@ impl<'a> Typechecker<'a> { while l < lhs_fields.len() && r < rhs_fields.len() { let (lhs_name, lhs_ty) = lhs_fields[l]; let (rhs_name, rhs_ty) = rhs_fields[r]; - let lhs_text = self.compiler.get_span_contents(lhs_name); - let rhs_text = self.compiler.get_span_contents(rhs_name); + let lhs_text = lhs_name.get_span_contents(&self.compiler); + let rhs_text = rhs_name.get_span_contents(&self.compiler); match lhs_text.cmp(rhs_text) { Ordering::Less => { l += 1; @@ -1603,7 +1654,7 @@ impl<'a> Typechecker<'a> { } } - fn error(&mut self, msg: impl Into, node_id: NodeId) { + fn error(&mut self, msg: impl Into, node_id: NodeIndexer) { self.errors.push(SourceError { message: msg.into(), node_id, @@ -1611,17 +1662,23 @@ impl<'a> Typechecker<'a> { }) } - fn binary_op_err(&mut self, op_msg: &str, lhs: NodeId, op: NodeId, rhs: NodeId) { + fn binary_op_err( + &mut self, + op_msg: &str, + lhs: &ExpressionNodeId, + op: &NodeId, + rhs: &ExpressionNodeId, + ) { self.error( format!( "type mismatch: unsupported {} between {} and {}", op_msg, - self.type_to_string(self.type_id_of(lhs)), - self.type_to_string(self.type_id_of(rhs)), + self.type_to_string(lhs.type_id_of(self)), + self.type_to_string(rhs.type_id_of(self)), ), - op, + op.into_indexer(), ); - self.set_node_type_id(op, ERROR_TYPE); + op.set_node_type_id(self, ERROR_TYPE); } fn add_resolved_types(&mut self, types: &mut HashSet, ty: &TypeId) { @@ -1655,7 +1712,7 @@ impl<'a> Typechecker<'a> { let mut simple_types = HashSet::::new(); let mut list_elems = HashSet::new(); - let mut record_fields = HashMap::<&[u8], (NodeId, HashSet)>::new(); + let mut record_fields = HashMap::<&[u8], (ExpressionNodeId, HashSet)>::new(); for ty_id in flattened { if simple_types.contains(&ty_id) { continue; @@ -1689,7 +1746,7 @@ impl<'a> Typechecker<'a> { Type::Record(rec_ty_id) => { let new_fields = &self.record_types[rec_ty_id.0]; for (name_node, ty) in new_fields.iter() { - let name = self.compiler.get_span_contents(*name_node); + let name = name_node.get_span_contents(&self.compiler); if let Some((_, types)) = record_fields.get_mut(&name) { types.insert(*ty); } else { @@ -1732,7 +1789,7 @@ impl<'a> Typechecker<'a> { for (_, (node, types)) in record_fields.into_iter() { fields.push((node, self.create_oneof(types))); } - fields.sort_by_cached_key(|(name_node, _)| self.compiler.get_span_contents(*name_node)); + fields.sort_by_cached_key(|(name_node, _)| name_node.get_span_contents(&self.compiler)); let rec_ty_id = RecordTypeId(self.record_types.len()); self.record_types.push(fields); @@ -1773,7 +1830,7 @@ impl<'a> Typechecker<'a> { let mut refs = HashMap::::new(); let mut simple_type: Option = None; let mut list_elems = HashSet::new(); - let mut record_fields = HashMap::<&[u8], (NodeId, HashSet)>::new(); + let mut record_fields = HashMap::<&[u8], (ExpressionNodeId, HashSet)>::new(); let mut oneof_ids = Vec::new(); for ty_id in flattened { let ty = self.types[ty_id.0]; @@ -1801,7 +1858,7 @@ impl<'a> Typechecker<'a> { } let new_fields = &self.record_types[rec_ty_id.0]; for (name_node, ty) in new_fields.iter() { - let name = self.compiler.get_span_contents(*name_node); + let name = name_node.get_span_contents(&self.compiler); if let Some((_, types)) = record_fields.get_mut(&name) { types.insert(*ty); } else { @@ -1848,7 +1905,7 @@ impl<'a> Typechecker<'a> { for (_, (node, types)) in record_fields.into_iter() { fields.push((node, self.create_oneof(types))); } - fields.sort_by_cached_key(|(name_node, _)| self.compiler.get_span_contents(*name_node)); + fields.sort_by_cached_key(|(name_node, _)| name_node.get_span_contents(&self.compiler)); let rec_ty_id = RecordTypeId(self.record_types.len()); self.record_types.push(fields); @@ -1892,3 +1949,108 @@ impl<'a> Typechecker<'a> { self.create_oneof(inters) } } + +trait NodeTypeSetter { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId); + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId; +} + +impl NodeTypeSetter for NameNodeId { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + typechecker.name_node_types[self.0] = type_id; + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + typechecker.name_node_types[self.0] + } +} + +impl NodeTypeSetter for StringNodeId { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + typechecker.string_node_types[self.0] = type_id; + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + typechecker.string_node_types[self.0] + } +} + +impl NodeTypeSetter for VariableNodeId { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + typechecker.variable_node_types[self.0] = type_id; + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + typechecker.variable_node_types[self.0] + } +} + +impl NodeTypeSetter for NodeId { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + typechecker.node_types[self.0] = type_id; + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + typechecker.node_types[self.0] + } +} + +impl NodeTypeSetter for ExpressionNodeId { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + typechecker.expression_node_types[self.0] = type_id; + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + typechecker.expression_node_types[self.0] + } +} + +impl NodeTypeSetter for StatementNodeId { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + typechecker.statement_node_types[self.0] = type_id; + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + typechecker.statement_node_types[self.0] + } +} + +impl NodeTypeSetter for BlockId { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + typechecker.block_node_types[self.0] = type_id; + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + typechecker.block_node_types[self.0] + } +} + +impl NodeTypeSetter for PipelineId { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + typechecker.pipeline_node_types[self.0] = type_id; + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + typechecker.pipeline_node_types[self.0] + } +} + +impl NodeTypeSetter for StatementOrExpression { + fn set_node_type_id(&self, typechecker: &mut Typechecker, type_id: TypeId) { + match self { + StatementOrExpression::Statement(stmt_id) => { + stmt_id.set_node_type_id(typechecker, type_id) + } + StatementOrExpression::Expression(expr_id) => { + expr_id.set_node_type_id(typechecker, type_id) + } + } + } + + fn type_id_of(&self, typechecker: &Typechecker) -> TypeId { + match self { + StatementOrExpression::Statement(stmt_id) => stmt_id.type_id_of(typechecker), + StatementOrExpression::Expression(expr_id) => expr_id.type_id_of(typechecker), + } + } +}