diff --git a/xls/ir/ir_matcher.cc b/xls/ir/ir_matcher.cc index 3955802d99..7e60e010c8 100644 --- a/xls/ir/ir_matcher.cc +++ b/xls/ir/ir_matcher.cc @@ -654,8 +654,18 @@ bool NextMatcher::MatchAndExplain( if (!NodeMatcher::MatchAndExplain(node, listener)) { return false; } - if (label_.has_value() && - !label_->MatchAndExplain(node->As()->label(), listener)) { + const xls::Next* next = node->As(); + if (state_element_.has_value()) { + if (next->has_state_read()) { + *listener << " is a coupled Next node, but expected decoupled"; + return false; + } + if (!state_element_->MatchAndExplain(next->state_element(), listener)) { + *listener << " has incorrect state element"; + return false; + } + } + if (label_.has_value() && !label_->MatchAndExplain(next->label(), listener)) { *listener << " has incorrect label"; return false; } @@ -664,6 +674,12 @@ bool NextMatcher::MatchAndExplain( void NextMatcher::DescribeTo(::std::ostream* os) const { std::vector additional_fields; + if (state_element_.has_value()) { + std::stringstream ss; + ss << "state_element="; + state_element_->DescribeTo(&ss); + additional_fields.push_back(ss.str()); + } if (label_.has_value()) { std::stringstream ss; ss << "label="; diff --git a/xls/ir/ir_matcher.h b/xls/ir/ir_matcher.h index 497fb2d8f7..b5d405c39b 100644 --- a/xls/ir/ir_matcher.h +++ b/xls/ir/ir_matcher.h @@ -1328,17 +1328,28 @@ inline ::testing::Matcher StateRead() { // // EXPECT_THAT(x, m::Next()); // EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1))); -// EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1), -// m::Literal(1))) -// EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1), -// m::Literal(1), "some_label")) +// EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1), m::Literal(1))); +// EXPECT_THAT(x, m::NextWithLabel(m::StateRead("foo"), m::Literal(1), +// "label")); +// +// Decoupled forms (asserts that the next node has no StateRead operand): +// EXPECT_THAT(x, m::NextWithStateElement(se, m::Literal(1))); +// EXPECT_THAT(x, m::NextWithStateElement(se, m::Literal(1), m::Literal(1))); +// EXPECT_THAT(x, m::NextWithStateElementWithLabel(se, m::Literal(1), +// "label")); +// EXPECT_THAT(x, m::NextWithStateElementWithLabel(se, m::Literal(1), +// m::Literal(1), "label")); class NextMatcher : public NodeMatcher { public: explicit NextMatcher( absl::Span> operands = {}, std::optional<::testing::Matcher&>> - label = std::nullopt) - : NodeMatcher(Op::kNext, operands), label_(std::move(label)) {} + label = std::nullopt, + std::optional<::testing::Matcher> + state_element = std::nullopt) + : NodeMatcher(Op::kNext, operands), + label_(std::move(label)), + state_element_(std::move(state_element)) {} bool MatchAndExplain(const Node* node, ::testing::MatchResultListener* listener) const override; @@ -1346,6 +1357,7 @@ class NextMatcher : public NodeMatcher { private: std::optional<::testing::Matcher&>> label_; + std::optional<::testing::Matcher> state_element_; }; inline ::testing::Matcher Next() { return NextMatcher(); } @@ -1374,6 +1386,31 @@ inline ::testing::Matcher NextWithLabel( return NextMatcher({state_read, value, predicate}, label); } +inline ::testing::Matcher NextWithStateElement( + ::testing::Matcher state_element, + ::testing::Matcher value) { + return NextMatcher({value}, std::nullopt, state_element); +} +inline ::testing::Matcher NextWithStateElement( + ::testing::Matcher state_element, + ::testing::Matcher value, + ::testing::Matcher predicate) { + return NextMatcher({value, predicate}, std::nullopt, state_element); +} +inline ::testing::Matcher NextWithStateElementWithLabel( + ::testing::Matcher state_element, + ::testing::Matcher value, + ::testing::Matcher&> label) { + return NextMatcher({value}, label, state_element); +} +inline ::testing::Matcher NextWithStateElementWithLabel( + ::testing::Matcher state_element, + ::testing::Matcher value, + ::testing::Matcher predicate, + ::testing::Matcher&> label) { + return NextMatcher({value, predicate}, label, state_element); +} + // RegisterRead matcher. Matches register name only. Supported forms: // // EXPECT_THAT(x, m::RegisterRead()); diff --git a/xls/ir/proc.cc b/xls/ir/proc.cc index 4d88219dd8..09a3b12801 100644 --- a/xls/ir/proc.cc +++ b/xls/ir/proc.cc @@ -1002,22 +1002,42 @@ absl::StatusOr Proc::TransformStateElement( // Identity-ify the old next nodes and create new ones. for (const NextTransformation& nt : transforms) { - // Make the next - XLS_ASSIGN_OR_RETURN( - Next * nxt, - MakeNodeWithName(nt.old_next->loc(), new_state_read, nt.new_value, - nt.new_predicate, nt.old_next->label(), - nt.old_next->GetName())); + Next* nxt; + if (nt.old_next->has_state_read()) { + // Coupled: use the matched new_state_read + XLS_ASSIGN_OR_RETURN( + nxt, + MakeNodeWithName(nt.old_next->loc(), new_state_read, + nt.new_value, nt.new_predicate, + nt.old_next->label(), nt.old_next->GetName())); + } else { + // Decoupled: use new_state_element directly + XLS_ASSIGN_OR_RETURN( + nxt, + MakeNodeWithName(nt.old_next->loc(), new_state_element, + nt.new_value, nt.new_predicate, + nt.old_next->label(), nt.old_next->GetName())); + } to_replace.push_back({nt.old_next, nxt}); // Identity-ify the old next. - XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber( - Next::kValueOperand, nt.old_next->state_read())); + if (nt.old_next->has_state_read()) { + XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber( + Next::kValueOperand, nt.old_next->state_read())); + } else { + XLS_ASSIGN_OR_RETURN( + Node * placeholder, + MakeNode(nt.old_next->loc(), + ZeroOfType(old_state_element->type()))); + XLS_RET_CHECK( + nt.old_next->ReplaceOperand(nt.old_next->value(), placeholder)); + } } for (const auto& [old_n, new_n] : to_replace) { XLS_RETURN_IF_ERROR(old_n->ReplaceUsesWith( new_n, [&](Node* n) { - if (n->Is() && n->As()->state_read() == old_n) { + if (n->Is() && n->As()->has_state_read() && + n->As()->state_read() == old_n) { return false; } return true; diff --git a/xls/ir/proc.h b/xls/ir/proc.h index d6dd5d9b55..6a15cce9ca 100644 --- a/xls/ir/proc.h +++ b/xls/ir/proc.h @@ -200,8 +200,8 @@ class Proc : public FunctionBase { // switch users of the old param to the new one. // // The old state element will continue to exist with a new name and all - // identity next nodes. It should be cleaned up using the - // NextValueOptimizationPass. + // identity next nodes with all users removed. It should be cleaned up by + // ProcStateOptimizationPass in RemoveUnobservableStateElements. // // The proc must only use 'next' nodes to call this function. absl::StatusOr TransformStateElement( diff --git a/xls/ir/proc_test.cc b/xls/ir/proc_test.cc index 16b68917f4..687b943f11 100644 --- a/xls/ir/proc_test.cc +++ b/xls/ir/proc_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -57,7 +58,36 @@ using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::UnorderedElementsAre; -class ProcTest : public IrTestBase {}; +class ProcTest : public IrTestBase { + protected: + struct TestTransformer : public Proc::StateElementTransformer { + public: + absl::StatusOr TransformStateRead( + Proc* proc, StateRead* new_state_read, + StateRead* old_state_read) override { + return proc->MakeNode(new_state_read->loc(), new_state_read, + Op::kNeg); + } + absl::StatusOr TransformNextValue(Proc* proc, + StateRead* new_state_read, + Next* old_next) override { + return proc->MakeNode(old_next->value()->loc(), old_next->value(), + Op::kNeg); + } + absl::StatusOr> TransformNextPredicate( + Proc* proc, StateRead* new_state_read, Next* old_next) override { + XLS_ASSIGN_OR_RETURN( + Node * true_const, + proc->MakeNode(old_next->loc(), Value::Bool(true))); + if (old_next->predicate()) { + return proc->MakeNode( + old_next->predicate().value()->loc(), + std::array{true_const, *old_next->predicate()}, Op::kAnd); + } + return true_const; + } + }; +}; TEST_F(ProcTest, SimpleProc) { auto p = CreatePackage(); @@ -508,33 +538,6 @@ TEST_F(ProcTest, TransformStateElement) { XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); // Test transformer that inverts the param. - struct TestTransformer : public Proc::StateElementTransformer { - public: - absl::StatusOr TransformStateRead( - Proc* proc, StateRead* new_state_read, - StateRead* old_state_read) override { - return proc->MakeNode(new_state_read->loc(), new_state_read, - Op::kNeg); - } - absl::StatusOr TransformNextValue(Proc* proc, - StateRead* new_state_read, - Next* old_next) override { - return proc->MakeNode(old_next->value()->loc(), old_next->value(), - Op::kNeg); - } - absl::StatusOr> TransformNextPredicate( - Proc* proc, StateRead* new_state_read, Next* old_next) override { - XLS_ASSIGN_OR_RETURN( - Node * true_const, - proc->MakeNode(old_next->loc(), Value::Bool(true))); - if (old_next->predicate()) { - return proc->MakeNode( - old_next->predicate().value()->loc(), - std::array{true_const, *old_next->predicate()}, Op::kAnd); - } - return true_const; - } - }; TestTransformer tt; ScopedRecordIr sri(p.get()); XLS_ASSERT_OK_AND_ASSIGN( @@ -568,6 +571,83 @@ TEST_F(ProcTest, TransformStateElement) { EXPECT_THAT(user.node(), m::Tuple(m::Neg(new_st))); } +TEST_F(ProcTest, TransformStateElementDecoupled) { + auto p = CreatePackage(); + TokenlessProcBuilder pb(TestName(), "tkn", p.get()); + auto cond = pb.StateElement("cond", UBits(0, 1)); + + XLS_ASSERT_OK_AND_ASSIGN( + StateElement * state_element, + pb.UnreadStateElement("st", Value(UBits(0b1010, 4)))); + + BValue st_read = pb.StateRead(state_element, std::nullopt, "my_read_label"); + + // Labeled next node + BValue add_st = + pb.Next(state_element, pb.Add(st_read, pb.Literal(UBits(1, 4))), cond, + /*label=*/"my_next_label"); + // Unlabeled next node + BValue sub_st = + pb.Next(state_element, pb.Subtract(st_read, pb.Literal(UBits(1, 4))), + pb.Not(cond)); + + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + // Test transformer that inverts the param. + TestTransformer tt; + ScopedRecordIr sri(p.get()); + XLS_ASSERT_OK_AND_ASSIGN( + StateElement * new_st_element, + proc->TransformStateElement(state_element, Value(UBits(0b0101, 4)), tt)); + StateRead* new_st = proc->GetStateReadByStateElement(new_st_element); + + // Make sure the st next has been identity-ified (dummy value Zero is set + // for decoupled) + EXPECT_THAT(st_read.node(), m::StateRead(testing::Not("st"))); + EXPECT_THAT(st_read.node()->users(), ::testing::IsEmpty()); + + // Verify old next nodes are identity-ified (labeled and unlabeled) + EXPECT_THAT(add_st.node(), m::NextWithStateElementWithLabel( + state_element, m::Literal(0), cond.node(), + std::optional("my_next_label"))); + EXPECT_THAT(sub_st.node(), + m::NextWithStateElement(state_element, m::Literal(0), + m::Not(cond.node()))); + + // Make sure that 'new_state_read' takes over the name and everything. + EXPECT_THAT(new_st, + m::StateRead("st", std::optional("my_read_label"))); + EXPECT_THAT(new_st->users(), UnorderedElementsAre(m::Neg(new_st))); + + // Verify new next nodes (labeled and unlabeled) + EXPECT_THAT(proc->next_values(new_st_element), + UnorderedElementsAre( + m::NextWithStateElementWithLabel( + new_st_element, + m::Neg(m::Add(m::Neg(new_st), m::Literal(UBits(1, 4)))), + m::And(m::Literal(UBits(1, 1)), cond.node()), + std::optional("my_next_label")), + m::NextWithStateElement( + new_st_element, + m::Neg(m::Sub(m::Neg(new_st), m::Literal(UBits(1, 4)))), + m::And(m::Literal(UBits(1, 1)), m::Not(cond.node()))))); + + XLS_ASSERT_OK_AND_ASSIGN(int64_t old_idx, + proc->GetStateElementIndex(state_element)); + + // Remove the old Next nodes. + std::vector old_nexts(proc->next_values(state_element).begin(), + proc->next_values(state_element).end()); + for (Next* next : old_nexts) { + XLS_ASSERT_OK(proc->RemoveNode(next)); + } + + // Remove the old state element (and its StateRead) with no users. + XLS_EXPECT_OK(proc->RemoveStateElement(old_idx)); + + // Verify only 'cond' and 'new_st' remain. + EXPECT_EQ(proc->GetStateElementCount(), 2); +} + class ScheduledProcTest : public IrTestBase { protected: absl::StatusOr CreateScheduledProc(Package* p) { diff --git a/xls/passes/array_untuple_pass.cc b/xls/passes/array_untuple_pass.cc index 7dd041342c..f599ecd43a 100644 --- a/xls/passes/array_untuple_pass.cc +++ b/xls/passes/array_untuple_pass.cc @@ -134,11 +134,16 @@ absl::StatusOr> FindExternalGroups( if (absl::c_all_of(state_read->users(), [&](Node* n) -> bool { if (n->Is()) { Next* nxt = n->As(); - return nxt->state_read() == nxt->value() && - nxt->state_read() == state_read; + if (nxt->has_state_read()) { + return nxt->state_read() == nxt->value() && + nxt->state_read() == state_read; + } else { + return nxt->value() == state_read && + nxt->state_element() == state_read->state_element(); + } } // TODO(nelsonliang): Handle identity state elements by retrieving - // all state reads and verifying all reds are identity updates. + // all state reads and verifying all reads are identity updates. return false; })) { excluded.insert(groups.Find(state_read)); @@ -351,10 +356,19 @@ class UntupleVisitor : public DfsVisitorWithDefault { iter::zip(iter::count(), state_read_values, update_values)) { XLS_RET_CHECK(state_read_node->Is()); StateRead* state_read = state_read_node->As(); - XLS_RETURN_IF_ERROR(proc->MakeNodeWithName( - n->loc(), state_read, value, n->predicate(), - n->label(), IdxName(n, idx)) - .status()); + if (n->has_state_read()) { + XLS_RETURN_IF_ERROR(proc->MakeNodeWithName( + n->loc(), state_read, value, n->predicate(), + n->label(), IdxName(n, idx)) + .status()); + } else { + StateElement* state_element = state_read->state_element(); + XLS_RET_CHECK(state_element != nullptr); + XLS_RETURN_IF_ERROR(proc->MakeNodeWithName( + n->loc(), state_element, value, + n->predicate(), n->label(), IdxName(n, idx)) + .status()); + } } // Remove this next from consideration. if (n->value() != state_read) { diff --git a/xls/passes/array_untuple_pass_test.cc b/xls/passes/array_untuple_pass_test.cc index 26fbda8db8..a656c303d0 100644 --- a/xls/passes/array_untuple_pass_test.cc +++ b/xls/passes/array_untuple_pass_test.cc @@ -525,6 +525,31 @@ TEST_F(ArrayUntuplePassTest, ProcStateArrayWithInvoke) { EXPECT_EQ(val2_before, val2_after); } +TEST_F(ArrayUntuplePassTest, ProcStateArrayIdentityNextWithStateElement) { + auto p = CreatePackage(); + ProcBuilder pb(TestName(), p.get()); + + XLS_ASSERT_OK_AND_ASSIGN( + Value init_val, + ValueBuilder::Array( + {ValueBuilder::Tuple({Value(UBits(0, 4)), Value(UBits(0, 8))}), + ValueBuilder::Tuple({Value(UBits(0, 4)), Value(UBits(0, 8))})}) + .Build()); + + XLS_ASSERT_OK_AND_ASSIGN(StateElement * state_elem, + pb.UnreadStateElement("my_state", init_val)); + BValue state_read = pb.StateRead(state_elem); + + // Identity next (decoupled) + pb.Next(state_elem, state_read); + + XLS_ASSERT_OK_AND_ASSIGN(Proc * pr, pb.Build()); + ScopedRecordIr sri(p.get()); + EXPECT_THAT(RunPass(p.get()), IsOkAndHolds(false)); + EXPECT_THAT(pr->StateElements(), + UnorderedElementsAre(m::StateElement("my_state"))); +} + TEST_F(ArrayUntuplePassTest, ProcStateArrayNextWithStateElement) { auto p = CreatePackage(); ProcBuilder pb(TestName(), p.get()); @@ -555,6 +580,17 @@ TEST_F(ArrayUntuplePassTest, ProcStateArrayNextWithStateElement) { EXPECT_THAT(pr->StateElements(), IsSupersetOf({m::StateElement(_, m::Type("bits[4][2]")), m::StateElement(_, m::Type("bits[8][2]"))})); + + XLS_ASSERT_OK_AND_ASSIGN(StateElement * se0, pr->GetStateElementByName( + "my_state_tuple_element_0")); + XLS_ASSERT_OK_AND_ASSIGN(StateElement * se1, pr->GetStateElementByName( + "my_state_tuple_element_1")); + + // Verify next values are decoupled (no state_read operand) + EXPECT_THAT(pr->next_values(se0), + UnorderedElementsAre(m::NextWithStateElement(se0, _))); + EXPECT_THAT(pr->next_values(se1), + UnorderedElementsAre(m::NextWithStateElement(se1, _))); } void IrFuzzArrayUntuple(FuzzPackageWithArgs fuzz_package_with_args) { diff --git a/xls/passes/proc_state_tuple_flattening_pass.cc b/xls/passes/proc_state_tuple_flattening_pass.cc index db529c032e..7f854755f3 100644 --- a/xls/passes/proc_state_tuple_flattening_pass.cc +++ b/xls/passes/proc_state_tuple_flattening_pass.cc @@ -139,6 +139,7 @@ struct NextValue { Node* value; std::optional predicate; std::optional label; + bool has_state_read; }; struct AbstractStateElement { std::string name; @@ -178,13 +179,23 @@ absl::Status ReplaceProcState(Proc* proc, read->state_element()->SetNonSynthesizable(); } for (const NextValue& next_value : element.next_values) { - XLS_RETURN_IF_ERROR(proc->MakeNodeWithName( - next_value.loc, - /*state_read=*/read, - /*value=*/next_value.value, - /*predicate=*/next_value.predicate, - /*label=*/next_value.label, next_value.name) - .status()); + if (next_value.has_state_read) { + XLS_RETURN_IF_ERROR(proc->MakeNodeWithName( + next_value.loc, + /*state_read=*/read, + /*value=*/next_value.value, + /*predicate=*/next_value.predicate, + /*label=*/next_value.label, next_value.name) + .status()); + } else { + XLS_RETURN_IF_ERROR(proc->MakeNodeWithName( + next_value.loc, + /*state_element=*/read->state_element(), + /*value=*/next_value.value, + /*predicate=*/next_value.predicate, + /*label=*/next_value.label, next_value.name) + .status()); + } XLS_RETURN_IF_ERROR(element.placeholder->ReplaceUsesWith(read)); } } @@ -271,6 +282,7 @@ absl::Status FlattenState(Proc* proc) { .value = value, .predicate = predicate, .label = next->label(), + .has_state_read = next->has_state_read(), }); } } diff --git a/xls/passes/proc_state_tuple_flattening_pass_test.cc b/xls/passes/proc_state_tuple_flattening_pass_test.cc index ba7295e936..190f683af6 100644 --- a/xls/passes/proc_state_tuple_flattening_pass_test.cc +++ b/xls/passes/proc_state_tuple_flattening_pass_test.cc @@ -433,7 +433,7 @@ TEST_F(ProcStateFlatteningPassTest, BValue next1 = pb.Add(elem1, pb.Literal(UBits(2, 32))); BValue next_val = pb.Tuple({next0, next1}); - pb.Next(read, next_val, /*predicate=*/std::nullopt, + pb.Next(state, next_val, /*predicate=*/std::nullopt, /*label=*/"my_write_label"); XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); @@ -445,10 +445,9 @@ TEST_F(ProcStateFlatteningPassTest, EXPECT_THAT(proc->nodes(), Contains(m::StateRead( "state_0", std::optional("my_read_label")))); - EXPECT_THAT( - proc->nodes(), - Contains(m::NextWithLabel(m::StateRead(), _, - std::optional("my_write_label")))); + EXPECT_THAT(proc->nodes(), + Contains(m::NextWithStateElementWithLabel( + _, _, std::optional("my_write_label")))); } INSTANTIATE_TEST_SUITE_P(NextValueTypes, ProcStateFlatteningPassTest,