From cd405f4153135565e58181ab49439b8122f9c40e Mon Sep 17 00:00:00 2001 From: dank-openai Date: Tue, 23 Jun 2026 15:56:50 -0400 Subject: [PATCH] dslx: parse semantic sum declarations and constructors Adds frontend parsing, formatting, visitors, and diagnostics for semantic sum syntax while preserving typed empty numeric enum parsing until semantic-sum runtime/type support lands higher in the stack. branch:dank/upstream/semantic-sum-frontend --- xls/dslx/bytecode/bytecode_emitter.cc | 11 + xls/dslx/bytecode/bytecode_emitter.h | 1 + xls/dslx/constexpr_evaluator.cc | 4 + xls/dslx/constexpr_evaluator.h | 1 + xls/dslx/cpp_transpiler/cpp_type_generator.cc | 5 + xls/dslx/dslx_fmt_test.py | 19 + .../match_exhaustiveness_checker.cc | 9 + xls/dslx/fmt/ast_fmt.cc | 371 ++++++++++++ xls/dslx/fmt/ast_fmt.h | 14 + xls/dslx/fmt/ast_fmt_test.cc | 144 ++++- xls/dslx/frontend/ast.cc | 307 ++++++++++ xls/dslx/frontend/ast.h | 296 +++++++++- xls/dslx/frontend/ast_cloner.cc | 117 ++++ xls/dslx/frontend/ast_cloner_test.cc | 37 ++ xls/dslx/frontend/ast_node.h | 4 + xls/dslx/frontend/ast_utils.cc | 4 + xls/dslx/frontend/bindings.cc | 1 + xls/dslx/frontend/bindings.h | 5 +- xls/dslx/frontend/module.cc | 11 + xls/dslx/frontend/module.h | 4 +- xls/dslx/frontend/parser.cc | 532 ++++++++++++++++-- xls/dslx/frontend/parser.h | 23 +- xls/dslx/frontend/parser_test.cc | 306 +++++++++- xls/dslx/frontend/scanner_test.cc | 6 + xls/dslx/frontend/semantics_analysis.h | 2 + .../ir_convert/extract_conversion_order.cc | 11 + xls/dslx/ir_convert/function_converter.cc | 12 + xls/dslx/lsp/document_symbols.cc | 10 + xls/dslx/tests/errors/error_modules_test.py | 4 +- xls/dslx/translators/dslx_to_verilog.cc | 9 + xls/dslx/translators/dslx_to_verilog_main.cc | 1 + xls/dslx/type_system/type_info.cc | 4 + xls/dslx/type_system/type_info.h | 3 +- xls/dslx/type_system/type_info.proto | 4 + xls/dslx/type_system/type_info_to_proto.cc | 16 + .../typecheck_module_v2_enum_test.cc | 21 +- .../type_system_v2/validate_concrete_type.cc | 27 +- xls/fuzzer/value_generator.cc | 4 + 38 files changed, 2265 insertions(+), 95 deletions(-) diff --git a/xls/dslx/bytecode/bytecode_emitter.cc b/xls/dslx/bytecode/bytecode_emitter.cc index f3b8498230..f7181b1bad 100644 --- a/xls/dslx/bytecode/bytecode_emitter.cc +++ b/xls/dslx/bytecode/bytecode_emitter.cc @@ -1300,6 +1300,12 @@ absl::StatusOr BytecodeEmitter::HandleNameDefTreeExpr( [&](WildcardPattern* n) -> absl::StatusOr { return Bytecode::MatchArmItem::MakeWildcard(); }, + [&](SumVariantPayloadPattern* /*n*/) + -> absl::StatusOr { + return absl::UnimplementedError( + "Semantic sum patterns are not supported by the bytecode " + "runtime."); + }, [&](RestOfTuple* n) -> absl::StatusOr { return Bytecode::MatchArmItem::MakeRestOfTuple(); }, @@ -1676,6 +1682,11 @@ absl::Status BytecodeEmitter::HandleStructInstance(const StructInstance* node) { return absl::OkStatus(); } +absl::Status BytecodeEmitter::HandleSumInstance(const SumInstance*) { + return absl::UnimplementedError( + "Semantic sum execution is not supported by the bytecode runtime."); +} + absl::Status BytecodeEmitter::HandleSplatStructInstance( const SplatStructInstance* node) { // For each field in the struct: diff --git a/xls/dslx/bytecode/bytecode_emitter.h b/xls/dslx/bytecode/bytecode_emitter.h index fe2b819cf4..840c302edf 100644 --- a/xls/dslx/bytecode/bytecode_emitter.h +++ b/xls/dslx/bytecode/bytecode_emitter.h @@ -152,6 +152,7 @@ class BytecodeEmitter : public ExprVisitor { absl::Status HandleSpawn(const Spawn* node) override; absl::Status HandleString(const String* node) override; absl::Status HandleStructInstance(const StructInstance* node) override; + absl::Status HandleSumInstance(const SumInstance* node) override; absl::Status HandleSplatStructInstance( const SplatStructInstance* node) override; absl::Status HandleConditional(const Conditional* node) override; diff --git a/xls/dslx/constexpr_evaluator.cc b/xls/dslx/constexpr_evaluator.cc index 600ba47ecb..df0e2c4e7b 100644 --- a/xls/dslx/constexpr_evaluator.cc +++ b/xls/dslx/constexpr_evaluator.cc @@ -518,6 +518,10 @@ absl::Status ConstexprEvaluator::HandleStructInstance( return InterpretExpr(expr); } +absl::Status ConstexprEvaluator::HandleSumInstance(const SumInstance*) { + return absl::UnimplementedError("Semantic sum constants are not supported."); +} + absl::Status ConstexprEvaluator::HandleConditional(const Conditional* expr) { // Simple enough that we don't need to invoke the interpreter. EVAL_AS_CONSTEXPR_OR_RETURN(expr->test()); diff --git a/xls/dslx/constexpr_evaluator.h b/xls/dslx/constexpr_evaluator.h index d3346eb2b6..93f59edc7f 100644 --- a/xls/dslx/constexpr_evaluator.h +++ b/xls/dslx/constexpr_evaluator.h @@ -99,6 +99,7 @@ class ConstexprEvaluator : public xls::dslx::ExprVisitor { absl::Status HandleStatementBlock(const StatementBlock* expr) override; absl::Status HandleString(const String* expr) override; absl::Status HandleStructInstance(const StructInstance* expr) override; + absl::Status HandleSumInstance(const SumInstance* expr) override; absl::Status HandleTupleIndex(const TupleIndex* expr) override; absl::Status HandleUnop(const Unop* expr) override; absl::Status HandleConstFor(const ConstFor* expr) override; diff --git a/xls/dslx/cpp_transpiler/cpp_type_generator.cc b/xls/dslx/cpp_transpiler/cpp_type_generator.cc index 5ba229580e..ab80a34de4 100644 --- a/xls/dslx/cpp_transpiler/cpp_type_generator.cc +++ b/xls/dslx/cpp_transpiler/cpp_type_generator.cc @@ -753,6 +753,11 @@ CppTypeGenerator::Create(const TypeDefinition& type_definition, return EnumCppTypeGenerator::Create(enum_def, type_info, import_data, namespaces); }, + [&](const SumDef* sum_def) + -> absl::StatusOr> { + return absl::UnimplementedError( + absl::StrFormat("Unsupported type: %s", sum_def->ToString())); + }, [&](const ColonRef* colon_ref) -> absl::StatusOr> { return absl::UnimplementedError( diff --git a/xls/dslx/dslx_fmt_test.py b/xls/dslx/dslx_fmt_test.py index 0105bc30aa..ea609cf721 100644 --- a/xls/dslx/dslx_fmt_test.py +++ b/xls/dslx/dslx_fmt_test.py @@ -58,6 +58,25 @@ def test_small_no_spaces_example_in_place(self): """) self.assertEqual(self._run(contents, in_place=True), want) + def test_semantic_sum_payload_comment_is_preserved(self): + contents = textwrap.dedent("""\ + enum Option { + Some( + // Payload comment. + u32, + ), + } + """) + want = textwrap.dedent("""\ + enum Option { + Some( + // Payload comment. + u32 + ), + } + """) + self.assertEqual(self._run(contents), want) + def test_multi_file_in_place_fmt(self): contents0 = 'fn f()->u32{u32:42}' contents1 = 'fn g()->u32{u32:42}' diff --git a/xls/dslx/exhaustiveness/match_exhaustiveness_checker.cc b/xls/dslx/exhaustiveness/match_exhaustiveness_checker.cc index ca78ad677c..ebfc89136f 100644 --- a/xls/dslx/exhaustiveness/match_exhaustiveness_checker.cc +++ b/xls/dslx/exhaustiveness/match_exhaustiveness_checker.cc @@ -202,6 +202,11 @@ PatternLeaf ToPatternLeaf(const NameDefTree::Leaf& leaf) { return SomeWildcard(); }, [&](Number* number) -> PatternLeaf { return number; }, + [&](SumVariantPayloadPattern* /*constructor_pattern*/) + -> PatternLeaf { + LOG(FATAL) << "SumVariantPayloadPattern not yet supported in " + "MatchExhaustivenessChecker"; + }, [&](RestOfTuple* rest_of_tuple) -> PatternLeaf { LOG(FATAL) << "RestOfTuple not valid for conversion to PatternLeaf"; }}, @@ -293,6 +298,10 @@ std::vector ExpandPatternLeaves(const NameDefTree& pattern, result.push_back(ToPatternLeaf(leaf)); types_index += 1; }, + [&](const SumVariantPayloadPattern* /*unused*/) { + result.push_back(ToPatternLeaf(leaf)); + types_index += 1; + }, [&](const RestOfTuple* /*unused*/) { // Instead of using flattened_index here, use types_index (the // number of tuple elements already matched) to figure out how diff --git a/xls/dslx/fmt/ast_fmt.cc b/xls/dslx/fmt/ast_fmt.cc index 18440007dc..3f032a9bfc 100644 --- a/xls/dslx/fmt/ast_fmt.cc +++ b/xls/dslx/fmt/ast_fmt.cc @@ -1577,6 +1577,44 @@ DocRef Formatter::FormatStructInstance(const StructInstance& n) { leader, arena_.MakeGroup(arena_.MakeFlatChoice(on_flat, on_break))); } +DocRef Formatter::FormatSumInstance(const SumInstance& n) { + if (n.is_unit()) { + return FormatColonRef(*n.constructor_ref()); + } + if (n.is_tuple()) { + DocRef args = + FormatJoin(n.tuple_payload_args(), + Joiner::kCommaBreak1AsGroupNoTrailingComma, + [this](const Expr* arg) { + return FormatExpr(*arg); + }); + return ConcatNGroup(arena_, {FormatColonRef(*n.constructor_ref()), + arena_.oparen(), args, arena_.cparen()}); + } + + std::vector pieces = {FormatColonRef(*n.constructor_ref()), + arena_.space(), arena_.ocurl()}; + if (n.struct_payload_field_args().empty()) { + pieces.push_back(arena_.space()); + } else { + pieces.push_back(arena_.space()); + for (int64_t i = 0; i < n.struct_payload_field_args().size(); ++i) { + const auto& [name, arg] = n.struct_payload_field_args()[i]; + pieces.push_back(arena_.MakeText(name)); + pieces.push_back(arena_.colon()); + pieces.push_back(arena_.space()); + pieces.push_back(FormatExpr(*arg)); + if (i != n.struct_payload_field_args().size() - 1) { + pieces.push_back(arena_.comma()); + pieces.push_back(arena_.break1()); + } + } + pieces.push_back(arena_.space()); + } + pieces.push_back(arena_.ccurl()); + return ConcatNGroup(arena_, pieces); +} + DocRef Formatter::FormatSplatStructInstance(const SplatStructInstance& n) { DocRef leader = FormatStructLeader(n.struct_ref()); DocRef splatted = Format(n.splatted()); @@ -1756,6 +1794,9 @@ DocRef Formatter::FormatRange(const Range& n) { DocRef Formatter::FormatNameDefTreeLeaf(const NameDefTree::Leaf& n) { return absl::visit( Visitor{ + [&](const SumVariantPayloadPattern* n) { + return FormatSumVariantPayloadPattern(*n); + }, [&](const NameDef* n) { return FormatNameDef(*n); }, [&](const NameRef* n) { return FormatNameRef(*n); }, [&](const WildcardPattern* n) { return FormatWildcardPattern(*n); }, @@ -1767,6 +1808,44 @@ DocRef Formatter::FormatNameDefTreeLeaf(const NameDefTree::Leaf& n) { n); } +DocRef Formatter::FormatSumVariantPayloadPattern( + const SumVariantPayloadPattern& n) { + std::vector pieces = {FormatColonRef(*n.constructor_ref())}; + if (n.is_tuple()) { + pieces.push_back(arena_.oparen()); + pieces.push_back(FormatJoin( + n.tuple_payload_patterns(), + Joiner::kCommaBreak1AsGroupTrailingCommaOnBreak, + [this](const NameDefTree* pattern) { + return FormatNameDefTree(*pattern); + })); + pieces.push_back(arena_.cparen()); + return ConcatNGroup(arena_, pieces); + } + + pieces.push_back(arena_.space()); + pieces.push_back(arena_.ocurl()); + if (!n.struct_payload_field_patterns().empty()) { + pieces.push_back(arena_.break1()); + pieces.push_back(FormatJoin< + SumVariantPayloadPattern::StructPayloadFieldPattern>( + n.struct_payload_field_patterns(), + Joiner::kCommaBreak1AsGroupTrailingCommaOnBreak, + [this]( + const SumVariantPayloadPattern::StructPayloadFieldPattern& field) { + const auto& [name, pattern] = field; + return ConcatNGroup( + arena_, {arena_.MakeText(name), arena_.colon(), arena_.space(), + FormatNameDefTree(*pattern)}); + })); + pieces.push_back(arena_.break1()); + } else { + pieces.push_back(arena_.space()); + } + pieces.push_back(arena_.ccurl()); + return ConcatNGroup(arena_, pieces); +} + DocRef Formatter::FormatNameDefTree(const NameDefTree& n) { if (n.is_leaf()) { return FormatNameDefTreeLeaf(n.leaf()); @@ -1863,6 +1942,15 @@ DocRef Formatter::FormatBlockedExprLeader(const Expr& e) { return ConcatN(arena_, {FormatStructLeader(n.struct_ref()), arena_.space(), arena_.ocurl()}); } + case AstNodeKind::kSumInstance: { + const SumInstance& n = static_cast(e); + if (n.is_tuple()) { + return arena_.MakeConcat(FormatColonRef(*n.constructor_ref()), + arena_.oparen()); + } + return ConcatN(arena_, {FormatColonRef(*n.constructor_ref()), + arena_.space(), arena_.ocurl()}); + } case AstNodeKind::kSplatStructInstance: { const SplatStructInstance& n = static_cast(e); return ConcatN(arena_, {FormatStructLeader(n.struct_ref()), @@ -2607,6 +2695,288 @@ DocRef Formatter::FormatStructDef(const StructDef& n) { return FormatStructDefBase(n, Keyword::kStruct, n.extern_type_name()); } +bool Formatter::FormatAppendSumCommentsBetween( + const Pos& start_pos, const Pos& limit_pos, std::vector& pieces) { + bool emitted_comments = false; + std::optional last_comment_span; + if (std::optional comments_doc = + FormatCommentsBetween(start_pos, limit_pos, &last_comment_span)) { + pieces.push_back(arena_.hard_line()); + pieces.push_back(comments_doc.value()); + pieces.push_back(arena_.hard_line()); + emitted_comments = true; + } + return emitted_comments; +} + +void Formatter::FormatSumStructMembers( + absl::Span members, const Span& body_span, + std::vector& pieces, bool place_internal_comments) { + if (!members.empty()) { + pieces.push_back(arena_.break1()); + } + + std::vector body_pieces; + Pos last_member_pos = body_span.start(); + for (size_t i = 0; i < members.size(); ++i) { + const StructMemberNode* member = members[i]; + const Span member_span = member->span(); + const Pos& member_start = member_span.start(); + + std::optional last_comment_span; + if (std::optional comments_doc = FormatCommentsBetween( + last_member_pos, member_start, &last_comment_span)) { + body_pieces.push_back(comments_doc.value()); + body_pieces.push_back(arena_.hard_line()); + if (last_comment_span->limit().lineno() != member_start.lineno()) { + body_pieces.push_back(arena_.hard_line()); + } + } + + last_member_pos = member_span.limit(); + + body_pieces.push_back(arena_.MakeText(member->name())); + if (place_internal_comments) { + FormatAppendSumCommentsBetween(member->name_def()->span().limit(), + member->colon_span().start(), + body_pieces); + } + body_pieces.push_back(arena_.colon()); + if (!place_internal_comments || + !FormatAppendSumCommentsBetween(member->colon_span().limit(), + member->type()->span().start(), + body_pieces)) { + body_pieces.push_back(arena_.space()); + } + body_pieces.push_back(FormatTypeAnnotation(*member->type())); + bool last_member = i + 1 == members.size(); + if (last_member) { + body_pieces.push_back(arena_.MakeFlatChoice(/*on_flat=*/arena_.empty(), + /*on_break=*/arena_.comma())); + } else { + body_pieces.push_back(arena_.comma()); + } + + Pos new_last_member_pos = FormatCollectInlineComments( + member_span.limit(), last_member_pos, body_pieces, last_comment_span); + + bool had_inline = new_last_member_pos != last_member_pos; + if (!last_member) { + body_pieces.push_back(had_inline ? arena_.hard_line() : arena_.break1()); + } + last_member_pos = new_last_member_pos; + } + + std::optional last_comment_span; + bool emitted_trailing_comment = false; + if (std::optional comments_doc = + FormatCommentsBetween(last_member_pos, body_span.limit(), + &last_comment_span)) { + body_pieces.push_back(arena_.hard_line()); + body_pieces.push_back(comments_doc.value()); + emitted_trailing_comment = true; + } + + pieces.push_back(arena_.MakeNest(ConcatN(arena_, body_pieces))); + + if (!members.empty() || emitted_trailing_comment) { + pieces.push_back(arena_.break1()); + } +} + +void Formatter::FormatSumTuplePayloadMembers(const SumVariant& variant, + const Span& payload_span, + std::vector& pieces) { + if (!comments_.HasComments(payload_span)) { + pieces.push_back(FormatJoin( + variant.tuple_members(), + Joiner::kCommaBreak1AsGroupTrailingCommaOnBreak, + [this](const TypeAnnotation* member) { + return FormatTypeAnnotation(*member); + })); + return; + } + + std::vector body_pieces; + Pos last_member_pos = payload_span.start(); + for (size_t i = 0; i < variant.tuple_members().size(); ++i) { + const TypeAnnotation* member = variant.tuple_members()[i]; + const bool first_member = i == 0; + const Span comments_span(last_member_pos, member->span().start()); + const std::vector comment_items = + comments_.GetComments(comments_span); + if (std::optional comments_doc = + FormatCommentsBetween(last_member_pos, member->span().start(), + nullptr)) { + if (!first_member) { + body_pieces.push_back(arena_.comma()); + body_pieces.push_back(comment_items.front()->span.start().lineno() == + last_member_pos.lineno() + ? arena_.space() + : arena_.hard_line()); + } + body_pieces.push_back(comments_doc.value()); + body_pieces.push_back(arena_.hard_line()); + } else if (!first_member) { + body_pieces.push_back(arena_.comma()); + body_pieces.push_back(arena_.hard_line()); + } + body_pieces.push_back(FormatTypeAnnotation(*member)); + last_member_pos = member->span().limit(); + } + + const Span comments_span(last_member_pos, payload_span.limit()); + const std::vector comment_items = + comments_.GetComments(comments_span); + if (std::optional comments_doc = + FormatCommentsBetween(last_member_pos, payload_span.limit(), + nullptr)) { + if (!variant.tuple_members().empty()) { + body_pieces.push_back(arena_.comma()); + body_pieces.push_back(comment_items.front()->span.start().lineno() == + last_member_pos.lineno() + ? arena_.space() + : arena_.hard_line()); + } + body_pieces.push_back(comments_doc.value()); + } + + pieces.push_back(arena_.hard_line()); + pieces.push_back(arena_.MakeNest(ConcatN(arena_, body_pieces))); + pieces.push_back(arena_.hard_line()); +} + +DocRef Formatter::FormatSumDef(const SumDef& n) { + std::vector pieces; + std::vector attrs; + for (const Attribute* attribute : n.attributes()) { + attrs.push_back(FormatAttribute(*attribute)); + } + if (n.is_public()) { + pieces.push_back(arena_.Make(Keyword::kPub)); + pieces.push_back(arena_.space()); + } + pieces.push_back(arena_.Make(Keyword::kEnum)); + pieces.push_back(arena_.space()); + pieces.push_back(arena_.MakeText(n.identifier())); + + if (!n.parametric_bindings().empty()) { + pieces.push_back(arena_.oangle()); + pieces.push_back(FormatJoin( + n.parametric_bindings(), Joiner::kCommaSpace, + [this](const ParametricBinding* binding) { + return FormatParametricBindingPtr(binding); + })); + pieces.push_back(arena_.cangle()); + } + + pieces.push_back(arena_.space()); + if (n.tag_type_annotation() != nullptr) { + pieces.push_back(arena_.colon()); + pieces.push_back(arena_.space()); + pieces.push_back(FormatTypeAnnotation(*n.tag_type_annotation())); + pieces.push_back(arena_.space()); + } + pieces.push_back(arena_.ocurl()); + pieces.push_back(arena_.hard_line()); + + std::vector nested; + nested.reserve(n.variants().size() * 2); + Pos last_variant_pos = n.span().start(); + for (size_t i = 0; i < n.variants().size(); ++i) { + const SumVariant* variant = n.variants()[i]; + std::optional last_comment_span; + if (std::optional comments_doc = FormatCommentsBetween( + last_variant_pos, variant->span().start(), &last_comment_span)) { + nested.push_back(comments_doc.value()); + nested.push_back(arena_.hard_line()); + if (last_comment_span->limit().lineno() != + variant->span().start().lineno()) { + nested.push_back(arena_.hard_line()); + } + } + + std::vector variant_pieces = { + arena_.MakeText(variant->identifier())}; + Pos last_piece_limit = variant->name_def()->span().limit(); + if (variant->is_tuple()) { + const Span payload_span = + variant->payload_span().value_or(variant->span()); + FormatAppendSumCommentsBetween(last_piece_limit, payload_span.start(), + variant_pieces); + variant_pieces.push_back(arena_.oparen()); + FormatSumTuplePayloadMembers(*variant, payload_span, variant_pieces); + variant_pieces.push_back(arena_.cparen()); + last_piece_limit = payload_span.limit(); + } else if (variant->is_struct()) { + const Span payload_span = + variant->payload_span().value_or(variant->span()); + if (!FormatAppendSumCommentsBetween(last_piece_limit, + payload_span.start(), + variant_pieces)) { + variant_pieces.push_back(arena_.space()); + } + variant_pieces.push_back(arena_.ocurl()); + if (!variant->struct_members().empty() || + comments_.HasComments(payload_span)) { + FormatSumStructMembers(variant->struct_members(), payload_span, + variant_pieces, + /*place_internal_comments=*/true); + } else { + variant_pieces.push_back(arena_.space()); + } + variant_pieces.push_back(arena_.ccurl()); + last_piece_limit = payload_span.limit(); + } + if (variant->discriminant() != nullptr) { + if (variant->discriminant_equals_span().has_value()) { + if (!FormatAppendSumCommentsBetween( + last_piece_limit, variant->discriminant_equals_span()->start(), + variant_pieces)) { + variant_pieces.push_back(arena_.space()); + } + } else { + variant_pieces.push_back(arena_.space()); + } + variant_pieces.push_back(arena_.equals()); + if (variant->discriminant_equals_span().has_value()) { + if (!FormatAppendSumCommentsBetween( + variant->discriminant_equals_span()->limit(), + variant->discriminant()->span().start(), variant_pieces)) { + variant_pieces.push_back(arena_.space()); + } + } else { + variant_pieces.push_back(arena_.space()); + } + variant_pieces.push_back(FormatExpr(*variant->discriminant())); + } + variant_pieces.push_back(arena_.comma()); + nested.push_back(ConcatNGroup(arena_, variant_pieces)); + + last_variant_pos = variant->span().limit(); + last_variant_pos = FormatCollectInlineComments( + variant->span().limit(), last_variant_pos, nested, last_comment_span); + if (i + 1 != n.variants().size()) { + nested.push_back(arena_.hard_line()); + } + } + + std::optional last_comment_span; + if (std::optional comments_doc = + FormatCommentsBetween(last_variant_pos, n.span().limit(), + &last_comment_span)) { + if (!nested.empty()) { + nested.push_back(arena_.hard_line()); + } + nested.push_back(comments_doc.value()); + } + + pieces.push_back(arena_.MakeNest(ConcatN(arena_, nested))); + pieces.push_back(arena_.hard_line()); + pieces.push_back(arena_.ccurl()); + return FormatJoinWithAttrs(attrs, ConcatN(arena_, pieces)); +} + DocRef Formatter::FormatProcDef(const ProcDef& n) { return FormatStructDefBase(n, Keyword::kProc, /*extern_type_name=*/std::nullopt); @@ -2959,6 +3329,7 @@ DocRef Formatter::FormatModuleMember(const ModuleMember& n) { return arena_.MakeConcat(FormatProcAlias(*n), arena_.semi()); }, [&](const StructDef* n) { return FormatStructDef(*n); }, + [&](const SumDef* n) { return FormatSumDef(*n); }, [&](const ProcDef* n) { return FormatProcDef(*n); }, [&](const Impl* n) { return FormatImpl(*n); }, [&](const Trait* n) { return FormatTrait(*n); }, diff --git a/xls/dslx/fmt/ast_fmt.h b/xls/dslx/fmt/ast_fmt.h index a88797f5c0..2f91563bb7 100644 --- a/xls/dslx/fmt/ast_fmt.h +++ b/xls/dslx/fmt/ast_fmt.h @@ -151,6 +151,10 @@ class Formatter { const StructDefBase& n, Keyword keyword, const std::optional& extern_type_name); virtual DocRef FormatStructInstance(const StructInstance& n); + virtual DocRef FormatSumDef(const SumDef& n); + virtual DocRef FormatSumInstance(const SumInstance& n); + virtual DocRef FormatSumVariantPayloadPattern( + const SumVariantPayloadPattern& n); virtual DocRef FormatTestFunction(const TestFunction& n); virtual DocRef FormatTestProc(const TestProc& n); virtual DocRef FormatTrait(const Trait& n); @@ -204,6 +208,16 @@ class Formatter { absl::Span> members); DocRef FormatStructMembersFlat( absl::Span> members); + bool FormatAppendSumCommentsBetween(const Pos& start_pos, + const Pos& limit_pos, + std::vector& pieces); + void FormatSumStructMembers(absl::Span members, + const Span& body_span, + std::vector& pieces, + bool place_internal_comments = false); + void FormatSumTuplePayloadMembers(const SumVariant& variant, + const Span& payload_span, + std::vector& pieces); DocRef FormatTuple(const XlsTuple& n); DocRef FormatTupleWithoutComments(const XlsTuple& n); template diff --git a/xls/dslx/fmt/ast_fmt_test.cc b/xls/dslx/fmt/ast_fmt_test.cc index ab081d7580..0bdbf54ee4 100644 --- a/xls/dslx/fmt/ast_fmt_test.cc +++ b/xls/dslx/fmt/ast_fmt_test.cc @@ -21,14 +21,14 @@ #include #include -#include "gmock/gmock.h" -#include "gtest/gtest.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" #include "xls/common/logging/log_lines.h" #include "xls/common/status/matchers.h" #include "xls/common/status/status_macros.h" @@ -1646,6 +1646,146 @@ SECOND = 1, kWant); } +TEST_F(ModuleFmtTest, SemanticSumEmptyPayloadShapes) { + DoFmt(R"(enum Option{None,EmptyTuple(),EmptyStruct{},Some(u32),Point{x:u32}} +fn f(x:Option)->Option{match x{Option::EmptyTuple()=>Option::EmptyTuple(),Option::EmptyStruct{}=>Option::EmptyStruct{},Option::Some(v)=>Option::Some(v),Option::Point{x:px}=>Option::Point{x:px},Option::None=>Option::None,}} +)", + R"(enum Option { + None, + EmptyTuple(), + EmptyStruct { }, + Some(u32), + Point { x: u32 }, +} + +fn f(x: Option) -> Option { + match x { + Option::EmptyTuple() => Option::EmptyTuple(), + Option::EmptyStruct { } => Option::EmptyStruct {}, + Option::Some(v) => Option::Some(v), + Option::Point { x: px } => Option::Point { x: px }, + Option::None => Option::None, + } +} +)"); +} + +TEST_F(ModuleFmtTest, SemanticSumCommentsAroundVariants) { + DoFmt(R"(enum Option { + // Before the first variant. + None, // Inline unit comment. + // Before the payload variant. + Some(u32), + // After the last variant. +} +)"); +} + +TEST_F(ModuleFmtTest, EmptySemanticSumWithComment) { + DoFmt(R"(enum Never { + // No value can inhabit this type. +} +)"); +} + +TEST_F(ModuleFmtTest, SemanticSumTuplePayloadComments) { + DoFmt( + R"(enum Option { + Some( + // Before the first tuple member. + u32, + // Before the second tuple member. + u64, // After the final tuple member. + ), + Empty( + // Comment in an empty tuple payload. + ), + Inline( + u32, // After the first tuple member. + u64, + ), + Trailing( + u32, + // After the final tuple member. + ), +} +)", + R"(enum Option { + Some( + // Before the first tuple member. + u32, + // Before the second tuple member. + u64, // After the final tuple member. + ), + Empty( + // Comment in an empty tuple payload. + ), + Inline( + u32, // After the first tuple member. + u64 + ), + Trailing( + u32, + // After the final tuple member. + ), +} +)"); +} + +TEST_F(ModuleFmtTest, SemanticSumStructPayloadComments) { + DoFmt(R"(enum Option { + Point { + // Before the first struct field. + x: u32, // Inline field comment. + // Before the second struct field. + y: u64, + // After the final struct field. + }, + Empty { + // Comment in an empty struct payload. + }, + Separated { + value + // Before the colon. + : + // Before the type. + u8, + }, +} +)"); +} + +TEST_F(ModuleFmtTest, SemanticSumCommentsAroundExplicitDiscriminants) { + DoFmt( + R"(enum Message : u3 { + Idle // Before unit equals. + = // After unit equals. + u3:0, + Request // Before tuple payload. + (u8) // Before equals. + = // After equals. + u3:1, + Empty() = u3:2, +} +)", + R"(enum Message : u3 { + Idle + // Before unit equals. + = + // After unit equals. + u3:0, + Request + // Before tuple payload. + (u8) + // Before equals. + = + // After equals. + u3:1, + Empty() = u3:2, +} +)"); +} + TEST_F(ModuleFmtTest, FunctionRefWithExplicitParametrics) { DoFmt( R"(fn f() -> u32 { X } diff --git a/xls/dslx/frontend/ast.cc b/xls/dslx/frontend/ast.cc index c2f4db3493..275230e8e9 100644 --- a/xls/dslx/frontend/ast.cc +++ b/xls/dslx/frontend/ast.cc @@ -282,6 +282,8 @@ std::string_view AstNodeKindToString(AstNodeKind kind) { return "string"; case AstNodeKind::kStructInstance: return "struct instance"; + case AstNodeKind::kSumInstance: + return "sum instance"; case AstNodeKind::kSplatStructInstance: return "splat struct instance"; case AstNodeKind::kNameDefTree: @@ -332,12 +334,18 @@ std::string_view AstNodeKindToString(AstNodeKind kind) { return "all-ones macro"; case AstNodeKind::kSlice: return "slice"; + case AstNodeKind::kSumVariantPayloadPattern: + return "constructor-pattern"; case AstNodeKind::kEnumDef: return "enum definition"; case AstNodeKind::kStructDef: return "struct definition"; case AstNodeKind::kStructMember: return "struct member"; + case AstNodeKind::kSumDef: + return "sum definition"; + case AstNodeKind::kSumVariant: + return "sum variant"; case AstNodeKind::kProcDef: return "proc definition"; case AstNodeKind::kQuickCheck: @@ -389,6 +397,7 @@ AnyNameDef TypeDefinitionGetNameDef(const TypeDefinition& td) { [](StructDef* n) -> AnyNameDef { return n->name_def(); }, [](ProcDef* n) -> AnyNameDef { return n->name_def(); }, [](EnumDef* n) -> AnyNameDef { return n->name_def(); }, + [](SumDef* n) -> AnyNameDef { return n->name_def(); }, [](ColonRef* n) -> AnyNameDef { return GetSubjectNameDef(n->subject()); }, @@ -405,6 +414,7 @@ AstNode* TypeDefinitionToAstNode(const TypeDefinition& td) { [](StructDef* n) -> AstNode* { return n; }, [](ProcDef* n) -> AstNode* { return n; }, [](EnumDef* n) -> AstNode* { return n; }, + [](SumDef* n) -> AstNode* { return n; }, [](ColonRef* n) -> AstNode* { return n; }, [](UseTreeEntry* n) -> AstNode* { return n; }, }, @@ -424,6 +434,9 @@ absl::StatusOr ToTypeDefinition(AstNode* node) { if (node->kind() == AstNodeKind::kEnumDef) { return absl::down_cast(node); } + if (node->kind() == AstNodeKind::kSumDef) { + return absl::down_cast(node); + } if (node->kind() == AstNodeKind::kColonRef) { return absl::down_cast(node); } @@ -1549,6 +1562,177 @@ std::string EnumDef::ToString() const { return result; } +// -- class SumVariant + +SumVariant::SumVariant(Module* owner, Span span, NameDef* name_def, + PayloadShape payload_shape, + std::vector tuple_members, + std::vector struct_members, + Expr* discriminant, std::optional payload_span, + std::optional discriminant_equals_span) + : AstNode(owner), + span_(std::move(span)), + name_def_(name_def), + payload_shape_(payload_shape), + discriminant_(discriminant), + payload_span_(std::move(payload_span)), + discriminant_equals_span_(std::move(discriminant_equals_span)), + tuple_members_(std::move(tuple_members)), + struct_members_(std::move(struct_members)) { + if (payload_shape_ == PayloadShape::kUnit) { + CHECK(tuple_members_.empty()); + CHECK(struct_members_.empty()); + } else if (payload_shape_ == PayloadShape::kTuple) { + CHECK(struct_members_.empty()); + } else { + CHECK_EQ(payload_shape_, PayloadShape::kStruct); + CHECK(tuple_members_.empty()); + } + CHECK(!discriminant_equals_span_.has_value() || discriminant_ != nullptr); +} + +SumVariant::~SumVariant() = default; + +std::vector SumVariant::GetChildren(bool want_types) const { + std::vector results = {name_def_}; + if (discriminant_ != nullptr) { + results.push_back(discriminant_); + } + if (!want_types) { + return results; + } + if (is_tuple()) { + for (TypeAnnotation* member : tuple_members_) { + results.push_back(member); + } + } else if (is_struct()) { + for (StructMemberNode* member : struct_members_) { + results.push_back(member); + } + } + return results; +} + +std::string SumVariant::ToString() const { + if (is_unit()) { + return discriminant_ == nullptr + ? identifier() + : absl::StrCat(identifier(), " = ", discriminant_->ToString()); + } + if (is_tuple()) { + std::string result = absl::StrCat( + identifier(), "(", + absl::StrJoin(tuple_members_, ", ", + [](std::string* out, TypeAnnotation* member) { + absl::StrAppend(out, member->ToString()); + }), + ")"); + if (discriminant_ != nullptr) { + absl::StrAppend(&result, " = ", discriminant_->ToString()); + } + return result; + } + if (struct_members_.empty()) { + return discriminant_ == nullptr + ? absl::StrCat(identifier(), " { }") + : absl::StrCat(identifier(), + " { } = ", discriminant_->ToString()); + } + std::string result = absl::StrCat( + identifier(), " { ", + absl::StrJoin(struct_members_, ", ", + [](std::string* out, StructMemberNode* member) { + absl::StrAppend(out, member->ToString()); + }), + " }"); + if (discriminant_ != nullptr) { + absl::StrAppend(&result, " = ", discriminant_->ToString()); + } + return result; +} + +// -- class SumDef + +SumDef::SumDef(Module* owner, Span span, NameDef* name_def, + std::vector parametric_bindings, + std::vector variants, bool is_public, + TypeAnnotation* tag_type_annotation) + : AstNode(owner), + span_(std::move(span)), + name_def_(name_def), + tag_type_annotation_(tag_type_annotation), + parametric_bindings_(std::move(parametric_bindings)), + variants_(std::move(variants)), + is_public_(is_public) {} + +SumDef::~SumDef() = default; + +std::vector SumDef::GetChildren(bool want_types) const { + std::vector results = {name_def_}; + if (want_types && tag_type_annotation_ != nullptr) { + results.push_back(tag_type_annotation_); + } + for (ParametricBinding* binding : parametric_bindings_) { + results.push_back(binding); + } + for (SumVariant* variant : variants_) { + results.push_back(variant); + } + return results; +} + +bool SumDef::HasVariant(std::string_view target) const { + return std::any_of(variants_.begin(), variants_.end(), + [&](const SumVariant* variant) { + return variant->identifier() == target; + }); +} + +std::optional SumDef::GetVariant(std::string_view target) { + for (SumVariant* variant : variants_) { + if (variant->identifier() == target) { + return variant; + } + } + return std::nullopt; +} + +std::optional SumDef::GetVariant( + std::string_view target) const { + for (const SumVariant* variant : variants_) { + if (variant->identifier() == target) { + return variant; + } + } + return std::nullopt; +} + +std::string SumDef::ToString() const { + std::string parametric_str; + if (!parametric_bindings_.empty()) { + parametric_str = absl::StrCat( + "<", + absl::StrJoin(parametric_bindings_, ", ", + [](std::string* out, ParametricBinding* binding) { + absl::StrAppend(out, binding->ToString()); + }), + ">"); + } + const std::string tag_type = + tag_type_annotation_ == nullptr + ? "" + : absl::StrCat(" : ", tag_type_annotation_->ToString()); + std::string result = absl::StrFormat( + "%s%senum %s%s%s {\n", MakeExternTypeAttr(extern_type_name_), + is_public_ ? "pub " : "", identifier(), parametric_str, tag_type); + for (const SumVariant* variant : variants_) { + absl::StrAppendFormat(&result, "%s%s,\n", kRustOneIndent, + variant->ToString()); + } + absl::StrAppend(&result, "}"); + return result; +} + // -- class Instantiation Instantiation::Instantiation(Module* owner, Span span, Expr* callee, @@ -2050,6 +2234,10 @@ std::vector StructInstance::GetChildren(bool want_types) const { std::string StructInstance::ToStringInternal() const { std::string type_name = ToAstNode(struct_ref())->ToString(); + if (GetUnorderedMembers().empty()) { + return absl::StrFormat("%s { }", type_name); + } + std::string members_str = absl::StrJoin( GetUnorderedMembers(), ", ", [](std::string* out, const std::pair& member) { @@ -2059,6 +2247,66 @@ std::string StructInstance::ToStringInternal() const { return absl::StrFormat("%s { %s }", type_name, members_str); } +// -- class SumInstance + +SumInstance::SumInstance( + Module* owner, Span span, ColonRef* constructor_ref, + PayloadShape payload_shape, std::vector tuple_payload_args, + std::vector struct_payload_field_args, + bool in_parens) + : Expr(owner, std::move(span), in_parens), + constructor_ref_(constructor_ref), + payload_shape_(payload_shape), + tuple_payload_args_(std::move(tuple_payload_args)), + struct_payload_field_args_(std::move(struct_payload_field_args)) { + if (payload_shape_ == PayloadShape::kUnit) { + CHECK(tuple_payload_args_.empty()); + CHECK(struct_payload_field_args_.empty()); + } else if (payload_shape_ == PayloadShape::kTuple) { + CHECK(struct_payload_field_args_.empty()); + } else { + CHECK_EQ(payload_shape_, PayloadShape::kStruct); + CHECK(tuple_payload_args_.empty()); + } +} + +SumInstance::~SumInstance() = default; + +std::vector SumInstance::GetChildren(bool want_types) const { + std::vector results = {constructor_ref_}; + for (Expr* arg : tuple_payload_args_) { + results.push_back(arg); + } + for (const auto& [_, arg] : struct_payload_field_args_) { + results.push_back(arg); + } + return results; +} + +std::string SumInstance::ToStringInternal() const { + if (is_unit()) { + return constructor_ref_->ToString(); + } else if (is_tuple()) { + return absl::StrCat(constructor_ref_->ToString(), "(", + absl::StrJoin(tuple_payload_args_, ", ", + [](std::string* out, Expr* arg) { + absl::StrAppend(out, arg->ToString()); + }), + ")"); + } else if (struct_payload_field_args_.empty()) { + return absl::StrCat(constructor_ref_->ToString(), " { }"); + } else { + return absl::StrCat( + constructor_ref_->ToString(), " { ", + absl::StrJoin(struct_payload_field_args_, ", ", + [](std::string* out, const StructPayloadFieldArg& arg) { + absl::StrAppendFormat(out, "%s: %s", arg.first, + arg.second->ToString()); + }), + " }"); + } +} + // -- class SplatStructInstance SplatStructInstance::SplatStructInstance( @@ -2797,6 +3045,65 @@ std::string XlsTuple::ToStringInternal() const { return result; } +// -- class SumVariantPayloadPattern + +SumVariantPayloadPattern::SumVariantPayloadPattern( + Module* owner, Span span, ColonRef* constructor_ref, + PayloadShape payload_shape, + std::vector tuple_payload_patterns, + std::vector struct_payload_field_patterns) + : AstNode(owner), + span_(std::move(span)), + constructor_ref_(constructor_ref), + payload_shape_(payload_shape), + tuple_payload_patterns_(std::move(tuple_payload_patterns)), + struct_payload_field_patterns_(std::move(struct_payload_field_patterns)) { + if (payload_shape_ == PayloadShape::kTuple) { + CHECK(struct_payload_field_patterns_.empty()); + } else { + CHECK_EQ(payload_shape_, PayloadShape::kStruct); + CHECK(tuple_payload_patterns_.empty()); + } +} + +SumVariantPayloadPattern::~SumVariantPayloadPattern() = default; + +std::vector SumVariantPayloadPattern::GetChildren( + bool want_types) const { + std::vector results = {constructor_ref_}; + for (NameDefTree* pattern : tuple_payload_patterns_) { + results.push_back(pattern); + } + for (const auto& [_, pattern] : struct_payload_field_patterns_) { + results.push_back(pattern); + } + return results; +} + +std::string SumVariantPayloadPattern::ToString() const { + if (is_tuple()) { + return absl::StrCat( + constructor_ref_->ToString(), "(", + absl::StrJoin(tuple_payload_patterns_, ", ", + [](std::string* out, NameDefTree* pattern) { + absl::StrAppend(out, pattern->ToString()); + }), + ")"); + } + if (struct_payload_field_patterns_.empty()) { + return absl::StrCat(constructor_ref_->ToString(), " { }"); + } + return absl::StrCat( + constructor_ref_->ToString(), " { ", + absl::StrJoin( + struct_payload_field_patterns_, ", ", + [](std::string* out, const StructPayloadFieldPattern& pattern) { + absl::StrAppendFormat(out, "%s: %s", pattern.first, + pattern.second->ToString()); + }), + " }"); +} + // -- class NameDefTree NameDefTree::~NameDefTree() = default; diff --git a/xls/dslx/frontend/ast.h b/xls/dslx/frontend/ast.h index 317e4853eb..f2070b7a6e 100644 --- a/xls/dslx/frontend/ast.h +++ b/xls/dslx/frontend/ast.h @@ -50,7 +50,7 @@ // Higher-order macro for all the Expr node leaf types (non-abstract). #define XLS_DSLX_EXPR_NODE_EACH(X) \ - /* keep-sorted start */ \ + /* keep-sorted start */ \ X(AllOnesMacro) \ X(Array) \ X(Attr) \ @@ -75,6 +75,7 @@ X(StatementBlock) \ X(String) \ X(StructInstance) \ + X(SumInstance) \ X(TupleIndex) \ X(Unop) \ X(VerbatimNode) \ @@ -87,11 +88,12 @@ // (Note that this includes all the Expr node leaf kinds listed in // XLS_DSLX_EXPR_NODE_EACH). #define XLS_DSLX_AST_NODE_EACH(X) \ - /* keep-sorted start */ \ + /* keep-sorted start */ \ X(Attribute) \ X(BuiltinNameDef) \ X(ConstAssert) \ X(ConstantDef) \ + X(SumVariantPayloadPattern) \ X(EnumDef) \ X(Function) \ X(FuzzTestFunction) \ @@ -114,6 +116,8 @@ X(Statement) \ X(StructDef) \ X(StructMemberNode) \ + X(SumDef) \ + X(SumVariant) \ X(TestFunction) \ X(TestProc) \ X(Trait) \ @@ -123,9 +127,9 @@ X(UseTreeEntry) \ X(WidthSlice) \ X(WildcardPattern) \ - /* keep-sorted end */ \ + /* keep-sorted end */ \ /* type annotations */ \ - /* keep-sorted start */ \ + /* keep-sorted start */ \ X(AnyTypeAnnotation) \ X(ArrayTypeAnnotation) \ X(BuiltinTypeAnnotation) \ @@ -143,7 +147,7 @@ X(TupleTypeAnnotation) \ X(TypeRefTypeAnnotation) \ X(TypeVariableTypeAnnotation) \ - /* keep-sorted end */ \ + /* keep-sorted end */ \ XLS_DSLX_EXPR_NODE_EACH(X) namespace xls::dslx { @@ -1790,7 +1794,7 @@ class Array final : public Expr { // Several different AST nodes define types that can be referred to by a // TypeRef. using TypeDefinition = std::variant; + SumDef*, ColonRef*, UseTreeEntry*>; // Returns the name definition that (most locally) defined this type definition // AST node. @@ -2682,6 +2686,71 @@ class MatchArm : public AstNode { Expr* expr_; // Expression that is executed if one of the patterns matches. }; +// Represents a payload-carrying sum-constructor pattern such as: +// +// Option::Some(value) +// Message::Point { x: px, y: py } +// +// The constructor itself is always spelled as a `ColonRef`. +// +// `payload_shape()` is the source of truth for tuple-vs-struct spelling. Empty +// child vectors are valid for both `Case()` and `Case { }`, so callers must +// not recover the shape from vector emptiness alone. +class SumVariantPayloadPattern : public AstNode { + public: + enum class PayloadShape : uint8_t { + kTuple, + kStruct, + }; + + using StructPayloadFieldPattern = std::pair; + + SumVariantPayloadPattern( + Module* owner, Span span, ColonRef* constructor_ref, + PayloadShape payload_shape, + std::vector tuple_payload_patterns, + std::vector struct_payload_field_patterns); + + ~SumVariantPayloadPattern() override; + + AstNodeKind kind() const override { + return AstNodeKind::kSumVariantPayloadPattern; + } + + absl::Status Accept(AstNodeVisitor* v) const override { + return v->HandleSumVariantPayloadPattern(this); + } + + std::string_view GetNodeTypeName() const override { + return "SumVariantPayloadPattern"; + } + + std::string ToString() const override; + + std::vector GetChildren(bool want_types) const override; + + ColonRef* constructor_ref() const { return constructor_ref_; } + const std::vector& tuple_payload_patterns() const { + return tuple_payload_patterns_; + } + const std::vector& struct_payload_field_patterns() + const { + return struct_payload_field_patterns_; + } + PayloadShape payload_shape() const { return payload_shape_; } + bool is_tuple() const { return payload_shape_ == PayloadShape::kTuple; } + bool is_struct() const { return payload_shape_ == PayloadShape::kStruct; } + const Span& span() const { return span_; } + std::optional GetSpan() const override { return span_; } + + private: + Span span_; + ColonRef* constructor_ref_; + PayloadShape payload_shape_; + std::vector tuple_payload_patterns_; + std::vector struct_payload_field_patterns_; +}; + // Represents a match (pattern match) expression. // // A match expression has zero or more *arms*. @@ -3225,6 +3294,147 @@ class EnumDef : public AstNode { std::optional extern_type_name_; }; +// Represents a single constructor inside a semantic sum declaration. +// +// `payload_shape()` is the source of truth for unit-vs-tuple-vs-struct +// spelling. Empty member vectors are valid for `Case()` and `Case { }`, so +// callers must not infer unit-ness from vector emptiness alone. +class SumVariant : public AstNode { + public: + enum class PayloadShape : uint8_t { + kUnit, + kTuple, + kStruct, + }; + + SumVariant(Module* owner, Span span, NameDef* name_def, + PayloadShape payload_shape, + std::vector tuple_members, + std::vector struct_members, + Expr* discriminant = nullptr, + std::optional payload_span = std::nullopt, + std::optional discriminant_equals_span = std::nullopt); + + ~SumVariant() override; + + AstNodeKind kind() const override { return AstNodeKind::kSumVariant; } + + absl::Status Accept(AstNodeVisitor* v) const override { + return v->HandleSumVariant(this); + } + + std::string_view GetNodeTypeName() const override { return "SumVariant"; } + + std::string ToString() const override; + + std::vector GetChildren(bool want_types) const override; + + NameDef* name_def() const { return name_def_; } + const std::string& identifier() const { return name_def_->identifier(); } + const Span& span() const { return span_; } + std::optional GetSpan() const override { return span_; } + + PayloadShape payload_shape() const { return payload_shape_; } + bool is_unit() const { return payload_shape_ == PayloadShape::kUnit; } + bool is_tuple() const { return payload_shape_ == PayloadShape::kTuple; } + bool is_struct() const { return payload_shape_ == PayloadShape::kStruct; } + Expr* discriminant() const { return discriminant_; } + // These source-only spans are absent for synthesized variants. + const std::optional& payload_span() const { return payload_span_; } + const std::optional& discriminant_equals_span() const { + return discriminant_equals_span_; + } + + const std::vector& tuple_members() const { + return tuple_members_; + } + const std::vector& struct_members() const { + return struct_members_; + } + + private: + Span span_; + NameDef* name_def_; + PayloadShape payload_shape_; + Expr* discriminant_; + std::optional payload_span_; + std::optional discriminant_equals_span_; + std::vector tuple_members_; + std::vector struct_members_; +}; + +// Represents a semantic sum declaration; e.g. +// +// enum Option { +// None, +// Some(T), +// } +class SumDef : public AstNode { + public: + static std::string_view GetDebugTypeName() { return "sum"; } + + SumDef(Module* owner, Span span, NameDef* name_def, + std::vector parametric_bindings, + std::vector variants, bool is_public, + TypeAnnotation* tag_type_annotation = nullptr); + + ~SumDef() override; + + AstNodeKind kind() const override { return AstNodeKind::kSumDef; } + + absl::Status Accept(AstNodeVisitor* v) const override { + return v->HandleSumDef(this); + } + + std::string_view GetNodeTypeName() const override { return "SumDef"; } + + std::string ToString() const override; + + std::vector GetChildren(bool want_types) const override; + + const std::string& identifier() const { return name_def_->identifier(); } + NameDef* name_def() const { return name_def_; } + const Span& span() const { return span_; } + std::optional GetSpan() const override { return span_; } + + bool is_public() const { return is_public_; } + bool IsParametric() const { return !parametric_bindings_.empty(); } + TypeAnnotation* tag_type_annotation() const { return tag_type_annotation_; } + + const std::vector& parametric_bindings() const { + return parametric_bindings_; + } + const std::vector& variants() const { return variants_; } + + absl::flat_hash_set ParametricKeys() const { + absl::flat_hash_set result; + for (const ParametricBinding* binding : parametric_bindings_) { + result.emplace(binding->name_def()->identifier()); + } + return result; + } + + bool HasVariant(std::string_view target) const; + std::optional GetVariant(std::string_view target); + std::optional GetVariant(std::string_view target) const; + + void set_extern_type_name(std::string_view n) { + extern_type_name_ = std::string(n); + } + const std::optional& extern_type_name() const { + return extern_type_name_; + } + + private: + Span span_; + NameDef* name_def_; + TypeAnnotation* tag_type_annotation_; + std::vector parametric_bindings_; + std::vector variants_; + bool is_public_; + std::optional extern_type_name_; +}; + // Helper struct for DSLX-struct items defined inside of DSLX-structs. struct StructMember { Span name_span; @@ -3709,6 +3919,75 @@ class StructInstance : public StructInstanceBase { std::string ToStringInternal() const final; }; +// Represents construction of a semantic sum value, such as: +// +// Option::None +// Option::Some(value) +// Message::Point { x: px, y: py } +// +// The constructor itself is always spelled as a `ColonRef`. +// +// `payload_shape()` is the source of truth for unit-vs-tuple-vs-struct +// spelling. Empty child vectors are valid for both `Case()` and `Case { }`, so +// callers must not recover the shape from vector emptiness alone. +class SumInstance : public Expr { + public: + enum class PayloadShape : uint8_t { + kUnit, + kTuple, + kStruct, + }; + + using StructPayloadFieldArg = std::pair; + + SumInstance(Module* owner, Span span, ColonRef* constructor_ref, + PayloadShape payload_shape, std::vector tuple_payload_args, + std::vector struct_payload_field_args, + bool in_parens = false); + + ~SumInstance() override; + + AstNodeKind kind() const override { return AstNodeKind::kSumInstance; } + + absl::Status Accept(AstNodeVisitor* v) const override { + return v->HandleSumInstance(this); + } + + absl::Status AcceptExpr(ExprVisitor* v) const override { + return v->HandleSumInstance(this); + } + + std::string_view GetNodeTypeName() const override { return "SumInstance"; } + + std::vector GetChildren(bool want_types) const override; + + ColonRef* constructor_ref() const { return constructor_ref_; } + const std::vector& tuple_payload_args() const { + return tuple_payload_args_; + } + const std::vector& struct_payload_field_args() const { + return struct_payload_field_args_; + } + PayloadShape payload_shape() const { return payload_shape_; } + bool is_unit() const { return payload_shape_ == PayloadShape::kUnit; } + bool is_tuple() const { return payload_shape_ == PayloadShape::kTuple; } + bool is_struct() const { return payload_shape_ == PayloadShape::kStruct; } + + bool IsBlockedExprWithLeader() const override { return !is_unit(); } + + Precedence GetPrecedenceWithoutParens() const final { + return Precedence::kStrongest; + } + + private: + std::string ToStringInternal() const final; + + ColonRef* constructor_ref_; + PayloadShape payload_shape_; + std::vector tuple_payload_args_; + std::vector struct_payload_field_args_; +}; + // Represents a struct instantiation as a "delta" from a 'splatted' original; // e.g. // Point { y: new_y, ..orig_p } @@ -4375,8 +4654,9 @@ class ConstantDef : public AstNode { class NameDefTree : public AstNode { public: using Nodes = std::vector; - using Leaf = std::variant; + using Leaf = + std::variant; NameDefTree(Module* owner, Span span, std::variant tree) : AstNode(owner), span_(std::move(span)), tree_(std::move(tree)) {} diff --git a/xls/dslx/frontend/ast_cloner.cc b/xls/dslx/frontend/ast_cloner.cc index 8fe9c44180..6992185442 100644 --- a/xls/dslx/frontend/ast_cloner.cc +++ b/xls/dslx/frontend/ast_cloner.cc @@ -310,6 +310,34 @@ class AstCloner : public AstNodeVisitor { return absl::OkStatus(); } + absl::Status HandleSumVariantPayloadPattern( + const SumVariantPayloadPattern* n) override { + XLS_RETURN_IF_ERROR(VisitChildren(n)); + + std::vector new_tuple_payload_patterns; + new_tuple_payload_patterns.reserve(n->tuple_payload_patterns().size()); + for (const NameDefTree* pattern : n->tuple_payload_patterns()) { + new_tuple_payload_patterns.push_back( + absl::down_cast(old_to_new_.at(pattern))); + } + + std::vector + new_struct_payload_field_patterns; + new_struct_payload_field_patterns.reserve( + n->struct_payload_field_patterns().size()); + for (const auto& [name, pattern] : n->struct_payload_field_patterns()) { + new_struct_payload_field_patterns.push_back(std::make_pair( + name, absl::down_cast(old_to_new_.at(pattern)))); + } + + old_to_new_[n] = module(n)->Make( + n->span(), + absl::down_cast(old_to_new_.at(n->constructor_ref())), + n->payload_shape(), std::move(new_tuple_payload_patterns), + std::move(new_struct_payload_field_patterns)); + return absl::OkStatus(); + } + absl::Status HandleEnumDef(const EnumDef* n) override { XLS_RETURN_IF_ERROR(VisitChildren(n)); @@ -341,6 +369,68 @@ class AstCloner : public AstNodeVisitor { return absl::OkStatus(); } + absl::Status HandleSumVariant(const SumVariant* n) override { + XLS_RETURN_IF_ERROR(VisitChildren(n)); + + std::vector new_tuple_members; + new_tuple_members.reserve(n->tuple_members().size()); + for (const TypeAnnotation* member : n->tuple_members()) { + new_tuple_members.push_back( + absl::down_cast(old_to_new_.at(member))); + } + + std::vector new_struct_members; + new_struct_members.reserve(n->struct_members().size()); + for (const StructMemberNode* member : n->struct_members()) { + new_struct_members.push_back( + absl::down_cast(old_to_new_.at(member))); + } + + old_to_new_[n] = module(n)->Make( + n->span(), absl::down_cast(old_to_new_.at(n->name_def())), + n->payload_shape(), std::move(new_tuple_members), + std::move(new_struct_members), + n->discriminant() == nullptr + ? nullptr + : absl::down_cast(old_to_new_.at(n->discriminant())), + n->payload_span(), n->discriminant_equals_span()); + return absl::OkStatus(); + } + + absl::Status HandleSumDef(const SumDef* n) override { + XLS_RETURN_IF_ERROR(VisitChildren(n)); + + std::vector new_parametric_bindings; + new_parametric_bindings.reserve(n->parametric_bindings().size()); + for (const ParametricBinding* binding : n->parametric_bindings()) { + new_parametric_bindings.push_back( + absl::down_cast(old_to_new_.at(binding))); + } + + std::vector new_variants; + new_variants.reserve(n->variants().size()); + for (const SumVariant* variant : n->variants()) { + new_variants.push_back( + absl::down_cast(old_to_new_.at(variant))); + } + + auto* new_name_def = + absl::down_cast(old_to_new_.at(n->name_def())); + auto* new_sum_def = module(n)->Make( + n->span(), new_name_def, std::move(new_parametric_bindings), + std::move(new_variants), n->is_public(), + n->tag_type_annotation() == nullptr + ? nullptr + : absl::down_cast( + old_to_new_.at(n->tag_type_annotation()))); + if (n->extern_type_name().has_value()) { + new_sum_def->set_extern_type_name(*n->extern_type_name()); + } + new_name_def->set_definer(new_sum_def); + old_to_new_[n] = new_sum_def; + return absl::OkStatus(); + } + absl::Status HandleFor(const For* n) override { XLS_RETURN_IF_ERROR(VisitChildren(n)); @@ -1017,6 +1107,33 @@ class AstCloner : public AstNodeVisitor { return absl::OkStatus(); } + absl::Status HandleSumInstance(const SumInstance* n) override { + XLS_RETURN_IF_ERROR(VisitChildren(n)); + + std::vector new_tuple_payload_args; + new_tuple_payload_args.reserve(n->tuple_payload_args().size()); + for (const Expr* arg : n->tuple_payload_args()) { + new_tuple_payload_args.push_back( + absl::down_cast(old_to_new_.at(arg))); + } + + std::vector + new_struct_payload_field_args; + new_struct_payload_field_args.reserve( + n->struct_payload_field_args().size()); + for (const auto& [name, arg] : n->struct_payload_field_args()) { + new_struct_payload_field_args.push_back( + std::make_pair(name, absl::down_cast(old_to_new_.at(arg)))); + } + + old_to_new_[n] = module(n)->Make( + n->span(), + absl::down_cast(old_to_new_.at(n->constructor_ref())), + n->payload_shape(), std::move(new_tuple_payload_args), + std::move(new_struct_payload_field_args), n->in_parens()); + return absl::OkStatus(); + } + absl::Status HandleConditional(const Conditional* n) override { XLS_RETURN_IF_ERROR(VisitChildren(n)); diff --git a/xls/dslx/frontend/ast_cloner_test.cc b/xls/dslx/frontend/ast_cloner_test.cc index 91342e8444..65419982bd 100644 --- a/xls/dslx/frontend/ast_cloner_test.cc +++ b/xls/dslx/frontend/ast_cloner_test.cc @@ -2637,6 +2637,43 @@ fn make() -> Foo { StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Foo"))); } +TEST(AstClonerTest, CloneModuleWithSemanticSums) { + constexpr std::string_view kProgram = R"(enum Option { + None, + Some(u32), + Point { x: u32, y: u32 }, +} + +fn unwrap_or_sum(x: Option) -> u32 { + match x { + Option::Some(v) => v, + Option::Point { x: px, y: py } => px + py, + Option::None => u32:0, + } +})"; + constexpr std::string_view kExpected = R"(enum Option { + None, + Some(u32), + Point { x: u32, y: u32 }, +} +fn unwrap_or_sum(x: Option) -> u32 { + match x { + Option::Some(v) => v, + Option::Point { x: px, y: py } => px + py, + Option::None => u32:0, + } +})"; + + FileTable file_table; + XLS_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseModule(kProgram, "fake_path.x", "the_module", file_table)); + XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr clone, + CloneModule(*module.get())); + EXPECT_EQ(kExpected, clone->ToString()); + XLS_ASSERT_OK(VerifyClone(module.get(), clone.get(), file_table)); +} + TEST(AstClonerTest, ParametricFunctionClonesParamTypeDefiner) { constexpr std::string_view kProgram = R"( #![feature(generics)] diff --git a/xls/dslx/frontend/ast_node.h b/xls/dslx/frontend/ast_node.h index d1885f40f7..08311ad8ed 100644 --- a/xls/dslx/frontend/ast_node.h +++ b/xls/dslx/frontend/ast_node.h @@ -47,6 +47,7 @@ enum class AstNodeKind : uint8_t { kConstAssert, kConstFor, kConstantDef, + kSumVariantPayloadPattern, kEnumDef, kFor, kFormatMacro, @@ -92,6 +93,9 @@ enum class AstNodeKind : uint8_t { kStructDef, kStructInstance, kStructMember, + kSumDef, + kSumInstance, + kSumVariant, kTestFunction, kTestProc, kTrait, diff --git a/xls/dslx/frontend/ast_utils.cc b/xls/dslx/frontend/ast_utils.cc index 4676d42c18..e5b71978c2 100644 --- a/xls/dslx/frontend/ast_utils.cc +++ b/xls/dslx/frontend/ast_utils.cc @@ -203,6 +203,7 @@ absl::StatusOr ResolveLocalStructDef(TypeDefinition td) { [&](StructDef* n) -> absl::StatusOr { return n; }, [&](ProcDef* n) -> absl::StatusOr { return error(n); }, [&](EnumDef* n) -> absl::StatusOr { return error(n); }, + [&](SumDef* n) -> absl::StatusOr { return error(n); }, [&](ColonRef* n) -> absl::StatusOr { return error(n); }, [&](UseTreeEntry* n) -> absl::StatusOr { return error(n); @@ -240,6 +241,9 @@ absl::Status VerifyParentage(const Module* module) { if (std::holds_alternative(member)) { return VerifyParentage(std::get(member)); } + if (std::holds_alternative(member)) { + return VerifyParentage(std::get(member)); + } if (std::holds_alternative(member)) { return VerifyParentage(std::get(member)); } diff --git a/xls/dslx/frontend/bindings.cc b/xls/dslx/frontend/bindings.cc index 65fa068468..131cd8c836 100644 --- a/xls/dslx/frontend/bindings.cc +++ b/xls/dslx/frontend/bindings.cc @@ -123,6 +123,7 @@ std::string BoundNodeGetTypeString(const BoundNode& bn) { [](ConstantDef*) { return "ConstantDef"; }, [](StructDef*) { return "StructDef"; }, [](ProcDef*) { return "ProcDef"; }, + [](SumDef*) { return "SumDef"; }, [](NameDef*) { return "NameDef"; }, [](BuiltinNameDef*) { return "BuiltinNameDef"; }, [](Import*) { return "Import"; }, diff --git a/xls/dslx/frontend/bindings.h b/xls/dslx/frontend/bindings.h index 353cafbac5..adb8d33785 100644 --- a/xls/dslx/frontend/bindings.h +++ b/xls/dslx/frontend/bindings.h @@ -42,7 +42,7 @@ namespace xls::dslx { // definitions. using BoundNode = std::variant; + StructDef*, ProcDef*, SumDef*, Import*, UseTreeEntry*>; // Returns a string, useful for reporting in error messages, for the type of the // AST node contained inside of the given BoundNode variant; e.g. "Import". @@ -234,7 +234,8 @@ class Bindings { return std::holds_alternative(*bn) || std::holds_alternative(*bn) || std::holds_alternative(*bn) || - std::holds_alternative(*bn); + std::holds_alternative(*bn) || + std::holds_alternative(*bn); } // As above, but flags a ParseError() if the binding cannot be resolved, diff --git a/xls/dslx/frontend/module.cc b/xls/dslx/frontend/module.cc index 84d469b6eb..2b61399684 100644 --- a/xls/dslx/frontend/module.cc +++ b/xls/dslx/frontend/module.cc @@ -218,6 +218,10 @@ const EnumDef* Module::FindEnumDef(const Span& span) const { return absl::down_cast(FindNode(AstNodeKind::kEnumDef, span)); } +const SumDef* Module::FindSumDef(const Span& span) const { + return absl::down_cast(FindNode(AstNodeKind::kSumDef, span)); +} + bool Module::IsPublicMember(const AstNode& node) const { for (const ModuleMember& member : top_) { const AstNode* member_node = ToAstNode(member); @@ -287,6 +291,9 @@ Module::GetTypeDefinitionByName() const { } else if (std::holds_alternative(member)) { EnumDef* enum_def = std::get(member); result[enum_def->identifier()] = enum_def; + } else if (std::holds_alternative(member)) { + SumDef* sum_def = std::get(member); + result[sum_def->identifier()] = sum_def; } else if (std::holds_alternative(member)) { StructDef* struct_def = std::get(member); result[struct_def->identifier()] = struct_def; @@ -307,6 +314,9 @@ std::vector Module::GetTypeDefinitions() const { } else if (std::holds_alternative(member)) { EnumDef* enum_def = std::get(member); results.push_back(enum_def); + } else if (std::holds_alternative(member)) { + SumDef* sum_def = std::get(member); + results.push_back(sum_def); } else if (std::holds_alternative(member)) { StructDef* struct_def = std::get(member); results.push_back(struct_def); @@ -436,6 +446,7 @@ std::string_view GetModuleMemberTypeName(const ModuleMember& module_member) { [](Impl*) { return "impl"; }, [](ConstantDef*) { return "constant-definition"; }, [](EnumDef*) { return "enum-definition"; }, + [](SumDef*) { return "sum-definition"; }, [](Import*) { return "import"; }, [](Use*) { return "use"; }, [](Trait*) { return "trait"; }, diff --git a/xls/dslx/frontend/module.h b/xls/dslx/frontend/module.h index f1c33787f9..939c3ffaae 100644 --- a/xls/dslx/frontend/module.h +++ b/xls/dslx/frontend/module.h @@ -49,7 +49,7 @@ namespace xls::dslx { using ModuleMember = std::variant; // Returns all the NameDefs defined by the given module member. @@ -273,6 +273,7 @@ class Module : public AstNode { const ProcDef* FindProcDef(const Span& span) const; const EnumDef* FindEnumDef(const Span& span) const; + const SumDef* FindSumDef(const Span& span) const; // Obtains all the type definition nodes in the module; e.g. TypeAlias, // StructDef, EnumDef. @@ -304,6 +305,7 @@ class Module : public AstNode { std::vector GetStructDefs() const { return GetTopWithT(); } + std::vector GetSumDefs() const { return GetTopWithT(); } std::vector GetProcDefs() { return GetTopWithT(); } std::vector GetProcs() const { return GetTopWithT(); } diff --git a/xls/dslx/frontend/parser.cc b/xls/dslx/frontend/parser.cc index 005b31bd72..0a5de54768 100644 --- a/xls/dslx/frontend/parser.cc +++ b/xls/dslx/frontend/parser.cc @@ -169,6 +169,7 @@ absl::StatusOr BoundNodeToTypeDefinition(BoundNode bn) { if (auto* e = TryGet(bn)) { return TypeDefinition(e); } if (auto* e = TryGet(bn)) { return TypeDefinition(e); } if (auto* e = TryGet(bn)) { return TypeDefinition(e); } + if (auto* e = TryGet(bn)) { return TypeDefinition(e); } if (auto* e = TryGet(bn)) { return TypeDefinition(e); } // clang-format on @@ -176,6 +177,19 @@ absl::StatusOr BoundNodeToTypeDefinition(BoundNode bn) { ToAstNode(bn)->ToString()); } +bool IsLocalSumConstructorRef(const Expr* expr, const Bindings& bindings) { + auto* colon_ref = dynamic_cast(expr); + if (colon_ref == nullptr) { + return false; + } + if (!std::holds_alternative(colon_ref->subject())) { + return false; + } + NameRef* subject = std::get(colon_ref->subject()); + std::optional def = bindings.ResolveNode(subject->identifier()); + return def.has_value() && std::holds_alternative(*def); +} + bool IsComparisonBinopKind(const Expr* e) { auto* binop = dynamic_cast(e); if (binop == nullptr) { @@ -534,14 +548,14 @@ absl::StatusOr> Parser::ParseModule( absl::StrFormat("Expected start of top-level construct; got: '%s'", peek->ToString())); }; - if (peek->kind() != TokenKind::kKeyword) { - return top_level_error(); - } - XLS_RET_CHECK(bindings != nullptr); XLS_RET_CHECK(module_member_start_pos.has_value()); XLS_RETURN_IF_ERROR(CheckForDuplicateAttributes(pending_attributes)); + if (peek->kind() != TokenKind::kKeyword) { + return top_level_error(); + } + switch (peek->GetKeyword()) { case Keyword::kFn: { XLS_ASSIGN_OR_RETURN(Function * fn, @@ -622,11 +636,17 @@ absl::StatusOr> Parser::ParseModule( } case Keyword::kEnum: { XLS_ASSIGN_OR_RETURN( - EnumDef * enum_def, + auto enum_or_sum, ParseEnumDef(*module_member_start_pos, is_public, *bindings)); - XLS_RETURN_IF_ERROR( - ApplyTypeAttributes(enum_def, pending_attributes, *bindings)); - XLS_RETURN_IF_ERROR(module_->AddTop(enum_def, make_collision_error)); + XLS_RETURN_IF_ERROR(absl::visit( + [&](auto* node) -> absl::Status { + return ApplyTypeAttributes(node, pending_attributes, *bindings); + }, + enum_or_sum)); + XLS_RETURN_IF_ERROR(module_->AddTop( + absl::visit([](auto* node) -> ModuleMember { return node; }, + enum_or_sum), + make_collision_error)); break; } case Keyword::kConst: { @@ -1237,6 +1257,11 @@ absl::Status Parser::ApplyTypeAttributes(T* node, return absl::OkStatus(); } +template absl::Status Parser::ApplyTypeAttributes( + EnumDef* node, std::vector attributes, Bindings& bindings); +template absl::Status Parser::ApplyTypeAttributes( + SumDef* node, std::vector attributes, Bindings& bindings); + absl::StatusOr Parser::ApplyProcAttributes( Proc* p, std::vector attributes) { bool is_test = false; @@ -1501,7 +1526,7 @@ absl::StatusOr Parser::ParseTypeRef(Bindings& bindings, /*internal=*/false); } } - if (!IsOneOf( + if (!IsOneOf( ToAstNode(type_def))) { return ParseErrorStatus( tok.span(), @@ -1709,7 +1734,7 @@ absl::StatusOr Parser::ParseColonRef(Bindings& bindings, } absl::StatusOr Parser::ParseCastOrEnumRefOrStructInstanceOrToken( - Bindings& bindings) { + Bindings& bindings, ExprRestrictions restrictions) { VLOG(5) << "ParseCastOrEnumRefOrStructInstanceOrToken @ " << GetPos() << " peek: " << PeekToken().value()->ToString(); @@ -1718,11 +1743,20 @@ absl::StatusOr Parser::ParseCastOrEnumRefOrStructInstanceOrToken( PeekTokenIs(TokenKind::kDoubleColon)); if (peek_is_double_colon) { XLS_ASSIGN_OR_RETURN(NameRef * subject, ParseNameRef(bindings, &tok)); - // TODO: google/xls#1030 - This assumes it is calling an imported function, - // (e.g., `lib::function();`) not instantiating an imported struct (e.g., - // `lib::my_struct{};`). Instantiating both locally-defined and imported - // structs really should be unified in this method. - return ParseColonRef(bindings, subject, subject->span()); + XLS_ASSIGN_OR_RETURN(ColonRef * colon_ref, + ParseColonRef(bindings, subject, subject->span())); + XLS_ASSIGN_OR_RETURN(bool peek_is_obrace, PeekTokenIs(TokenKind::kOBrace)); + if (peek_is_obrace && + !IsExprRestrictionEnabled(restrictions, + ExprRestriction::kNoStructLiteral)) { + auto* type_ref = module_->Make(colon_ref->span(), colon_ref); + XLS_ASSIGN_OR_RETURN( + TypeAnnotation * type, + MakeTypeRefTypeAnnotation(colon_ref->span(), type_ref, + /*dims=*/{}, /*parametrics=*/{})); + return ParseStructInstance(bindings, type); + } + return colon_ref; } XLS_ASSIGN_OR_RETURN(TypeAnnotation * type, @@ -1739,7 +1773,26 @@ absl::StatusOr Parser::ParseCastOrEnumRefOrStructInstanceOrToken( } XLS_ASSIGN_OR_RETURN(bool peek_is_obrace, PeekTokenIs(TokenKind::kOBrace)); if (peek_is_obrace) { - return ParseStructInstance(bindings, type); + if (!IsExprRestrictionEnabled(restrictions, + ExprRestriction::kNoStructLiteral)) { + return ParseStructInstance(bindings, type); + } + + Transaction parse_struct_txn(this, &bindings); + absl::Cleanup rollback_by_default([&parse_struct_txn] { + if (!parse_struct_txn.completed()) { + parse_struct_txn.Rollback(); + } + }); + absl::StatusOr struct_instance = + ParseStructInstance(*parse_struct_txn.bindings(), type); + if (struct_instance.ok()) { + XLS_ASSIGN_OR_RETURN(bool peek_is_dot, PeekTokenIs(TokenKind::kDot)); + if (peek_is_dot) { + parse_struct_txn.Commit(); + return *struct_instance; + } + } } XLS_ASSIGN_OR_RETURN(bool peek_is_oparen, PeekTokenIs(TokenKind::kOParen)); if (peek_is_oparen) { @@ -2181,12 +2234,50 @@ absl::StatusOr Parser::ParsePattern(Bindings& bindings, // TODO: https://github.com/google/xls/issues/1459 - Handle rest-of-tuple. XLS_ASSIGN_OR_RETURN(bool peek_is_double_colon, PeekTokenIs(TokenKind::kDoubleColon)); - if (peek_is_double_colon) { // Mod or enum ref. + auto make_constructor_pattern = + [&](ColonRef* constructor_ref) -> absl::StatusOr { + XLS_ASSIGN_OR_RETURN(bool peek_is_oparen, + PeekTokenIs(TokenKind::kOParen)); + if (peek_is_oparen) { + XLS_ASSIGN_OR_RETURN( + SumVariantPayloadPattern * pattern, + ParseTupleSumVariantPayloadPattern(bindings, constructor_ref)); + return module_->Make(pattern->span(), pattern); + } + XLS_ASSIGN_OR_RETURN(bool peek_is_obrace, + PeekTokenIs(TokenKind::kOBrace)); + if (peek_is_obrace) { + XLS_ASSIGN_OR_RETURN( + SumVariantPayloadPattern * pattern, + ParseStructSumVariantPayloadPattern(bindings, constructor_ref)); + return module_->Make(pattern->span(), pattern); + } + return module_->Make(constructor_ref->span(), + constructor_ref); + }; + if (bindings.ResolveNodeIsTypeDefinition(*tok.GetValue()) && + !peek_is_double_colon) { + XLS_ASSIGN_OR_RETURN(TypeAnnotation * type, + ParseTypeAnnotation(bindings, tok)); + auto* type_ref = dynamic_cast(type); + XLS_ASSIGN_OR_RETURN(bool type_is_constructor_ref, + PeekTokenIs(TokenKind::kDoubleColon)); + if (type_ref != nullptr && type_is_constructor_ref) { + XLS_ASSIGN_OR_RETURN( + ColonRef * colon_ref, + ParseColonRef(bindings, type_ref, type_ref->span())); + return make_constructor_pattern(colon_ref); + } + return ParseErrorStatus( + type->span(), + absl::StrFormat("Expected constructor pattern after sum type; got %s", + type->ToString())); + } + if (peek_is_double_colon) { XLS_ASSIGN_OR_RETURN(NameRef * subject, ParseNameRef(bindings, &tok)); XLS_ASSIGN_OR_RETURN(ColonRef * colon_ref, ParseColonRef(bindings, subject, subject->span())); - Span span(tok.span().start(), colon_ref->span().limit()); - return module_->Make(span, colon_ref); + return make_constructor_pattern(colon_ref); } std::string identifier = tok.GetValue().value(); @@ -2266,32 +2357,52 @@ absl::StatusOr Parser::ParseMatch(Bindings& bindings, bool is_const) { "an additional match case.")); break; } + auto reject_pattern_bindings = + [this](const NameDefTree& pattern, + const Bindings& pattern_bindings) -> absl::Status { + std::vector locals = + MapKeysSorted(pattern_bindings.local_bindings()); + return ParseErrorStatus( + pattern.span(), + absl::StrFormat("Cannot bind names in a match arm with multiple " + "patterns; bound: %s", + absl::StrJoin(locals, ", "))); + }; + Bindings arm_bindings(&bindings); + Bindings first_pattern_bindings(&arm_bindings); XLS_ASSIGN_OR_RETURN( NameDefTree * first_pattern, - ParsePattern(arm_bindings, /*within_tuple_pattern=*/false)); + ParsePattern(first_pattern_bindings, /*within_tuple_pattern=*/false)); std::vector patterns = {first_pattern}; + bool has_multiple_patterns = false; while (true) { XLS_ASSIGN_OR_RETURN(bool dropped_bar, TryDropToken(TokenKind::kBar)); if (!dropped_bar) { break; } - if (arm_bindings.HasLocalBindings()) { - // TODO(leary): 2020-09-12 Loosen this restriction? They just have to - // bind the same exact set of names. - std::vector locals = - MapKeysSorted(arm_bindings.local_bindings()); - return ParseErrorStatus( - first_pattern->span(), - absl::StrFormat("Cannot have multiple patterns that bind names; " - "previously bound: %s", - absl::StrJoin(locals, ", "))); + has_multiple_patterns = true; + if (first_pattern_bindings.HasLocalBindings()) { + // TODO: Lift this Phase 1 restriction once OR-pattern arms have shared + // binding semantics. A multi-pattern arm should be allowed when every + // alternative binds the exact same set of names and those names have + // the same inferred types in every alternative. Until then, reject + // bindings in multi-pattern arms so an arm body cannot reference a name + // that exists only for some alternatives. + return reject_pattern_bindings(*first_pattern, first_pattern_bindings); } + Bindings pattern_bindings(&arm_bindings); XLS_ASSIGN_OR_RETURN( NameDefTree * pattern, - ParsePattern(arm_bindings, /*within_tuple_pattern=*/false)); + ParsePattern(pattern_bindings, /*within_tuple_pattern=*/false)); + if (pattern_bindings.HasLocalBindings()) { + return reject_pattern_bindings(*pattern, pattern_bindings); + } patterns.push_back(pattern); } + if (!has_multiple_patterns) { + arm_bindings.ConsumeChild(&first_pattern_bindings); + } XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kFatArrow)); XLS_ASSIGN_OR_RETURN(Expr * rhs, ParseExpression(arm_bindings)); @@ -2583,8 +2694,8 @@ absl::StatusOr Parser::ParseTermLhs(Bindings& outer_bindings, VLOG(5) << "ParseTerm, kind is identifier AND it it a known type"; // The "local" struct case goes into here because it recognized the // my_struct as a type - XLS_ASSIGN_OR_RETURN( - lhs, ParseCastOrEnumRefOrStructInstanceOrToken(outer_bindings)); + XLS_ASSIGN_OR_RETURN(lhs, ParseCastOrEnumRefOrStructInstanceOrToken( + outer_bindings, restrictions)); } else if (peek->kind() == TokenKind::kIdentifier || peek_is_kw_in || peek_is_kw_out || peek_is_kw_self) { VLOG(5) << "ParseTerm, kind is identifier but not a known type"; @@ -2677,8 +2788,8 @@ absl::StatusOr Parser::ParseTermLhs(Bindings& outer_bindings, lhs = ToExprNode(name_or_colon_ref); } } else if (peek->kind() == TokenKind::kOParen) { - XLS_ASSIGN_OR_RETURN( - lhs, ParseParentheticalOrCastLhs(outer_bindings, start_pos)); + XLS_ASSIGN_OR_RETURN(lhs, ParseParentheticalOrCastLhs( + outer_bindings, start_pos, restrictions)); } else if (peek->IsKeyword(Keyword::kMatch)) { // Match expression. XLS_ASSIGN_OR_RETURN(lhs, ParseMatch(outer_bindings, /*is_const=*/false)); } else if (peek->kind() == TokenKind::kOBrack) { // Array expression. @@ -2715,7 +2826,8 @@ absl::StatusOr Parser::ParseTermLhs(Bindings& outer_bindings, } absl::StatusOr Parser::ParseParentheticalOrCastLhs( - Bindings& outer_bindings, const Pos& start_pos) { + Bindings& outer_bindings, const Pos& start_pos, + ExprRestrictions restrictions) { // This is potentially a cast to tuple type, e.g. `(u8, u16):(u8:5, u16:1)`. // Try to parse as a parenthesized expression first and fall back to tuple // cast. In the case that both fail, return the status for the parenthetical @@ -2732,7 +2844,7 @@ absl::StatusOr Parser::ParseParentheticalOrCastLhs( } parse_txn.Rollback(); absl::StatusOr cast = - ParseCastOrEnumRefOrStructInstanceOrToken(outer_bindings); + ParseCastOrEnumRefOrStructInstanceOrToken(outer_bindings, restrictions); if (cast.ok()) { return cast; } @@ -2878,6 +2990,13 @@ absl::StatusOr Parser::ParseTermRhs(Expr* lhs, Bindings& outer_bindings, goto done; } + if (IsLocalSumConstructorRef(lhs, outer_bindings)) { + return ParseErrorStatus( + Span(new_pos, GetPos()), + "Explicit parametrics belong on the sum type, not the " + "constructor; use `Name::Variant(...)`."); + } + XLS_ASSIGN_OR_RETURN(bool has_open_paren, PeekTokenIs(TokenKind::kOParen)); // Disallow explicit parametrics (even empty) for AST-node builtins that @@ -4154,39 +4273,243 @@ absl::StatusOr Parser::ParseFor(Bindings& bindings) { } } -absl::StatusOr Parser::ParseEnumDef(const Pos& start_pos, - bool is_public, - Bindings& bindings) { - XLS_ASSIGN_OR_RETURN(Token enum_tok, PopKeywordOrError(Keyword::kEnum)); +absl::StatusOr> Parser::ParseEnumDef( + const Pos& start_pos, bool is_public, Bindings& bindings) { + XLS_RETURN_IF_ERROR(PopKeywordOrError(Keyword::kEnum).status()); // We don't bind the enum's name until the end to prevent recursive references // to itself in its body. XLS_ASSIGN_OR_RETURN(NameDef * name_def, ParseNameDefNoBind()); + Bindings enum_bindings(&bindings); + + XLS_ASSIGN_OR_RETURN(bool dropped_oangle, TryDropToken(TokenKind::kOAngle)); + std::vector parametric_bindings; + if (dropped_oangle) { + XLS_ASSIGN_OR_RETURN(parametric_bindings, + ParseParametricBindings(enum_bindings)); + } + XLS_ASSIGN_OR_RETURN(bool saw_colon, TryDropToken(TokenKind::kColon)); TypeAnnotation* type_annotation = nullptr; if (saw_colon) { - XLS_ASSIGN_OR_RETURN(type_annotation, ParseTypeAnnotation(bindings)); + XLS_ASSIGN_OR_RETURN(type_annotation, ParseTypeAnnotation(enum_bindings)); } XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kOBrace)); - Bindings enum_bindings(&bindings); - auto parse_enum_entry = [this, - &enum_bindings]() -> absl::StatusOr { - XLS_ASSIGN_OR_RETURN(NameDef * name_def, ParseNameDef(enum_bindings)); - XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kEquals)); - XLS_ASSIGN_OR_RETURN(Expr * expr, ParseExpression(enum_bindings)); - return EnumMember{.name_def = name_def, .value = expr}; + struct ParsedEntry { + Span span; + NameDef* name_def; + SumVariant::PayloadShape payload_shape; + std::vector tuple_members; + std::vector struct_members; + Expr* discriminant; + std::optional payload_span; + std::optional discriminant_equals_span; + }; + + auto parse_struct_member = + [this, &enum_bindings]() -> absl::StatusOr { + Pos node_start_pos = GetPos(); + XLS_ASSIGN_OR_RETURN(NameDef * member_name, ParseNameDefNoBind()); + XLS_ASSIGN_OR_RETURN( + Token colon, + PopTokenOrError( + TokenKind::kColon, /*start=*/nullptr, + "Expect type annotation on semantic-sum variant field")); + XLS_ASSIGN_OR_RETURN(TypeAnnotation * type, + ParseTypeAnnotation(enum_bindings)); + return module_->Make(Span(node_start_pos, GetPos()), + member_name, colon.span(), type); + }; + + auto parse_entry = [this, &enum_bindings, &parse_struct_member, + &name_def]() -> absl::StatusOr { + Pos variant_start = GetPos(); + XLS_ASSIGN_OR_RETURN(NameDef * variant_name, ParseNameDefNoBind()); + + XLS_ASSIGN_OR_RETURN(std::optional oparen, + TryPopToken(TokenKind::kOParen)); + if (oparen.has_value()) { + auto parse_payload = + [this, &enum_bindings]() -> absl::StatusOr { + return ParseTypeAnnotation(enum_bindings); + }; + Pos payload_limit; + XLS_ASSIGN_OR_RETURN( + std::vector tuple_members, + ParseCommaSeq(parse_payload, TokenKind::kCParen, + &payload_limit)); + const Span payload_span(oparen->span().start(), payload_limit); + XLS_ASSIGN_OR_RETURN(std::optional equals, + TryPopToken(TokenKind::kEquals)); + Expr* discriminant = nullptr; + std::optional discriminant_equals_span; + if (equals.has_value()) { + discriminant_equals_span = equals->span(); + XLS_ASSIGN_OR_RETURN(discriminant, ParseExpression(enum_bindings)); + } + return ParsedEntry{ + .span = Span(variant_start, GetPos()), + .name_def = variant_name, + .payload_shape = SumVariant::PayloadShape::kTuple, + .tuple_members = std::move(tuple_members), + .struct_members = {}, + .discriminant = discriminant, + .payload_span = payload_span, + .discriminant_equals_span = discriminant_equals_span, + }; + } + + XLS_ASSIGN_OR_RETURN(std::optional obrace, + TryPopToken(TokenKind::kOBrace)); + if (obrace.has_value()) { + Pos payload_limit; + XLS_ASSIGN_OR_RETURN( + std::vector struct_members, + ParseCommaSeq(parse_struct_member, + TokenKind::kCBrace, &payload_limit)); + absl::flat_hash_set seen_member_names; + for (StructMemberNode* struct_member : struct_members) { + if (!seen_member_names.insert(struct_member->name()).second) { + return ParseErrorStatus( + struct_member->name_def()->span(), + absl::StrFormat("Sum variant `%s` field `%s` is defined more " + "than once in sum `%s`", + variant_name->identifier(), struct_member->name(), + name_def->identifier())); + } + } + const Span payload_span(obrace->span().start(), payload_limit); + XLS_ASSIGN_OR_RETURN(std::optional equals, + TryPopToken(TokenKind::kEquals)); + Expr* discriminant = nullptr; + std::optional discriminant_equals_span; + if (equals.has_value()) { + discriminant_equals_span = equals->span(); + XLS_ASSIGN_OR_RETURN(discriminant, ParseExpression(enum_bindings)); + } + return ParsedEntry{ + .span = Span(variant_start, GetPos()), + .name_def = variant_name, + .payload_shape = SumVariant::PayloadShape::kStruct, + .tuple_members = {}, + .struct_members = std::move(struct_members), + .discriminant = discriminant, + .payload_span = payload_span, + .discriminant_equals_span = discriminant_equals_span, + }; + } + + XLS_ASSIGN_OR_RETURN(std::optional equals, + TryPopToken(TokenKind::kEquals)); + Expr* discriminant = nullptr; + std::optional discriminant_equals_span; + if (equals.has_value()) { + discriminant_equals_span = equals->span(); + XLS_ASSIGN_OR_RETURN(discriminant, ParseExpression(enum_bindings)); + } + return ParsedEntry{ + .span = Span(variant_start, GetPos()), + .name_def = variant_name, + .payload_shape = SumVariant::PayloadShape::kUnit, + .tuple_members = {}, + .struct_members = {}, + .discriminant = discriminant, + .payload_span = std::nullopt, + .discriminant_equals_span = discriminant_equals_span, + }; }; XLS_ASSIGN_OR_RETURN( - std::vector entries, - ParseCommaSeq(parse_enum_entry, TokenKind::kCBrace)); - auto* enum_def = module_->Make(Span(start_pos, GetPos()), name_def, - type_annotation, entries, is_public); - bindings.Add(name_def->identifier(), enum_def); - name_def->set_definer(enum_def); - return enum_def; + std::vector entries, + ParseCommaSeq(parse_entry, TokenKind::kCBrace)); + + const bool has_payload_syntax = + absl::c_any_of(entries, [](const ParsedEntry& entry) { + return entry.payload_shape != SumVariant::PayloadShape::kUnit; + }); + const bool is_semantic_sum = + has_payload_syntax || (entries.empty() && type_annotation == nullptr); + const bool has_discriminants = absl::c_any_of( + entries, + [](const ParsedEntry& entry) { return entry.discriminant != nullptr; }); + const bool all_have_discriminants = absl::c_all_of( + entries, + [](const ParsedEntry& entry) { return entry.discriminant != nullptr; }); + + if (!is_semantic_sum) { + if (!parametric_bindings.empty()) { + return ParseErrorStatus( + name_def->span(), + "Parametric bindings are only supported on semantic sums."); + } + if (!all_have_discriminants) { + return ParseErrorStatus( + name_def->span(), + absl::StrFormat( + "Numeric enum `%s` requires a value for every member.", + name_def->identifier())); + } + std::vector enum_members; + enum_members.reserve(entries.size()); + for (const ParsedEntry& entry : entries) { + enum_members.push_back(EnumMember{ + .name_def = entry.name_def, + .value = entry.discriminant, + }); + } + auto* enum_def = module_->Make(Span(start_pos, GetPos()), name_def, + type_annotation, + std::move(enum_members), is_public); + bindings.Add(name_def->identifier(), enum_def); + name_def->set_definer(enum_def); + return enum_def; + } + + if (!entries.empty() && has_discriminants != all_have_discriminants) { + return ParseErrorStatus( + name_def->span(), + absl::StrFormat( + "Semantic sum `%s` must use either all implicit or all explicit " + "discriminants.", + name_def->identifier())); + } else if (type_annotation != nullptr && !entries.empty() && + !all_have_discriminants) { + return ParseErrorStatus( + name_def->span(), + absl::StrFormat( + "Semantic sum `%s` with a tag type annotation requires explicit " + "discriminants on every variant.", + name_def->identifier())); + } + + std::vector variants; + variants.reserve(entries.size()); + for (ParsedEntry& entry : entries) { + variants.push_back(module_->Make( + entry.span, entry.name_def, entry.payload_shape, + std::move(entry.tuple_members), std::move(entry.struct_members), + entry.discriminant, entry.payload_span, + entry.discriminant_equals_span)); + } + absl::flat_hash_set seen_variant_names; + for (SumVariant* variant : variants) { + if (!seen_variant_names.insert(variant->identifier()).second) { + return ParseErrorStatus( + variant->name_def()->span(), + absl::StrFormat( + "Sum variant `%s` is defined more than once in sum `%s`", + variant->identifier(), name_def->identifier())); + } + } + + auto* sum_def = module_->Make( + Span(start_pos, GetPos()), name_def, std::move(parametric_bindings), + std::move(variants), is_public, type_annotation); + bindings.Add(name_def->identifier(), sum_def); + name_def->set_definer(sum_def); + return sum_def; } absl::StatusOr Parser::MakeBuiltinTypeAnnotation( @@ -4531,6 +4854,100 @@ absl::StatusOr Parser::ParseTuplePattern(const Pos& start_pos, return module_->Make(span, std::move(members)); } +absl::StatusOr +Parser::ParseTupleSumVariantPayloadPattern(Bindings& bindings, + ColonRef* constructor_ref) { + XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kOParen)); + + std::vector tuple_payload_patterns; + bool must_end = false; + while (true) { + XLS_ASSIGN_OR_RETURN(bool dropped_cparen, TryDropToken(TokenKind::kCParen)); + if (dropped_cparen) { + break; + } + if (must_end) { + XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kCParen)); + break; + } + XLS_ASSIGN_OR_RETURN(NameDefTree * pattern, + ParsePattern(bindings, /*within_tuple_pattern=*/true)); + tuple_payload_patterns.push_back(pattern); + XLS_ASSIGN_OR_RETURN(bool dropped_comma, TryDropToken(TokenKind::kComma)); + must_end = !dropped_comma; + } + + return module_->Make( + Span(constructor_ref->span().start(), GetPos()), constructor_ref, + SumVariantPayloadPattern::PayloadShape::kTuple, + std::move(tuple_payload_patterns), + std::vector{}); +} + +absl::StatusOr +Parser::ParseStructSumVariantPayloadPattern(Bindings& bindings, + ColonRef* constructor_ref) { + XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kOBrace)); + + auto make_identifier_pattern = + [&](const Token& tok) -> absl::StatusOr { + std::string identifier = tok.GetValue().value(); + if (std::optional resolved = bindings.ResolveNode(identifier); + resolved.has_value()) { + AnyNameDef any_name_def = + bindings.ResolveNameOrNullopt(identifier).value(); + NameRef* ref = + module_->Make(tok.span(), identifier, any_name_def); + return module_->Make(tok.span(), ref); + } + + XLS_ASSIGN_OR_RETURN(NameDef * name_def, TokenToNameDef(tok)); + bindings.Add(name_def->identifier(), name_def); + auto* result = module_->Make(tok.span(), name_def); + name_def->set_definer(result); + return result; + }; + + std::vector + struct_payload_field_patterns; + bool must_end = false; + while (true) { + XLS_ASSIGN_OR_RETURN(bool dropped_cbrace, TryDropToken(TokenKind::kCBrace)); + if (dropped_cbrace) { + break; + } + if (must_end) { + XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kCBrace)); + break; + } + + XLS_ASSIGN_OR_RETURN( + Token member_tok, + PopTokenOrError(TokenKind::kIdentifier, /*start=*/nullptr, + "Expected constructor-pattern field name")); + XLS_ASSIGN_OR_RETURN(bool dropped_colon, TryDropToken(TokenKind::kColon)); + + NameDefTree* member_pattern; + if (dropped_colon) { + XLS_ASSIGN_OR_RETURN( + member_pattern, + ParsePattern(bindings, /*within_tuple_pattern=*/false)); + } else { + XLS_ASSIGN_OR_RETURN(member_pattern, make_identifier_pattern(member_tok)); + } + struct_payload_field_patterns.push_back( + std::make_pair(*member_tok.GetValue(), member_pattern)); + + XLS_ASSIGN_OR_RETURN(bool dropped_comma, TryDropToken(TokenKind::kComma)); + must_end = !dropped_comma; + } + + return module_->Make( + Span(constructor_ref->span().start(), GetPos()), constructor_ref, + SumVariantPayloadPattern::PayloadShape::kStruct, + std::vector{}, std::move(struct_payload_field_patterns)); +} + absl::StatusOr Parser::ParseBlockExpression(Bindings& bindings, bool has_braces) { Bindings block_bindings(&bindings); @@ -4680,7 +5097,8 @@ absl::StatusOr Parser::ParseParametricArg(Bindings& bindings) { XLS_ASSIGN_OR_RETURN(auto nocr, ParseNameOrColonRef(bindings)); absl::StatusOr def = bindings.ResolveNodeOrError( ToAstNode(nocr)->ToString(), peek->span(), file_table()); - if (def.ok() && IsOneOf(ToAstNode(*def))) { + if (def.ok() && + IsOneOf(ToAstNode(*def))) { XLS_ASSIGN_OR_RETURN(TypeDefinition type_definition, BoundNodeToTypeDefinition(*def)); XLS_ASSIGN_OR_RETURN( diff --git a/xls/dslx/frontend/parser.h b/xls/dslx/frontend/parser.h index ad33d43739..66630d678c 100644 --- a/xls/dslx/frontend/parser.h +++ b/xls/dslx/frontend/parser.h @@ -328,7 +328,7 @@ class Parser : public TokenParser { const Span& subject_span); absl::StatusOr ParseCastOrEnumRefOrStructInstanceOrToken( - Bindings& bindings); + Bindings& bindings, ExprRestrictions restrictions); absl::StatusOr ParseStructInstance(Bindings& bindings, TypeAnnotation* type = nullptr); @@ -414,8 +414,9 @@ class Parser : public TokenParser { // Attempts to parse a parenthetical and falls back to parsing as cast on // failure. - absl::StatusOr ParseParentheticalOrCastLhs(Bindings& outer_bindings, - const Pos& start_pos); + absl::StatusOr ParseParentheticalOrCastLhs( + Bindings& outer_bindings, const Pos& start_pos, + ExprRestrictions restrictions); absl::StatusOr ParseTermLhsParenthesized(Bindings& bindings, const Pos& start_pos); @@ -544,6 +545,10 @@ class Parser : public TokenParser { absl::StatusOr ParseTuplePattern(const Pos& start_pos, Bindings& bindings); + absl::StatusOr ParseTupleSumVariantPayloadPattern( + Bindings& bindings, ColonRef* constructor_ref); + absl::StatusOr ParseStructSumVariantPayloadPattern( + Bindings& bindings, ColonRef* constructor_ref); // Returns a parsed pattern; e.g. one that would guard a match arm. // @@ -574,14 +579,20 @@ class Parser : public TokenParser { // ultimately the loop terminates and the final accum value is returned. absl::StatusOr ParseFor(Bindings& bindings); - // Parses an enum definition; e.g. + // Parses either a numeric enum definition or a semantic sum declaration using + // the shared `enum` introducer; e.g. // // enum Foo : u2 { // A = 0, // B = 1, // } - absl::StatusOr ParseEnumDef(const Pos& start_pos, bool is_public, - Bindings& bindings); + // + // enum Maybe { + // None, + // Some(u32), + // } + absl::StatusOr> ParseEnumDef( + const Pos& start_pos, bool is_public, Bindings& bindings); absl::StatusOr ParseStruct(const Pos& start_pos, bool is_public, Bindings& bindings); diff --git a/xls/dslx/frontend/parser_test.cc b/xls/dslx/frontend/parser_test.cc index 62763ba45b..155bdaee0d 100644 --- a/xls/dslx/frontend/parser_test.cc +++ b/xls/dslx/frontend/parser_test.cc @@ -22,9 +22,6 @@ #include #include -#include "gmock/gmock.h" -#include "gtest/gtest-spi.h" -#include "gtest/gtest.h" #include "absl/base/casts.h" #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" @@ -32,6 +29,9 @@ #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "gmock/gmock.h" +#include "gtest/gtest-spi.h" +#include "gtest/gtest.h" #include "xls/common/attribute_data.h" #include "xls/common/file/filesystem.h" #include "xls/common/file/get_runfile_path.h" @@ -683,6 +683,13 @@ impl Foo { } fn foo() -> Foo { Foo { } +})", + /*target=*/R"(proc Foo { +} +impl Foo { +} +fn foo() -> Foo { + Foo { } })"); } @@ -2568,6 +2575,273 @@ fn f(x: u32) { "`..` patterns are not allowed outside of a tuple pattern"); } +TEST_F(ParserTest, MatchWithSemanticSumConstructors) { + std::unique_ptr module = RoundTrip(R"(#![feature(generics)] + +enum Option { + None, + Some(T), + Point { x: u32, y: u32 }, +} +fn f(x: Option) -> Option { + match x { + Option::Some(v) => Option::Some(v), + Option::Point { x: px, y: py } => Option::Point { x: px, y: py }, + Option::None => Option::None, + } +})"); + std::optional maybe_member = + module->FindMemberWithName("Option"); + ASSERT_TRUE(maybe_member.has_value()); + EXPECT_TRUE(std::holds_alternative(*maybe_member.value())); +} + +TEST_F(ParserTest, SemanticSumSupportsExplicitDiscriminants) { + std::unique_ptr module = RoundTrip(R"(enum Message : u3 { + Idle() = 0, + Request(u8) = 3, + Response { field: u8 } = 7, +})"); + std::optional maybe_member = + module->FindMemberWithName("Message"); + ASSERT_TRUE(maybe_member.has_value()); + ASSERT_TRUE(std::holds_alternative(*maybe_member.value())); + const auto* sum_def = std::get(*maybe_member.value()); + ASSERT_TRUE(sum_def->GetVariant("Idle").has_value()); + const SumVariant* idle = *sum_def->GetVariant("Idle"); + const SumVariant* request = *sum_def->GetVariant("Request"); + EXPECT_NE(idle->discriminant(), nullptr); + EXPECT_TRUE(idle->payload_span().has_value()); + EXPECT_TRUE(idle->discriminant_equals_span().has_value()); + EXPECT_TRUE(request->payload_span().has_value()); + EXPECT_TRUE(request->discriminant_equals_span().has_value()); + EXPECT_EQ(sum_def->tag_type_annotation()->ToString(), "u3"); +} + +TEST_F(ParserTest, SemanticSumRejectsMixedDiscriminants) { + constexpr std::string_view kProgram = R"(enum Message { + Idle(), + Request(u8) = 3, +})"; + EXPECT_THAT(Parse(kProgram), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("must use either all implicit or all explicit " + "discriminants"))); +} + +TEST_F(ParserTest, SemanticSumRejectsImplicitDiscriminantsWithTagAnnotation) { + constexpr std::string_view kProgram = R"(enum Option : u1 { + None(), + Some(u8), +})"; + EXPECT_THAT(Parse(kProgram), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("with a tag type annotation requires explicit " + "discriminants on every variant"))); +} + +TEST_F(ParserTest, BareUnitEnumWithoutValuesIsStillRejected) { + EXPECT_THAT(Parse("enum E { A, B }"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("requires a value for every member"))); +} + +TEST_F(ParserTest, MatchOrPatternRejectsBindingInLaterAlternative) { + constexpr std::string_view kProgram = R"(enum Option { + None, + Some(u32), +} + +fn f(x: Option) -> u32 { + match x { + Option::None | Option::Some(v) => v, + _ => u32:0, + } +})"; + EXPECT_THAT(Parse(kProgram), + StatusIs(absl::StatusCode::kInvalidArgument, + testing::AllOf(HasSubstr("Cannot bind names in a match " + "arm with multiple patterns"), + HasSubstr("bound: v")))); +} + +TEST_F(ParserTest, MatchOrPatternRejectsBindingInFirstAlternative) { + constexpr std::string_view kProgram = R"(enum Option { + None, + Some(u32), +} + +fn f(x: Option) -> u32 { + match x { + Option::Some(v) | Option::None => v, + _ => u32:0, + } +})"; + EXPECT_THAT(Parse(kProgram), + StatusIs(absl::StatusCode::kInvalidArgument, + testing::AllOf(HasSubstr("Cannot bind names in a match " + "arm with multiple patterns"), + HasSubstr("bound: v")))); +} + +TEST_F(ParserTest, MatchOrPatternAllowsBindingFreeAlternatives) { + constexpr std::string_view kProgram = R"(enum Option { + None, + Some(u32), +} + +fn f(x: Option) -> u32 { + match x { + Option::None | Option::Some(_) => u32:1, + } +})"; + XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, Parse(kProgram)); + EXPECT_TRUE(module->GetFunction("f").has_value()); +} + +TEST_F(ParserTest, SumRemainsValidIdentifierOutsideTypeDeclarations) { + std::unique_ptr module = RoundTrip(R"(fn f(x: u32) -> u32 { + let sum = x + u32:1; + sum +})"); + EXPECT_TRUE(module->GetFunction("f").has_value()); +} + +TEST_F(ParserTest, EmptyEnumWithTagAnnotationIsNumericEnum) { + std::unique_ptr module = RoundTrip(R"(enum Never : u3 { +})"); + std::optional maybe_member = + module->FindMemberWithName("Never"); + ASSERT_TRUE(maybe_member.has_value()); + EXPECT_TRUE(std::holds_alternative(*maybe_member.value())); +} + +TEST_F(ParserTest, EmptyEnumWithoutTagAnnotationIsSemanticSum) { + std::unique_ptr module = RoundTrip(R"(enum Never { +})"); + std::optional maybe_member = + module->FindMemberWithName("Never"); + ASSERT_TRUE(maybe_member.has_value()); + EXPECT_TRUE(std::holds_alternative(*maybe_member.value())); +} + +TEST_F(ParserTest, PreserveEmptySemanticSumPayloadShapes) { + std::unique_ptr module = RoundTrip(R"(enum Option { + None, + EmptyTuple(), + EmptyStruct { }, + Some(u32), + Point { x: u32 }, +} +fn f(x: Option) -> Option { + match x { + Option::EmptyTuple() => Option::EmptyTuple(), + Option::EmptyStruct { } => Option::EmptyStruct { }, + Option::Some(v) => Option::Some(v), + Option::Point { x: px } => Option::Point { x: px }, + Option::None => Option::None, + } +})"); + std::optional maybe_member = + module->FindMemberWithName("Option"); + ASSERT_TRUE(maybe_member.has_value()); + ASSERT_TRUE(std::holds_alternative(*maybe_member.value())); + auto* sum_def = std::get(*maybe_member.value()); + + ASSERT_TRUE(sum_def->GetVariant("None").has_value()); + ASSERT_TRUE(sum_def->GetVariant("EmptyTuple").has_value()); + ASSERT_TRUE(sum_def->GetVariant("EmptyStruct").has_value()); + EXPECT_TRUE((*sum_def->GetVariant("None"))->is_unit()); + EXPECT_TRUE((*sum_def->GetVariant("EmptyTuple"))->is_tuple()); + EXPECT_TRUE((*sum_def->GetVariant("EmptyTuple"))->tuple_members().empty()); + EXPECT_TRUE((*sum_def->GetVariant("EmptyStruct"))->is_struct()); + EXPECT_TRUE((*sum_def->GetVariant("EmptyStruct"))->struct_members().empty()); + + std::optional maybe_f = module->GetFunction("f"); + ASSERT_TRUE(maybe_f.has_value()); + StatementBlock* body = maybe_f.value()->body(); + ASSERT_EQ(body->statements().size(), 1); + auto* match = dynamic_cast( + std::get(body->statements().at(0)->wrapped())); + ASSERT_NE(match, nullptr); + + ASSERT_EQ(match->arms().size(), 5); + + const NameDefTree* tuple_tree = match->arms().at(0)->patterns().at(0); + ASSERT_TRUE(tuple_tree->is_leaf()); + ASSERT_TRUE( + std::holds_alternative(tuple_tree->leaf())); + const auto* tuple_pattern = + std::get(tuple_tree->leaf()); + EXPECT_TRUE(tuple_pattern->is_tuple()); + EXPECT_TRUE(tuple_pattern->tuple_payload_patterns().empty()); + + const NameDefTree* struct_tree = match->arms().at(1)->patterns().at(0); + ASSERT_TRUE(struct_tree->is_leaf()); + ASSERT_TRUE( + std::holds_alternative(struct_tree->leaf())); + const auto* struct_pattern = + std::get(struct_tree->leaf()); + EXPECT_TRUE(struct_pattern->is_struct()); + EXPECT_TRUE(struct_pattern->struct_payload_field_patterns().empty()); +} + +TEST_F(ParserTest, RejectsConstructorLevelParametricsOnSemanticSums) { + constexpr std::string_view kExplicitOnConstructor = R"(#![feature(generics)] + +enum Option { + None, + Some(T), +} + +fn f(x: u32) -> Option { + Option::Some(x) +} +)"; + EXPECT_THAT( + Parse(kExplicitOnConstructor), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("ParseError"))); + + constexpr std::string_view kTurbofishOnConstructor = R"(#![feature(generics)] + +enum Option { + None, + Some(T), +} + +fn f(x: u32) -> Option { + Option::::Some(x) +} +)"; + EXPECT_THAT( + Parse(kTurbofishOnConstructor), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("ParseError"))); +} + +TEST_F(ParserTest, RejectsDuplicateSemanticSumConstructors) { + constexpr std::string_view kProgram = R"(enum Option { + Some, + Some(u32), + })"; + EXPECT_THAT( + Parse(kProgram), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "Sum variant `Some` is defined more than once in sum `Option`"))); +} + +TEST_F(ParserTest, RejectsDuplicateSemanticSumStructPayloadFields) { + constexpr std::string_view kProgram = R"(enum S { + V { x: u32, x: u64 }, + })"; + EXPECT_THAT(Parse(kProgram), + StatusIs(absl::StatusCode::kInvalidArgument, + AllOf(HasSubstr("Sum variant `V`"), + HasSubstr("field `x` is defined more than once"), + HasSubstr("sum `S`")))); +} + TEST_F(ParserTest, ArrayTypeAnnotation) { std::string s = "u8[2]"; scanner_.emplace(file_table_, Fileno(0), s); @@ -3201,6 +3475,23 @@ TEST_F(ParserTest, TernaryWithComparisonToColonRefTest) { RoundTripExpr("if a <= m::b { u32:42 } else { u32:24 }", {"a", "m"}); } +TEST_F(ParserTest, ConditionalWithEnumColonRefTest) { + constexpr std::string_view kProgram = R"(enum Foo : u32 { + THING = 0, + OTHER = 1, +} + +fn f(x: Foo) -> Foo { + if x == Foo::THING { Foo::OTHER } else { Foo::THING } +})"; + std::unique_ptr module = ExpectParsesSuccessfully(kProgram); + ASSERT_TRUE(module != nullptr); + ASSERT_TRUE(module->GetFunction("f").has_value()); + EXPECT_THAT( + (*module->GetFunction("f"))->body()->ToString(), + HasSubstr("if x == Foo::THING { Foo::OTHER } else { Foo::THING }")); +} + TEST_F(ParserTest, ForInWithColonRefAsRangeLimit) { RoundTripExpr(R"(for (x, s) in u32:0..m::SOME_CONST { x @@ -4092,10 +4383,11 @@ TEST(ParserErrorTest, BadTestTarget) { Scanner s{file_table, Fileno(0), std::string(kProgram)}; Parser parser{"test", &s}; absl::StatusOr> module = parser.ParseModule(); - EXPECT_THAT(module.status(), - IsPosError("ParseError", - HasSubstr("Attributes are only supported for a " - "function, proc, struct, enum, or type."))); + EXPECT_THAT( + module.status(), + IsPosError("ParseError", HasSubstr("Attributes are only supported for a " + "function, proc, struct, enum, or " + "type."))); } TEST(ParserErrorTest, BadAttributeTokenType) { diff --git a/xls/dslx/frontend/scanner_test.cc b/xls/dslx/frontend/scanner_test.cc index 103a28e4c4..34e32868af 100644 --- a/xls/dslx/frontend/scanner_test.cc +++ b/xls/dslx/frontend/scanner_test.cc @@ -141,6 +141,12 @@ TEST(ScannerTest, ScanKeyword) { EXPECT_TRUE(tokens[0].IsKeyword(Keyword::kFn)); } +TEST(ScannerTest, ScanSumIdentifier) { + XLS_ASSERT_OK_AND_ASSIGN(std::vector tokens, ToTokens("sum")); + ASSERT_EQ(tokens.size(), 1); + EXPECT_TRUE(tokens[0].IsIdentifier("sum")); +} + TEST(ScannerTest, CarryIsValidIdentifier) { XLS_ASSERT_OK_AND_ASSIGN(std::vector tokens, ToTokens("fn carry() {}")); diff --git a/xls/dslx/frontend/semantics_analysis.h b/xls/dslx/frontend/semantics_analysis.h index b0be8afe11..232f5f99ea 100644 --- a/xls/dslx/frontend/semantics_analysis.h +++ b/xls/dslx/frontend/semantics_analysis.h @@ -46,6 +46,8 @@ class SemanticsAnalysis { void SetNameDefType(const NameDef* def, const Type* type); + bool suppress_warnings() const { return suppress_warnings_; } + private: // Used by kUnusedDefinition. We cannot completely determine whether a // definition is truly unused at RunPreTypeCheckPass, because (1) tokens are diff --git a/xls/dslx/ir_convert/extract_conversion_order.cc b/xls/dslx/ir_convert/extract_conversion_order.cc index 251821b1a6..1a79bbfa72 100644 --- a/xls/dslx/ir_convert/extract_conversion_order.cc +++ b/xls/dslx/ir_convert/extract_conversion_order.cc @@ -346,6 +346,16 @@ class InvocationVisitor : public ExprVisitor { return absl::OkStatus(); } + absl::Status HandleSumInstance(const SumInstance* expr) override { + for (const Expr* arg : expr->tuple_payload_args()) { + XLS_RETURN_IF_ERROR(arg->AcceptExpr(this)); + } + for (const auto& member : expr->struct_payload_field_args()) { + XLS_RETURN_IF_ERROR(member.second->AcceptExpr(this)); + } + return absl::OkStatus(); + } + absl::Status HandleConditional(const Conditional* expr) override { // constexpr if selects only one branch of the if to handle if (expr->IsConst()) { @@ -762,6 +772,7 @@ absl::StatusOr> GetOrder(Module* module, [](ProcAlias*) { return absl::OkStatus(); }, [](Impl*) { return absl::OkStatus(); }, [](EnumDef*) { return absl::OkStatus(); }, + [](SumDef*) { return absl::OkStatus(); }, [](Import*) { return absl::OkStatus(); }, [](Trait*) { return absl::OkStatus(); }, [](Use*) { return absl::OkStatus(); }, diff --git a/xls/dslx/ir_convert/function_converter.cc b/xls/dslx/ir_convert/function_converter.cc index 05b4d4b178..1433d69416 100644 --- a/xls/dslx/ir_convert/function_converter.cc +++ b/xls/dslx/ir_convert/function_converter.cc @@ -401,6 +401,7 @@ class FunctionConverterVisitor : public AstNodeVisitor { // from their parent nodes). // keep-sorted start INVALID(Attribute) + INVALID(SumVariantPayloadPattern) INVALID(FunctionRef) INVALID(FuzzTestFunction) INVALID(MatchArm) @@ -452,6 +453,9 @@ class FunctionConverterVisitor : public AstNodeVisitor { INVALID(QuickCheck) INVALID(StructDef) INVALID(StructMemberNode) + INVALID(SumDef) + INVALID(SumInstance) + INVALID(SumVariant) INVALID(Trait) INVALID(TypeAlias) INVALID(Use) @@ -1558,6 +1562,10 @@ absl::StatusOr FunctionConverter::HandleRangedForInductionVariable( return absl::InternalError( "Induction variable cannot be a colon-reference"); }, + [&](SumVariantPayloadPattern*) -> absl::StatusOr { + return absl::InternalError( + "Induction variable cannot be a constructor pattern"); + }, [&](NameRef*) -> absl::StatusOr { return absl::InternalError( "Induction variable cannot be a name-reference"); @@ -1907,6 +1915,10 @@ absl::StatusOr FunctionConverter::HandleMatcher( return function_builder_->Literal(UBits(1, 1), loc); }); }, + [&](SumVariantPayloadPattern*) -> absl::StatusOr { + return absl::UnimplementedError( + "Semantic sum patterns are not supported by IR conversion."); + }, [&](RestOfTuple* n) -> absl::StatusOr { return Def(matcher, [&](const SourceInfo& loc) { return function_builder_->Literal(UBits(1, 1), loc); diff --git a/xls/dslx/lsp/document_symbols.cc b/xls/dslx/lsp/document_symbols.cc index b3ae82e57d..cf08b1c649 100644 --- a/xls/dslx/lsp/document_symbols.cc +++ b/xls/dslx/lsp/document_symbols.cc @@ -61,6 +61,16 @@ std::vector ToDocumentSymbols(const EnumDef& e) { return {std::move(ds)}; } +std::vector ToDocumentSymbols(const SumDef& s) { + verible::lsp::DocumentSymbol ds = { + .name = s.identifier(), + .kind = verible::lsp::SymbolKind::kEnum, + .range = ConvertSpanToLspRange(s.span()), + .selectionRange = ConvertSpanToLspRange(s.name_def()->span()), + }; + return {std::move(ds)}; +} + std::vector ToDocumentSymbols( const ConstantDef& c) { verible::lsp::DocumentSymbol ds = { diff --git a/xls/dslx/tests/errors/error_modules_test.py b/xls/dslx/tests/errors/error_modules_test.py index ff183622d3..a4504c75df 100644 --- a/xls/dslx/tests/errors/error_modules_test.py +++ b/xls/dslx/tests/errors/error_modules_test.py @@ -295,7 +295,9 @@ def test_match_multi_pattern_with_bindings(self): 'xls/dslx/tests/errors/match_multi_pattern_with_bindings.x:17:5-17:7', stderr, ) - self.assertIn('Cannot have multiple patterns that bind names', stderr) + self.assertIn( + 'Cannot bind names in a match arm with multiple patterns', stderr + ) def test_co_recursion(self): stderr = self._run( diff --git a/xls/dslx/translators/dslx_to_verilog.cc b/xls/dslx/translators/dslx_to_verilog.cc index 822de5b472..aa5f8f53f1 100644 --- a/xls/dslx/translators/dslx_to_verilog.cc +++ b/xls/dslx/translators/dslx_to_verilog.cc @@ -113,6 +113,9 @@ std::optional TypeDefinitionIdentifier( [](EnumDef* enum_def) -> std::optional { return enum_def->name_def()->identifier(); }, + [](SumDef* sum_def) -> std::optional { + return sum_def->name_def()->identifier(); + }, }, resolved_type_definition.definition); } @@ -342,6 +345,12 @@ DslxTypeToVerilogManager::TypeDefinitionToVastType( "DslxTypeToVerilogManager", proc_def->ToString())); }, + [&](SumDef* sum_def) -> absl::StatusOr { + return absl::UnimplementedError(absl::StrFormat( + "TypeAnnotation SumDef %s not supported by " + "DslxTypeToVerilogManager", + sum_def->ToString())); + }, [&](UseTreeEntry* use_tree_entry) -> absl::StatusOr { return absl::UnimplementedError(absl::StrFormat( diff --git a/xls/dslx/translators/dslx_to_verilog_main.cc b/xls/dslx/translators/dslx_to_verilog_main.cc index eb25838374..90181aa257 100644 --- a/xls/dslx/translators/dslx_to_verilog_main.cc +++ b/xls/dslx/translators/dslx_to_verilog_main.cc @@ -63,6 +63,7 @@ bool TypeDefinitionSourceIsPublic(const TypeInfo::TypeSource& def_source) { [](const TypeAlias* type_alias) { return type_alias->is_public(); }, [](const ProcDef* proc_def) { return proc_def->is_public(); }, [](const EnumDef* enum_def) { return enum_def->is_public(); }, + [](const SumDef* sum_def) { return sum_def->is_public(); }, [](const StructDef* struct_def) { return struct_def->is_public(); }, }, def_source.definition); diff --git a/xls/dslx/type_system/type_info.cc b/xls/dslx/type_system/type_info.cc index 066e9a8659..dfad8d9f7b 100644 --- a/xls/dslx/type_system/type_info.cc +++ b/xls/dslx/type_system/type_info.cc @@ -423,6 +423,9 @@ absl::StatusOr TypeInfo::ResolveTypeDefinition( [this](EnumDef* sd) -> absl::StatusOr { return TypeInfo::TypeSource{.type_info = this, .definition = sd}; }, + [this](SumDef* sd) -> absl::StatusOr { + return TypeInfo::TypeSource{.type_info = this, .definition = sd}; + }, [this](ColonRef* sd) -> absl::StatusOr { return ResolveTypeDefinition(sd); }, @@ -472,6 +475,7 @@ absl::StatusOr> TypeInfo::FindSvType( [](StructDef* sd) -> Res { return sd->extern_type_name(); }, [](ProcDef* pd) -> Res { return std::nullopt; }, [](EnumDef* sd) -> Res { return sd->extern_type_name(); }, + [](SumDef* sd) -> Res { return sd->extern_type_name(); }, [src](TypeAlias* ta) -> Res { std::optional sv_type = ta->extern_type_name(); if (sv_type) { diff --git a/xls/dslx/type_system/type_info.h b/xls/dslx/type_system/type_info.h index 21242d4c5e..51a80b9b53 100644 --- a/xls/dslx/type_system/type_info.h +++ b/xls/dslx/type_system/type_info.h @@ -357,7 +357,8 @@ class TypeInfo { struct TypeSource { TypeInfo* type_info; - std::variant definition; + std::variant + definition; }; // Get the actual instructions which provided the given type-definition with diff --git a/xls/dslx/type_system/type_info.proto b/xls/dslx/type_system/type_info.proto index c1195e93ce..ae2eaef689 100644 --- a/xls/dslx/type_system/type_info.proto +++ b/xls/dslx/type_system/type_info.proto @@ -98,6 +98,10 @@ enum AstNodeKindProto { AST_NODE_KIND_TRAIT = 73; AST_NODE_KIND_ATTRIBUTE = 74; AST_NODE_KIND_FUZZ_TEST_FUNCTION = 75; + AST_NODE_KIND_SUM_VARIANT_PAYLOAD_PATTERN = 76; + AST_NODE_KIND_SUM_DEF = 77; + AST_NODE_KIND_SUM_VARIANT = 78; + AST_NODE_KIND_SUM_INSTANCE = 79; } message BitsValueProto { diff --git a/xls/dslx/type_system/type_info_to_proto.cc b/xls/dslx/type_system/type_info_to_proto.cc index 06b8057493..bb176dfb77 100644 --- a/xls/dslx/type_system/type_info_to_proto.cc +++ b/xls/dslx/type_system/type_info_to_proto.cc @@ -99,6 +99,8 @@ AstNodeKindProto ToProto(AstNodeKind kind) { return AST_NODE_KIND_STRING; case AstNodeKind::kStructInstance: return AST_NODE_KIND_STRUCT_INSTANCE; + case AstNodeKind::kSumInstance: + return AST_NODE_KIND_SUM_INSTANCE; case AstNodeKind::kStructMember: return AST_NODE_KIND_STRUCT_MEMBER; case AstNodeKind::kNameDefTree: @@ -153,6 +155,10 @@ AstNodeKindProto ToProto(AstNodeKind kind) { return AST_NODE_KIND_SLICE; case AstNodeKind::kEnumDef: return AST_NODE_KIND_ENUM_DEF; + case AstNodeKind::kSumDef: + return AST_NODE_KIND_SUM_DEF; + case AstNodeKind::kSumVariant: + return AST_NODE_KIND_SUM_VARIANT; case AstNodeKind::kStructDef: return AST_NODE_KIND_STRUCT_DEF; case AstNodeKind::kProcDef: @@ -187,6 +193,8 @@ AstNodeKindProto ToProto(AstNodeKind kind) { return AST_NODE_KIND_PROC_MEMBER; case AstNodeKind::kRestOfTuple: return AST_NODE_KIND_REST_OF_TUPLE; + case AstNodeKind::kSumVariantPayloadPattern: + return AST_NODE_KIND_SUM_VARIANT_PAYLOAD_PATTERN; case AstNodeKind::kImpl: return AST_NODE_KIND_IMPL; case AstNodeKind::kVerbatimNode: @@ -715,6 +723,8 @@ absl::StatusOr FromProto(AstNodeKindProto p) { return AstNodeKind::kString; case AST_NODE_KIND_STRUCT_INSTANCE: return AstNodeKind::kStructInstance; + case AST_NODE_KIND_SUM_INSTANCE: + return AstNodeKind::kSumInstance; case AST_NODE_KIND_STRUCT_MEMBER: return AstNodeKind::kStructMember; case AST_NODE_KIND_NAME_DEF_TREE: @@ -765,6 +775,10 @@ absl::StatusOr FromProto(AstNodeKindProto p) { return AstNodeKind::kSlice; case AST_NODE_KIND_ENUM_DEF: return AstNodeKind::kEnumDef; + case AST_NODE_KIND_SUM_DEF: + return AstNodeKind::kSumDef; + case AST_NODE_KIND_SUM_VARIANT: + return AstNodeKind::kSumVariant; case AST_NODE_KIND_STRUCT_DEF: return AstNodeKind::kStructDef; case AST_NODE_KIND_PROC_DEF: @@ -805,6 +819,8 @@ absl::StatusOr FromProto(AstNodeKindProto p) { return AstNodeKind::kProcMember; case AST_NODE_KIND_REST_OF_TUPLE: return AstNodeKind::kRestOfTuple; + case AST_NODE_KIND_SUM_VARIANT_PAYLOAD_PATTERN: + return AstNodeKind::kSumVariantPayloadPattern; case AST_NODE_KIND_IMPL: return AstNodeKind::kImpl; case AST_NODE_KIND_VERBATIM_NODE: diff --git a/xls/dslx/type_system_v2/typecheck_module_v2_enum_test.cc b/xls/dslx/type_system_v2/typecheck_module_v2_enum_test.cc index 34a9c71850..e52afe602a 100644 --- a/xls/dslx/type_system_v2/typecheck_module_v2_enum_test.cc +++ b/xls/dslx/type_system_v2/typecheck_module_v2_enum_test.cc @@ -14,9 +14,9 @@ #include +#include "absl/status/status_matchers.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "absl/status/status_matchers.h" #include "xls/common/status/matchers.h" #include "xls/dslx/create_import_data.h" #include "xls/dslx/import_data.h" @@ -139,13 +139,13 @@ enum MyEnum : u8 { TypecheckFails(HasTypeMismatch("u8[u32:1]", "u8"))); } -TEST(TypecheckV2Test, EnumNoTypeOrValue) { +TEST(TypecheckV2Test, EmptyEnumIsSemanticSum) { EXPECT_THAT( R"( -enum MyEnum { +enum Never { } )", - TypecheckFails(HasSubstr("has no type annotation and no value"))); + TypecheckSucceeds(::testing::_)); } TEST(TypecheckV2Test, EnumTypeAlias) { @@ -201,6 +201,19 @@ const_assert!(y as u8 == 2); HasNodeWithType("y", "MyEnum")))); } +TEST(TypecheckV2Test, SumVariantPayloadPatternOnNumericEnumReturnsDiagnostic) { + EXPECT_THAT( + R"( +enum E : u1 { A = 0 } +fn f(x: E) -> u32 { + match x { + E::A() => u32:0, + } +} +)", + TypecheckFails(HasSubstr("Constructor pattern"))); +} + TEST(TypecheckV2Test, EnumInvalidNameRef) { EXPECT_THAT( R"( diff --git a/xls/dslx/type_system_v2/validate_concrete_type.cc b/xls/dslx/type_system_v2/validate_concrete_type.cc index 9c76f14c8f..91da914f27 100644 --- a/xls/dslx/type_system_v2/validate_concrete_type.cc +++ b/xls/dslx/type_system_v2/validate_concrete_type.cc @@ -378,12 +378,27 @@ class TypeValidator : public AstNodeVisitorWithDefault { for (MatchArm* arm : node->arms()) { for (NameDefTree* pattern : arm->patterns()) { - bool exhaustive_before = exhaustiveness_checker.IsExhaustive(); - exhaustiveness_checker.AddPattern(*pattern); - if (exhaustive_before) { - warning_collector_.Add( - pattern->span(), WarningKind::kAlreadyExhaustiveMatch, - "Match is already exhaustive before this pattern"); + bool has_constructor_pattern = false; + for (NameDefTree::Leaf leaf : pattern->Flatten()) { + if (std::holds_alternative(leaf)) { + has_constructor_pattern = true; + break; + } + } + if (has_constructor_pattern) { + return TypeInferenceErrorStatus( + pattern->span(), matched, + "Constructor patterns are not supported before semantic-sum " + "pattern typechecking is enabled.", + file_table_); + } else { + bool exhaustive_before = exhaustiveness_checker.IsExhaustive(); + exhaustiveness_checker.AddPattern(*pattern); + if (exhaustive_before) { + warning_collector_.Add( + pattern->span(), WarningKind::kAlreadyExhaustiveMatch, + "Match is already exhaustive before this pattern"); + } } } } diff --git a/xls/fuzzer/value_generator.cc b/xls/fuzzer/value_generator.cc index 338c7a88c3..dd99e4fdfd 100644 --- a/xls/fuzzer/value_generator.cc +++ b/xls/fuzzer/value_generator.cc @@ -312,6 +312,10 @@ absl::StatusOr GenerateDslxConstant(absl::BitGenRef bit_gen, return module->Make(fake_span, name_ref, value.name_def->identifier()); }, + [&](dslx::SumDef* sum_def) -> absl::StatusOr { + return absl::UnimplementedError( + "Generating constants of sum types isn't yet supported."); + }, [&](dslx::ColonRef* colon_ref) -> absl::StatusOr { return absl::UnimplementedError( "Generating constants of ColonRef types isn't yet supported.");