From 92e5cb0accb58909ccc7ed2f686bc43a86950171 Mon Sep 17 00:00:00 2001 From: Nelson Liang Date: Tue, 23 Jun 2026 12:01:37 -0700 Subject: [PATCH] [Explicit State Access] Support decoupled nodes in proc_state_legalization. Specifically: * Ensure default next predicate is nor(all other next_value predicates) instead of read_predicate && nor(all other write predicates) * Bypass restriction that every write is preceded by a read for decoupled next nodes. * No longer ensure a state_read read is triggered by every write operation if the proc is decoupled. PiperOrigin-RevId: 936810475 --- xls/ir/ir_matcher.cc | 20 +- xls/ir/ir_matcher.h | 49 ++++- xls/ir/nodes.cc | 9 +- xls/ir/proc.cc | 38 +++- xls/ir/proc.h | 4 +- xls/ir/proc_test.cc | 120 +++++++++--- xls/passes/BUILD | 1 + xls/passes/array_untuple_pass.cc | 28 ++- xls/passes/array_untuple_pass_test.cc | 37 ++++ xls/passes/next_value_optimization_pass.cc | 183 +++++++++++++----- .../next_value_optimization_pass_test.cc | 101 ++++++++++ .../proc_state_optimization_pass_test.cc | 31 +++ .../proc_state_tuple_flattening_pass.cc | 37 +++- .../proc_state_tuple_flattening_pass_test.cc | 9 +- xls/scheduling/BUILD | 4 + .../proc_state_legalization_pass.cc | 54 ++++-- .../proc_state_legalization_pass_test.cc | 32 +++ xls/scheduling/sdc_scheduler.cc | 9 +- xls/scheduling/sdc_scheduler_test.cc | 55 ++++++ 19 files changed, 691 insertions(+), 130 deletions(-) diff --git a/xls/ir/ir_matcher.cc b/xls/ir/ir_matcher.cc index 3955802d99..7e60e010c8 100644 --- a/xls/ir/ir_matcher.cc +++ b/xls/ir/ir_matcher.cc @@ -654,8 +654,18 @@ bool NextMatcher::MatchAndExplain( if (!NodeMatcher::MatchAndExplain(node, listener)) { return false; } - if (label_.has_value() && - !label_->MatchAndExplain(node->As()->label(), listener)) { + const xls::Next* next = node->As(); + if (state_element_.has_value()) { + if (next->has_state_read()) { + *listener << " is a coupled Next node, but expected decoupled"; + return false; + } + if (!state_element_->MatchAndExplain(next->state_element(), listener)) { + *listener << " has incorrect state element"; + return false; + } + } + if (label_.has_value() && !label_->MatchAndExplain(next->label(), listener)) { *listener << " has incorrect label"; return false; } @@ -664,6 +674,12 @@ bool NextMatcher::MatchAndExplain( void NextMatcher::DescribeTo(::std::ostream* os) const { std::vector additional_fields; + if (state_element_.has_value()) { + std::stringstream ss; + ss << "state_element="; + state_element_->DescribeTo(&ss); + additional_fields.push_back(ss.str()); + } if (label_.has_value()) { std::stringstream ss; ss << "label="; diff --git a/xls/ir/ir_matcher.h b/xls/ir/ir_matcher.h index 497fb2d8f7..b5d405c39b 100644 --- a/xls/ir/ir_matcher.h +++ b/xls/ir/ir_matcher.h @@ -1328,17 +1328,28 @@ inline ::testing::Matcher StateRead() { // // EXPECT_THAT(x, m::Next()); // EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1))); -// EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1), -// m::Literal(1))) -// EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1), -// m::Literal(1), "some_label")) +// EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1), m::Literal(1))); +// EXPECT_THAT(x, m::NextWithLabel(m::StateRead("foo"), m::Literal(1), +// "label")); +// +// Decoupled forms (asserts that the next node has no StateRead operand): +// EXPECT_THAT(x, m::NextWithStateElement(se, m::Literal(1))); +// EXPECT_THAT(x, m::NextWithStateElement(se, m::Literal(1), m::Literal(1))); +// EXPECT_THAT(x, m::NextWithStateElementWithLabel(se, m::Literal(1), +// "label")); +// EXPECT_THAT(x, m::NextWithStateElementWithLabel(se, m::Literal(1), +// m::Literal(1), "label")); class NextMatcher : public NodeMatcher { public: explicit NextMatcher( absl::Span> operands = {}, std::optional<::testing::Matcher&>> - label = std::nullopt) - : NodeMatcher(Op::kNext, operands), label_(std::move(label)) {} + label = std::nullopt, + std::optional<::testing::Matcher> + state_element = std::nullopt) + : NodeMatcher(Op::kNext, operands), + label_(std::move(label)), + state_element_(std::move(state_element)) {} bool MatchAndExplain(const Node* node, ::testing::MatchResultListener* listener) const override; @@ -1346,6 +1357,7 @@ class NextMatcher : public NodeMatcher { private: std::optional<::testing::Matcher&>> label_; + std::optional<::testing::Matcher> state_element_; }; inline ::testing::Matcher Next() { return NextMatcher(); } @@ -1374,6 +1386,31 @@ inline ::testing::Matcher NextWithLabel( return NextMatcher({state_read, value, predicate}, label); } +inline ::testing::Matcher NextWithStateElement( + ::testing::Matcher state_element, + ::testing::Matcher value) { + return NextMatcher({value}, std::nullopt, state_element); +} +inline ::testing::Matcher NextWithStateElement( + ::testing::Matcher state_element, + ::testing::Matcher value, + ::testing::Matcher predicate) { + return NextMatcher({value, predicate}, std::nullopt, state_element); +} +inline ::testing::Matcher NextWithStateElementWithLabel( + ::testing::Matcher state_element, + ::testing::Matcher value, + ::testing::Matcher&> label) { + return NextMatcher({value}, label, state_element); +} +inline ::testing::Matcher NextWithStateElementWithLabel( + ::testing::Matcher state_element, + ::testing::Matcher value, + ::testing::Matcher predicate, + ::testing::Matcher&> label) { + return NextMatcher({value, predicate}, label, state_element); +} + // RegisterRead matcher. Matches register name only. Supported forms: // // EXPECT_THAT(x, m::RegisterRead()); diff --git a/xls/ir/nodes.cc b/xls/ir/nodes.cc index a34aa926d7..223e0be2f3 100755 --- a/xls/ir/nodes.cc +++ b/xls/ir/nodes.cc @@ -1480,8 +1480,15 @@ absl::StatusOr StateRead::CloneInNewFunction( absl::StatusOr Next::CloneInNewFunction( absl::Span new_operands, FunctionBase* new_function) const { if (state_read_ == nullptr) { + XLS_RET_CHECK(new_function->IsProc()) + << this << " cloning into " << new_function; + XLS_ASSIGN_OR_RETURN( + int64_t idx, + function_base()->AsProcOrDie()->GetStateElementIndex(state_element())); + XLS_RET_CHECK_LT(idx, new_function->AsProcOrDie()->GetStateElementCount()); return new_function->MakeNodeWithName( - loc(), state_element_, new_operands[0], + loc(), new_function->AsProcOrDie()->GetStateElement(idx), + new_operands[0], new_operands.size() > 1 ? std::make_optional(new_operands[1]) : std::nullopt, label(), GetNameView()); diff --git a/xls/ir/proc.cc b/xls/ir/proc.cc index fd4ee6beb6..9be717eb7e 100644 --- a/xls/ir/proc.cc +++ b/xls/ir/proc.cc @@ -1004,22 +1004,42 @@ absl::StatusOr Proc::TransformStateElement( // Identity-ify the old next nodes and create new ones. for (const NextTransformation& nt : transforms) { - // Make the next - XLS_ASSIGN_OR_RETURN( - Next * nxt, - MakeNodeWithName(nt.old_next->loc(), new_state_read, nt.new_value, - nt.new_predicate, nt.old_next->label(), - nt.old_next->GetName())); + Next* nxt; + if (nt.old_next->has_state_read()) { + // Coupled: use the matched new_state_read + XLS_ASSIGN_OR_RETURN( + nxt, + MakeNodeWithName(nt.old_next->loc(), new_state_read, + nt.new_value, nt.new_predicate, + nt.old_next->label(), nt.old_next->GetName())); + } else { + // Decoupled: use new_state_element directly + XLS_ASSIGN_OR_RETURN( + nxt, + MakeNodeWithName(nt.old_next->loc(), new_state_element, + nt.new_value, nt.new_predicate, + nt.old_next->label(), nt.old_next->GetName())); + } to_replace.push_back({nt.old_next, nxt}); // Identity-ify the old next. - XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber( - Next::kValueOperand, nt.old_next->state_read())); + if (nt.old_next->has_state_read()) { + XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber( + Next::kValueOperand, nt.old_next->state_read())); + } else { + XLS_ASSIGN_OR_RETURN( + Node * placeholder, + MakeNode(nt.old_next->loc(), + ZeroOfType(old_state_element->type()))); + XLS_RET_CHECK( + nt.old_next->ReplaceOperand(nt.old_next->value(), placeholder)); + } } for (const auto& [old_n, new_n] : to_replace) { XLS_RETURN_IF_ERROR(old_n->ReplaceUsesWith( new_n, [&](Node* n) { - if (n->Is() && n->As()->state_read() == old_n) { + if (n->Is() && n->As()->has_state_read() && + n->As()->state_read() == old_n) { return false; } return true; diff --git a/xls/ir/proc.h b/xls/ir/proc.h index f55cb8c587..2898cbf316 100644 --- a/xls/ir/proc.h +++ b/xls/ir/proc.h @@ -199,8 +199,8 @@ class Proc : public FunctionBase { // switch users of the old param to the new one. // // The old state element will continue to exist with a new name and all - // identity next nodes. It should be cleaned up using the - // NextValueOptimizationPass. + // identity next nodes with all users removed. It should be cleaned up by + // ProcStateOptimizationPass in RemoveUnobservableStateElements. // // The proc must only use 'next' nodes to call this function. absl::StatusOr TransformStateElement( diff --git a/xls/ir/proc_test.cc b/xls/ir/proc_test.cc index f55514e68b..6ed81ad7ad 100644 --- a/xls/ir/proc_test.cc +++ b/xls/ir/proc_test.cc @@ -15,7 +15,6 @@ #include "xls/ir/proc.h" #include -#include #include #include #include @@ -57,7 +56,36 @@ using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::UnorderedElementsAre; -class ProcTest : public IrTestBase {}; +class ProcTest : public IrTestBase { + protected: + struct TestTransformer : public Proc::StateElementTransformer { + public: + absl::StatusOr TransformStateRead( + Proc* proc, StateRead* new_state_read, + StateRead* old_state_read) override { + return proc->MakeNode(new_state_read->loc(), new_state_read, + Op::kNeg); + } + absl::StatusOr TransformNextValue(Proc* proc, + StateRead* new_state_read, + Next* old_next) override { + return proc->MakeNode(old_next->value()->loc(), old_next->value(), + Op::kNeg); + } + absl::StatusOr> TransformNextPredicate( + Proc* proc, StateRead* new_state_read, Next* old_next) override { + XLS_ASSIGN_OR_RETURN( + Node * true_const, + proc->MakeNode(old_next->loc(), Value::Bool(true))); + if (old_next->predicate()) { + return proc->MakeNode( + old_next->predicate().value()->loc(), + std::array{true_const, *old_next->predicate()}, Op::kAnd); + } + return true_const; + } + }; +}; TEST_F(ProcTest, SimpleProc) { auto p = CreatePackage(); @@ -512,33 +540,6 @@ TEST_F(ProcTest, TransformStateElement) { XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); // Test transformer that inverts the param. - struct TestTransformer : public Proc::StateElementTransformer { - public: - absl::StatusOr TransformStateRead( - Proc* proc, StateRead* new_state_read, - StateRead* old_state_read) override { - return proc->MakeNode(new_state_read->loc(), new_state_read, - Op::kNeg); - } - absl::StatusOr TransformNextValue(Proc* proc, - StateRead* new_state_read, - Next* old_next) override { - return proc->MakeNode(old_next->value()->loc(), old_next->value(), - Op::kNeg); - } - absl::StatusOr> TransformNextPredicate( - Proc* proc, StateRead* new_state_read, Next* old_next) override { - XLS_ASSIGN_OR_RETURN( - Node * true_const, - proc->MakeNode(old_next->loc(), Value::Bool(true))); - if (old_next->predicate()) { - return proc->MakeNode( - old_next->predicate().value()->loc(), - std::array{true_const, *old_next->predicate()}, Op::kAnd); - } - return true_const; - } - }; TestTransformer tt; ScopedRecordIr sri(p.get()); XLS_ASSERT_OK_AND_ASSIGN( @@ -572,6 +573,67 @@ TEST_F(ProcTest, TransformStateElement) { EXPECT_THAT(user.node(), m::Tuple(m::Neg(new_st))); } +TEST_F(ProcTest, TransformStateElementDecoupled) { + auto p = CreatePackage(); + TokenlessProcBuilder pb(TestName(), "tkn", p.get()); + auto cond = pb.StateElement("cond", UBits(0, 1)); + + XLS_ASSERT_OK_AND_ASSIGN(StateElement * state_element, + pb.UnreadStateElement("st", Value(UBits(0b1010, 4)), + /*non_synthesizable=*/false)); + + BValue st_read = pb.StateRead(state_element, std::nullopt, "my_read_label"); + + // Labeled next node + BValue add_st = + pb.Next(state_element, pb.Add(st_read, pb.Literal(UBits(1, 4))), cond, + /*label=*/"my_next_label"); + // Unlabeled next node + BValue sub_st = + pb.Next(state_element, pb.Subtract(st_read, pb.Literal(UBits(1, 4))), + pb.Not(cond)); + + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + // Test transformer that inverts the param. + TestTransformer tt; + ScopedRecordIr sri(p.get()); + XLS_ASSERT_OK_AND_ASSIGN( + StateElement * new_st_element, + proc->TransformStateElement(state_element, Value(UBits(0b0101, 4)), tt)); + StateRead* new_st = proc->GetStateReadByStateElement(new_st_element); + + // Make sure the st next has been identity-ified (dummy value Zero is set + // for decoupled) + EXPECT_THAT(st_read.node(), m::StateRead(testing::Not("st"))); + EXPECT_THAT(st_read.node()->users(), ::testing::IsEmpty()); + + // Verify old next nodes are identity-ified (labeled and unlabeled) + EXPECT_THAT(add_st.node(), m::NextWithStateElementWithLabel( + state_element, m::Literal(0), cond.node(), + std::optional("my_next_label"))); + EXPECT_THAT(sub_st.node(), + m::NextWithStateElement(state_element, m::Literal(0), + m::Not(cond.node()))); + + // Make sure that 'new_state_read' takes over the name and everything. + EXPECT_THAT(new_st, + m::StateRead("st", std::optional("my_read_label"))); + EXPECT_THAT(new_st->users(), UnorderedElementsAre(m::Neg(new_st))); + + // Verify new next nodes (labeled and unlabeled) + EXPECT_THAT(proc->next_values(new_st_element), + UnorderedElementsAre( + m::NextWithStateElementWithLabel( + new_st_element, + m::Neg(m::Add(m::Neg(new_st), m::Literal(UBits(1, 4)))), + m::And(m::Literal(UBits(1, 1)), cond.node()), + std::optional("my_next_label")), + m::NextWithStateElement( + new_st_element, + m::Neg(m::Sub(m::Neg(new_st), m::Literal(UBits(1, 4)))), + m::And(m::Literal(UBits(1, 1)), m::Not(cond.node()))))); +} + class ScheduledProcTest : public IrTestBase { protected: absl::StatusOr CreateScheduledProc(Package* p) { diff --git a/xls/passes/BUILD b/xls/passes/BUILD index 9f66d70a51..6e85204b44 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -2569,6 +2569,7 @@ xls_pass( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", ], ) diff --git a/xls/passes/array_untuple_pass.cc b/xls/passes/array_untuple_pass.cc index 80daaa685a..7c009c8d70 100644 --- a/xls/passes/array_untuple_pass.cc +++ b/xls/passes/array_untuple_pass.cc @@ -134,11 +134,16 @@ absl::StatusOr> FindExternalGroups( if (absl::c_all_of(state_read->users(), [&](Node* n) -> bool { if (n->Is()) { Next* nxt = n->As(); - return nxt->state_read() == nxt->value() && - nxt->state_read() == state_read; + if (nxt->has_state_read()) { + return nxt->state_read() == nxt->value() && + nxt->state_read() == state_read; + } else { + return nxt->value() == state_read && + nxt->state_element() == state_read->state_element(); + } } // TODO(nelsonliang): Handle identity state elements by retrieving - // all state reads and verifying all reds are identity updates. + // all state reads and verifying all reads are identity updates. return false; })) { excluded.insert(groups.Find(state_read)); @@ -352,10 +357,19 @@ class UntupleVisitor : public DfsVisitorWithDefault { iter::zip(iter::count(), state_read_values, update_values)) { XLS_RET_CHECK(state_read_node->Is()); StateRead* state_read = state_read_node->As(); - XLS_RETURN_IF_ERROR(proc->MakeNodeWithName( - n->loc(), state_read, value, n->predicate(), - n->label(), IdxName(n, idx)) - .status()); + if (n->has_state_read()) { + XLS_RETURN_IF_ERROR(proc->MakeNodeWithName( + n->loc(), state_read, value, n->predicate(), + n->label(), IdxName(n, idx)) + .status()); + } else { + StateElement* state_element = state_read->state_element(); + XLS_RET_CHECK(state_element != nullptr); + XLS_RETURN_IF_ERROR(proc->MakeNodeWithName( + n->loc(), state_element, value, + n->predicate(), n->label(), IdxName(n, idx)) + .status()); + } } // Remove this next from consideration. if (n->value() != state_read) { diff --git a/xls/passes/array_untuple_pass_test.cc b/xls/passes/array_untuple_pass_test.cc index e27d01f18e..9f38454c30 100644 --- a/xls/passes/array_untuple_pass_test.cc +++ b/xls/passes/array_untuple_pass_test.cc @@ -525,6 +525,32 @@ TEST_F(ArrayUntuplePassTest, ProcStateArrayWithInvoke) { EXPECT_EQ(val2_before, val2_after); } +TEST_F(ArrayUntuplePassTest, ProcStateArrayIdentityNextWithStateElement) { + auto p = CreatePackage(); + ProcBuilder pb(TestName(), p.get()); + + XLS_ASSERT_OK_AND_ASSIGN( + Value init_val, + ValueBuilder::Array( + {ValueBuilder::Tuple({Value(UBits(0, 4)), Value(UBits(0, 8))}), + ValueBuilder::Tuple({Value(UBits(0, 4)), Value(UBits(0, 8))})}) + .Build()); + + XLS_ASSERT_OK_AND_ASSIGN(StateElement * state_elem, + pb.UnreadStateElement("my_state", init_val, + /*non_synthesizable=*/false)); + BValue state_read = pb.StateRead(state_elem); + + // Identity next (decoupled) + pb.Next(state_elem, state_read); + + XLS_ASSERT_OK_AND_ASSIGN(Proc * pr, pb.Build()); + ScopedRecordIr sri(p.get()); + EXPECT_THAT(RunPass(p.get()), IsOkAndHolds(false)); + EXPECT_THAT(pr->StateElements(), + UnorderedElementsAre(m::StateElement("my_state"))); +} + TEST_F(ArrayUntuplePassTest, ProcStateArrayNextWithStateElement) { auto p = CreatePackage(); ProcBuilder pb(TestName(), p.get()); @@ -556,6 +582,17 @@ TEST_F(ArrayUntuplePassTest, ProcStateArrayNextWithStateElement) { EXPECT_THAT(pr->StateElements(), IsSupersetOf({m::StateElement(_, m::Type("bits[4][2]")), m::StateElement(_, m::Type("bits[8][2]"))})); + + XLS_ASSERT_OK_AND_ASSIGN(StateElement * se0, pr->GetStateElementByName( + "my_state_tuple_element_0")); + XLS_ASSERT_OK_AND_ASSIGN(StateElement * se1, pr->GetStateElementByName( + "my_state_tuple_element_1")); + + // Verify next values are decoupled (no state_read operand) + EXPECT_THAT(pr->next_values(se0), + UnorderedElementsAre(m::NextWithStateElement(se0, _))); + EXPECT_THAT(pr->next_values(se1), + UnorderedElementsAre(m::NextWithStateElement(se1, _))); } void IrFuzzArrayUntuple(FuzzPackageWithArgs fuzz_package_with_args) { diff --git a/xls/passes/next_value_optimization_pass.cc b/xls/passes/next_value_optimization_pass.cc index 62ba160add..8f48946490 100644 --- a/xls/passes/next_value_optimization_pass.cc +++ b/xls/passes/next_value_optimization_pass.cc @@ -124,20 +124,42 @@ RemoveConstantPredicate( } VLOG(2) << "Identified node as always live; removing predicate: " << next; IdenticalNexts new_next; - XLS_ASSIGN_OR_RETURN(new_next.main, - next.main->ReplaceUsesWithNew( - /*state_read=*/next.main->state_read(), - /*value=*/next.main->value(), - /*predicate=*/std::nullopt, - /*label=*/next.main->label())); + if (next.main->has_state_read()) { + XLS_ASSIGN_OR_RETURN(new_next.main, + next.main->ReplaceUsesWithNew( + /*state_read=*/next.main->state_read(), + /*value=*/next.main->value(), + /*predicate=*/std::nullopt, + /*label=*/next.main->label())); + } else { + XLS_ASSIGN_OR_RETURN(new_next.main, + next.main->ReplaceUsesWithNew( + /*state_element=*/next.main->state_element(), + /*value=*/next.main->value(), + /*predicate=*/std::nullopt, + /*label=*/next.main->label())); + } + if (next.non_synth) { - XLS_ASSIGN_OR_RETURN(new_next.non_synth, - (*next.non_synth) - ->ReplaceUsesWithNew( - /*state_read=*/(*next.non_synth)->state_read(), - /*value=*/(*next.non_synth)->value(), - /*predicate=*/std::nullopt, - /*label=*/(*next.non_synth)->label())); + if ((*next.non_synth)->has_state_read()) { + XLS_ASSIGN_OR_RETURN( + new_next.non_synth, + (*next.non_synth) + ->ReplaceUsesWithNew( + /*state_read=*/(*next.non_synth)->state_read(), + /*value=*/(*next.non_synth)->value(), + /*predicate=*/std::nullopt, + /*label=*/(*next.non_synth)->label())); + } else { + XLS_ASSIGN_OR_RETURN( + new_next.non_synth, + (*next.non_synth) + ->ReplaceUsesWithNew( + /*state_element=*/(*next.non_synth)->state_element(), + /*value=*/(*next.non_synth)->value(), + /*predicate=*/std::nullopt, + /*label=*/(*next.non_synth)->label())); + } } if (split_depth.contains(next)) { split_depth[new_next] = split_depth[next]; @@ -197,27 +219,58 @@ absl::StatusOr>> SplitSelect( Op::kAnd)); } std::string name = NodeNameFormat("%s_case_%d", next.main, i); - XLS_ASSIGN_OR_RETURN(new_next.main, - proc->MakeNodeWithName( - next.main->loc(), - /*state_read=*/next.main->state_read(), - /*value=*/selected_value.cases()[i], predicate, - /*label=*/next.main->label(), name)); - + if (next.main->has_state_read()) { + XLS_ASSIGN_OR_RETURN(new_next.main, + proc->MakeNodeWithName( + next.main->loc(), + /*state_read=*/next.main->state_read(), + /*value=*/selected_value.cases()[i], predicate, + /*label=*/next.main->label(), name)); + } else { + XLS_ASSIGN_OR_RETURN( + new_next.main, + proc->MakeNodeWithName(next.main->loc(), + /*state_element=*/ + next.main->state_element(), + /*value=*/selected_value.cases()[i], + /*predicate=*/predicate, + /*label=*/next.main->label(), name)); + } if (next.non_synth) { std::string non_synth_name = NodeNameFormat("%s_case_%d", *next.non_synth, i); // Change main pass-through updates to pass-through on the non-synth one // too. - Node* case_val = selected_value.cases()[i] == next.main->state_read() - ? (*next.non_synth)->state_read() - : selected_value.cases()[i]; - XLS_ASSIGN_OR_RETURN(new_next.non_synth, - proc->MakeNodeWithName( - (*next.non_synth)->loc(), - /*state_read=*/(*next.non_synth)->state_read(), - /*value=*/case_val, predicate, - (*next.non_synth)->label(), non_synth_name)); + Node* case_val; + if (next.main->has_state_read()) { + case_val = selected_value.cases()[i] == next.main->state_read() + ? (*next.non_synth)->state_read() + : selected_value.cases()[i]; + } else { + bool is_passthrough = + selected_value.cases()[i] == + proc->GetStateReadByStateElement(next.main->state_element()); + case_val = is_passthrough ? proc->GetStateReadByStateElement( + (*next.non_synth)->state_element()) + : selected_value.cases()[i]; + } + if ((*next.non_synth)->has_state_read()) { + XLS_ASSIGN_OR_RETURN(new_next.non_synth, + proc->MakeNodeWithName( + (*next.non_synth)->loc(), + /*state_read=*/(*next.non_synth)->state_read(), + /*value=*/case_val, predicate, + (*next.non_synth)->label(), non_synth_name)); + } else { + XLS_ASSIGN_OR_RETURN( + new_next.non_synth, + proc->MakeNodeWithName((*next.non_synth)->loc(), + /*state_element=*/ + (*next.non_synth)->state_element(), + /*value=*/case_val, predicate, + /*label=*/(*next.non_synth)->label(), + non_synth_name)); + } } new_next_values.push_back(new_next); split_depth[new_next] = depth; @@ -236,24 +289,56 @@ absl::StatusOr>> SplitSelect( IdenticalNexts new_next; std::string name = NodeNameConcat(next.main, "_default_case"); - XLS_ASSIGN_OR_RETURN( - new_next.main, - proc->MakeNodeWithName(next.main->loc(), - /*state_read=*/next.main->state_read(), - /*value=*/*selected_value.default_value(), - predicate, next.main->label(), name)); + if (next.main->has_state_read()) { + XLS_ASSIGN_OR_RETURN(new_next.main, + proc->MakeNodeWithName( + next.main->loc(), + /*state_read=*/next.main->state_read(), + /*value=*/*selected_value.default_value(), + predicate, next.main->label(), name)); + } else { + XLS_ASSIGN_OR_RETURN(new_next.main, + proc->MakeNodeWithName( + next.main->loc(), + /*state_element=*/ + next.main->state_element(), + /*value=*/*selected_value.default_value(), + /*predicate=*/predicate, + /*label=*/next.main->label(), name)); + } if (next.non_synth) { std::string non_synth_name = NodeNameConcat(*next.non_synth, "_default_case"); - Node* value = *selected_value.default_value() == next.main->state_read() - ? (*next.non_synth)->state_read() - : *selected_value.default_value(); - XLS_ASSIGN_OR_RETURN( - new_next.non_synth, - proc->MakeNodeWithName( - (*next.non_synth)->loc(), - /*state_read=*/(*next.non_synth)->state_read(), value, predicate, - (*next.non_synth)->label(), non_synth_name)); + Node* value; + if (next.main->has_state_read()) { + value = *selected_value.default_value() == next.main->state_read() + ? (*next.non_synth)->state_read() + : *selected_value.default_value(); + } else { + bool is_passthrough = + *selected_value.default_value() == + proc->GetStateReadByStateElement(next.main->state_element()); + value = is_passthrough ? proc->GetStateReadByStateElement( + (*next.non_synth)->state_element()) + : *selected_value.default_value(); + } + if ((*next.non_synth)->has_state_read()) { + XLS_ASSIGN_OR_RETURN( + new_next.non_synth, + proc->MakeNodeWithName( + (*next.non_synth)->loc(), + /*state_read=*/(*next.non_synth)->state_read(), value, + predicate, (*next.non_synth)->label(), non_synth_name)); + } else { + XLS_ASSIGN_OR_RETURN( + new_next.non_synth, + proc->MakeNodeWithName( + (*next.non_synth)->loc(), + /*state_element=*/ + (*next.non_synth)->state_element(), + /*value=*/value, /*predicate=*/predicate, + /*label=*/(*next.non_synth)->label(), non_synth_name)); + } } new_next_values.push_back(new_next); split_depth[new_next] = depth; @@ -400,10 +485,16 @@ absl::StatusOr NextValueOptimizationPass::RunOnProcInternal( } else { absl::flat_hash_map, Next*> nonsynth_nexts; for (Next* next : proc->next_values(*non_synth)) { - nonsynth_nexts[next->operands().subspan(1)] = next; + if (next->has_state_read()) { + nonsynth_nexts[next->operands().subspan(1)] = next; + } else { + nonsynth_nexts[next->operands()] = next; + } } for (Next* next : proc->next_values(elem)) { - auto it = nonsynth_nexts.find(next->operands().subspan(1)); + auto it = next->has_state_read() + ? nonsynth_nexts.find(next->operands().subspan(1)) + : nonsynth_nexts.find(next->operands()); XLS_RET_CHECK(it != nonsynth_nexts.end()) << "Unable to find corresponding non-synth next for " << next; worklist.push_back({.main = next, .non_synth = it->second}); diff --git a/xls/passes/next_value_optimization_pass_test.cc b/xls/passes/next_value_optimization_pass_test.cc index ca04e72f17..147b7f7863 100644 --- a/xls/passes/next_value_optimization_pass_test.cc +++ b/xls/passes/next_value_optimization_pass_test.cc @@ -96,6 +96,21 @@ TEST_F(NextValueOptimizationPassTest, DeadNextValue) { EXPECT_THAT(proc->next_values(), IsEmpty()); } +TEST_F(NextValueOptimizationPassTest, DecoupledDeadNextValue) { + auto p = CreatePackage(); + ProcBuilder pb("pb", p.get()); + XLS_ASSERT_OK_AND_ASSIGN(StateElement * x_element, + pb.UnreadStateElement("x", Value(UBits(0, 32)), + /*non_synthesizable=*/false)); + pb.StateRead(x_element); + pb.Next(/*state_element=*/x_element, + /*value=*/pb.Literal(UBits(5, 32)), + /*pred=*/pb.Literal(UBits(0, 1))); + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + EXPECT_THAT(Run(p.get()), IsOkAndHolds(true)); + EXPECT_THAT(proc->next_values(), IsEmpty()); +} + TEST_F(NextValueOptimizationPassTest, NextValuesWithLiteralPredicates) { auto p = CreatePackage(); ProcBuilder pb("p", p.get()); @@ -113,6 +128,28 @@ TEST_F(NextValueOptimizationPassTest, NextValuesWithLiteralPredicates) { ElementsAre(m::Next(m::StateRead(), m::Literal(3)))); } +TEST_F(NextValueOptimizationPassTest, + DecoupledNextValuesWithLiteralPredicates) { + auto p = CreatePackage(); + ProcBuilder pb("pb", p.get()); + XLS_ASSERT_OK_AND_ASSIGN(StateElement * x_element, + pb.UnreadStateElement("x", Value(UBits(0, 32)), + /*non_synthesizable=*/false)); + pb.StateRead(x_element); + pb.Next(/*state_element=*/x_element, + /*value=*/pb.Literal(UBits(5, 32)), + /*pred=*/pb.Literal(UBits(0, 1))); + pb.Next(/*state_element=*/x_element, + /*value=*/pb.Literal(UBits(3, 32)), + /*pred=*/pb.Literal(UBits(1, 1))); + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + solvers::ScopedVerifyProcEquivalence svpe(proc, /*activation_count=*/3, + /*include_state=*/true); + EXPECT_THAT(Run(p.get()), IsOkAndHolds(true)); + EXPECT_THAT(proc->next_values(), + ElementsAre(m::NextWithStateElement(x_element, m::Literal(3)))); +} + // Clarify that the label should be propagated through priority select. TEST_F(NextValueOptimizationPassTest, NextValuesWithLabels) { auto p = CreatePackage(); @@ -178,6 +215,39 @@ TEST_F(NextValueOptimizationPassTest, PrioritySelectNextValue) { m::Eq(m::StateRead(), m::Literal(0))))); } +TEST_F(NextValueOptimizationPassTest, DecoupledPrioritySelectNextValue) { + auto p = CreatePackage(); + ProcBuilder pb("p", p.get()); + XLS_ASSERT_OK_AND_ASSIGN(StateElement * x_element, + pb.UnreadStateElement("x", Value(UBits(0, 3)), + /*non_synthesizable=*/false)); + BValue x_read = pb.StateRead(x_element); + BValue priority_select = pb.PrioritySelect( + x_read, + std::vector({pb.Literal(UBits(2, 3)), pb.Literal(UBits(1, 3)), + pb.Literal(UBits(2, 3))}), + pb.Literal(UBits(0, 3))); + pb.Next(/*state_element=*/x_element, /*value=*/priority_select); + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + solvers::ScopedVerifyProcEquivalence svpe(proc, /*activation_count=*/3, true); + + EXPECT_THAT(Run(p.get()), IsOkAndHolds(true)); + EXPECT_THAT( + proc->next_values(), + UnorderedElementsAre( + m::NextWithStateElement( + x_element, m::Literal(2), + m::Eq(m::BitSlice(m::StateRead(), 0, 1), m::Literal(0b1))), + m::NextWithStateElement( + x_element, m::Literal(1), + m::Eq(m::BitSlice(m::StateRead(), 0, 2), m::Literal(0b10))), + m::NextWithStateElement( + x_element, m::Literal(2), + m::Eq(m::BitSlice(m::StateRead(), 0, 3), m::Literal(0b100))), + m::NextWithStateElement(x_element, m::Literal(0), + m::Eq(m::StateRead(), m::Literal(0))))); +} + TEST_F(NextValueOptimizationPassTest, OneHotSelectNextValue) { auto p = CreatePackage(); ProcBuilder pb("p", p.get()); @@ -256,6 +326,37 @@ TEST_F(NextValueOptimizationPassTest, SmallSelectNextValueWithDefault) { m::UGt(m::StateRead(), m::Literal(2))))); } +TEST_F(NextValueOptimizationPassTest, + DecoupledSmallSelectNextValueWithDefault) { + auto p = CreatePackage(); + ProcBuilder pb("p", p.get()); + XLS_ASSERT_OK_AND_ASSIGN(StateElement * x_element, + pb.UnreadStateElement("x", Value(UBits(0, 2)), + /*non_synthesizable=*/false)); + BValue x_read = pb.StateRead(x_element); + BValue select = + pb.Select(x_read, + std::vector{pb.Literal(UBits(2, 2)), pb.Literal(UBits(1, 2)), + pb.Literal(UBits(2, 2))}, + /*default_value=*/pb.Literal(UBits(3, 2))); + pb.Next(/*state_element=*/x_element, /*value=*/select); + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + solvers::ScopedVerifyProcEquivalence svpe(proc, /*activation_count=*/3, + /*include_state=*/true); + EXPECT_THAT(Run(p.get(), /*split_next_value_selects=*/4), IsOkAndHolds(true)); + EXPECT_THAT( + proc->next_values(), + UnorderedElementsAre( + m::NextWithStateElement(x_element, m::Literal(2), + m::Eq(m::StateRead(), m::Literal(0))), + m::NextWithStateElement(x_element, m::Literal(1), + m::Eq(m::StateRead(), m::Literal(1))), + m::NextWithStateElement(x_element, m::Literal(2), + m::Eq(m::StateRead(), m::Literal(2))), + m::NextWithStateElement(x_element, m::Literal(3), + m::UGt(m::StateRead(), m::Literal(2))))); +} + TEST_F(NextValueOptimizationPassTest, BigSelectNextValue) { auto p = CreatePackage(); ProcBuilder pb("p", p.get()); diff --git a/xls/passes/proc_state_optimization_pass_test.cc b/xls/passes/proc_state_optimization_pass_test.cc index b8b28ebba9..1e3b6c8ea7 100644 --- a/xls/passes/proc_state_optimization_pass_test.cc +++ b/xls/passes/proc_state_optimization_pass_test.cc @@ -180,6 +180,37 @@ TEST_P(ProcStateOptimizationPassTest, ProcWithDeadElements) { EXPECT_EQ(proc->GetStateElement(0)->name(), "x"); } +TEST_F(BaseProcStateOptimizationPassTest, DecoupledDeadElements) { + auto p = CreatePackage(); + XLS_ASSERT_OK_AND_ASSIGN( + Channel * out, p->CreateStreamingChannel("out", ChannelOps::kSendOnly, + p->GetBitsType(32))); + + TokenlessProcBuilder pb("p", "tkn", p.get()); + + // Register 'x': Live (read is sent to channel) + XLS_ASSERT_OK_AND_ASSIGN(StateElement * x_element, + pb.UnreadStateElement("x", Value(UBits(0, 32)), + /*non_synthesizable=*/false)); + BValue x_read = pb.StateRead(x_element); + pb.Send(out, x_read); + pb.Next(x_element, pb.Not(x_read)); + + // Register 'y': Dead (has 1 read but it is unused, write is a constant) + XLS_ASSERT_OK_AND_ASSIGN(StateElement * y_element, + pb.UnreadStateElement("y", Value(UBits(0, 32)), + /*non_synthesizable=*/false)); + pb.StateRead(y_element); + pb.Next(y_element, pb.Literal(UBits(5, 32))); + + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + EXPECT_EQ(proc->GetStateElementCount(), 2); + EXPECT_THAT(Run(p.get()), IsOkAndHolds(true)); + // State element y gets cleaned up + EXPECT_EQ(proc->GetStateElementCount(), 1); + EXPECT_EQ(proc->GetStateElement(0)->name(), "x"); +} + TEST_P(ProcStateOptimizationPassTest, CrissCrossDeadElements) { auto p = CreatePackage(); TokenlessProcBuilder pb("p", "tkn", p.get()); diff --git a/xls/passes/proc_state_tuple_flattening_pass.cc b/xls/passes/proc_state_tuple_flattening_pass.cc index 5283361884..963bc55ada 100644 --- a/xls/passes/proc_state_tuple_flattening_pass.cc +++ b/xls/passes/proc_state_tuple_flattening_pass.cc @@ -25,6 +25,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" @@ -139,6 +140,7 @@ struct NextValue { Node* value; std::optional predicate; std::optional label; + bool has_state_read; }; struct AbstractStateElement { std::string name; @@ -175,13 +177,23 @@ absl::Status ReplaceProcState(Proc* proc, read->set_label(element.read_label); } for (const NextValue& next_value : element.next_values) { - XLS_RETURN_IF_ERROR(proc->MakeNodeWithName( - next_value.loc, - /*state_read=*/read, - /*value=*/next_value.value, - /*predicate=*/next_value.predicate, - /*label=*/next_value.label, next_value.name) - .status()); + if (next_value.has_state_read) { + XLS_RETURN_IF_ERROR(proc->MakeNodeWithName( + next_value.loc, + /*state_read=*/read, + /*value=*/next_value.value, + /*predicate=*/next_value.predicate, + /*label=*/next_value.label, next_value.name) + .status()); + } else { + XLS_RETURN_IF_ERROR(proc->MakeNodeWithName( + next_value.loc, + /*state_element=*/read->state_element(), + /*value=*/next_value.value, + /*predicate=*/next_value.predicate, + /*label=*/next_value.label, next_value.name) + .status()); + } XLS_RETURN_IF_ERROR(element.placeholder->ReplaceUsesWith(read)); } } @@ -192,6 +204,16 @@ absl::Status ReplaceProcState(Proc* proc, // array typed elements) in the proc state are flattened into their constituent // components. absl::Status FlattenState(Proc* proc) { + // Currently this pass assumes exactly one state read per state element. In + // the future, we could support multiple reads per state element. + for (StateElement* state_element : proc->StateElements()) { + XLS_RET_CHECK_EQ(proc->GetStateReadsByStateElement(state_element).size(), 1) + << absl::StreamFormat( + "State element '%s' has %d reads, but this pass only supports " + "exactly one read per state element.", + state_element->name(), + proc->GetStateReadsByStateElement(state_element).size()); + } std::vector identities; std::vector elements; @@ -268,6 +290,7 @@ absl::Status FlattenState(Proc* proc) { .value = value, .predicate = predicate, .label = next->label(), + .has_state_read = next->has_state_read(), }); } } diff --git a/xls/passes/proc_state_tuple_flattening_pass_test.cc b/xls/passes/proc_state_tuple_flattening_pass_test.cc index cb6b7e1bcb..fc3cc081f7 100644 --- a/xls/passes/proc_state_tuple_flattening_pass_test.cc +++ b/xls/passes/proc_state_tuple_flattening_pass_test.cc @@ -434,7 +434,7 @@ TEST_F(ProcStateFlatteningPassTest, BValue next1 = pb.Add(elem1, pb.Literal(UBits(2, 32))); BValue next_val = pb.Tuple({next0, next1}); - pb.Next(read, next_val, /*predicate=*/std::nullopt, + pb.Next(state, next_val, /*predicate=*/std::nullopt, /*label=*/"my_write_label"); XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); @@ -446,10 +446,9 @@ TEST_F(ProcStateFlatteningPassTest, EXPECT_THAT(proc->nodes(), Contains(m::StateRead( "state_0", std::optional("my_read_label")))); - EXPECT_THAT( - proc->nodes(), - Contains(m::NextWithLabel(m::StateRead(), _, - std::optional("my_write_label")))); + EXPECT_THAT(proc->nodes(), + Contains(m::NextWithStateElementWithLabel( + _, _, std::optional("my_write_label")))); } INSTANTIATE_TEST_SUITE_P(NextValueTypes, ProcStateFlatteningPassTest, diff --git a/xls/scheduling/BUILD b/xls/scheduling/BUILD index c3b177c41d..71b0286aa3 100644 --- a/xls/scheduling/BUILD +++ b/xls/scheduling/BUILD @@ -231,6 +231,8 @@ cc_test( "//xls/ir:function_builder", "//xls/ir:ir_test_base", "//xls/ir:value", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@googletest//:gtest", ], @@ -536,6 +538,8 @@ cc_test( "//xls/common/status:status_macros", "//xls/ir", "//xls/ir:bits", + "//xls/ir:channel", + "//xls/ir:channel_ops", "//xls/ir:function_builder", "//xls/ir:ir_matcher", "//xls/ir:ir_test_base", diff --git a/xls/scheduling/proc_state_legalization_pass.cc b/xls/scheduling/proc_state_legalization_pass.cc index 88f9e9535a..2d2320a6d3 100644 --- a/xls/scheduling/proc_state_legalization_pass.cc +++ b/xls/scheduling/proc_state_legalization_pass.cc @@ -49,6 +49,13 @@ namespace xls { namespace { +// Checks if a state element is decoupled, meaning it has exactly one next_value +// and no state reads. +bool IsDecoupled(Proc* proc, StateElement* state_element) { + const auto& next_values = proc->next_values(state_element); + return next_values.size() == 1 && !(*next_values.begin())->has_state_read(); +} + // Ensure that `state_read` is either unconditional or has a predicate that is // true whenever any of its corresponding `next_value`s are active. absl::StatusOr LegalizeStateReadPredicate( @@ -61,6 +68,9 @@ absl::StatusOr LegalizeStateReadPredicate( // Already unconditional, or no explicit `next_value`s; nothing to do. return false; } + if (IsDecoupled(proc, state_element)) { + return false; + } std::vector predicates; absl::flat_hash_set predicates_set; @@ -221,6 +231,10 @@ absl::StatusOr AddWriteWithoutReadAsserts( return false; } + if (IsDecoupled(proc, state_element)) { + return false; + } + const absl::btree_set& next_values = proc->next_values(state_element); if (next_values.empty()) { @@ -395,21 +409,33 @@ absl::StatusOr AddDefaultNextValue(Proc* proc, Node * default_predicate, NaryNorIfNeeded(proc, std::vector(predicates.begin(), predicates.end()), /*name=*/"", state_read->loc())); - if (state_read->predicate().has_value()) { - XLS_ASSIGN_OR_RETURN( - default_predicate, - proc->MakeNode( - state_read->loc(), - absl::MakeConstSpan({*state_read->predicate(), default_predicate}), - Op::kAnd)); + if (IsDecoupled(proc, state_element)) { + // Decoupled default next node predicate is only gated by other + // next_value predicates + XLS_RETURN_IF_ERROR( + proc->MakeNodeWithName( + state_read->loc(), /*state_element=*/state_element, + /*value=*/state_read, + /*predicate=*/default_predicate, /*label=*/std::nullopt, + /*name=*/absl::StrCat(state_element->name(), "_default")) + .status()); + } else { + if (state_read->predicate().has_value()) { + XLS_ASSIGN_OR_RETURN( + default_predicate, + proc->MakeNode(state_read->loc(), + absl::MakeConstSpan({*state_read->predicate(), + default_predicate}), + Op::kAnd)); + } + XLS_RETURN_IF_ERROR(proc->MakeNodeWithName( + state_read->loc(), /*state_read=*/state_read, + /*value=*/state_read, + /*predicate=*/default_predicate, + /*label=*/std::nullopt, + absl::StrCat(state_element->name(), "_default")) + .status()); } - XLS_RETURN_IF_ERROR(proc->MakeNodeWithName( - state_read->loc(), /*state_read=*/state_read, - /*value=*/state_read, - /*predicate=*/default_predicate, - /*label=*/std::nullopt, - absl::StrCat(state_element->name(), "_default")) - .status()); return true; } diff --git a/xls/scheduling/proc_state_legalization_pass_test.cc b/xls/scheduling/proc_state_legalization_pass_test.cc index 2bab5504a8..57e57b852e 100644 --- a/xls/scheduling/proc_state_legalization_pass_test.cc +++ b/xls/scheduling/proc_state_legalization_pass_test.cc @@ -27,6 +27,8 @@ #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" #include "xls/ir/bits.h" +#include "xls/ir/channel.h" +#include "xls/ir/channel_ops.h" #include "xls/ir/function_builder.h" #include "xls/ir/ir_matcher.h" #include "xls/ir/ir_test_base.h" @@ -552,6 +554,36 @@ TEST_P(ProcStateLegalizationPassTest, m::Literal(0)))))))); } +TEST_P(ProcStateLegalizationPassTest, DecoupledPredicatedNextValue) { + auto p = CreatePackage(); + ProcBuilder pb("p", p.get()); + XLS_ASSERT_OK_AND_ASSIGN( + Channel * in_read_pred, + p->CreateStreamingChannel("in_read_pred", ChannelOps::kReceiveOnly, + p->GetBitsType(1))); + BValue read_pred = pb.Receive(in_read_pred, pb.Literal(Value::Token())); + BValue read_pred_val = pb.TupleIndex(read_pred, 1); + XLS_ASSERT_OK_AND_ASSIGN(StateElement * x_element, + pb.UnreadStateElement("x", Value(UBits(0, 32)), + /*non_synthesizable=*/false)); + BValue x_read = pb.StateRead(x_element, read_pred_val); + BValue incremented = pb.Add(x_read, pb.Literal(UBits(1, 32))); + BValue write_pred = pb.Eq(x_read, pb.Literal(UBits(0, 32))); + pb.Next(x_element, incremented, write_pred); + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + XLS_ASSERT_OK(p->SetTop(proc)); + ASSERT_THAT(Run(proc), IsOkAndHolds(true)); + StateRead* read_node = proc->GetStateReadByStateElement(x_element); + EXPECT_EQ(*read_node->predicate(), read_pred_val.node()); + EXPECT_THAT(proc->next_values(), + UnorderedElementsAre( + m::NextWithStateElement(x_element, incremented.node(), + write_pred.node()), + m::NextWithStateElement(x_element, read_node, + m::Not(write_pred.node())))); + EXPECT_THAT(proc->nodes(), Each(Not(m::Assert()))); +} + INSTANTIATE_TEST_SUITE_P(ProcStateLegalizationPassTestSuite, ProcStateLegalizationPassTest, testing::Values(false, true)); diff --git a/xls/scheduling/sdc_scheduler.cc b/xls/scheduling/sdc_scheduler.cc index a97c3c751e..846a84d813 100644 --- a/xls/scheduling/sdc_scheduler.cc +++ b/xls/scheduling/sdc_scheduler.cc @@ -327,6 +327,10 @@ absl::Status SDCSchedulingModel::AddAllDefUseConstraints() { absl::StrFormat("state_read_before_write_%s_%s", read->GetName(), next->GetName())); + if (!next->has_state_read()) { + XLS_RETURN_IF_ERROR(AddThroughputConstraint(read, next)); + } + VLOG(2) << "Setting state read-before-write constraint: " << absl::StrFormat("cycle[%s] - cycle[%s] >= 0", next->GetName(), read->GetName()); @@ -347,10 +351,11 @@ absl::Status SDCSchedulingModel::AddDefUseConstraints( if (node->Is() && user.has_value() && user.value()->Is()) { Next* next = user.value()->As(); - if (next->state_read() == node) { + if (next->has_state_read() && next->state_read() == node) { XLS_RETURN_IF_ERROR(AddThroughputConstraint(node->As(), next)); } - if (next->value() != node && next->predicate() != node) { + if (next->has_state_read() && next->value() != node && + next->predicate() != node) { XLS_RET_CHECK_EQ(next->state_read(), node); // We don't need to keep the param's value alive to this user, so no need // for a lifetime constraint. diff --git a/xls/scheduling/sdc_scheduler_test.cc b/xls/scheduling/sdc_scheduler_test.cc index b2bea20c64..25f509a535 100644 --- a/xls/scheduling/sdc_scheduler_test.cc +++ b/xls/scheduling/sdc_scheduler_test.cc @@ -17,7 +17,10 @@ #include #include +#include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "xls/common/status/matchers.h" #include "xls/estimators/delay_model/delay_estimator.h" @@ -147,5 +150,57 @@ TEST_F(SDCSchedulerTest, WithIOConstraint) { EXPECT_EQ(cycle_map.at(send.node()), 1); } +TEST_F(SDCSchedulerTest, DecoupledFeedbackLoopInfeasible) { + auto p = CreatePackage(); + ProcBuilder pb(TestName(), p.get()); + XLS_ASSERT_OK_AND_ASSIGN(StateElement * x_element, + pb.UnreadStateElement("x", Value(UBits(0, 32)), + /*non_synthesizable=*/false)); + BValue x_read = pb.StateRead(x_element); + // Create a long chain of adds to increase the critical path length + // to 5. This means that for worst_case_throughput = 1, the constraint will be + // violated This test will fail for worst_case_throughput = 1, but pass for + // worst_case_throughput = 4 + + BValue add1 = pb.Add(x_read, pb.Literal(UBits(1, 32))); + BValue add2 = pb.Add(add1, pb.Literal(UBits(1, 32))); + BValue add3 = pb.Add(add2, pb.Literal(UBits(1, 32))); + BValue add4 = pb.Add(add3, pb.Literal(UBits(1, 32))); + BValue add5 = pb.Add(add4, pb.Literal(UBits(1, 32))); + pb.Next(x_element, add5); + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + XLS_ASSERT_OK_AND_ASSIGN(ScheduleGraph graph, + ScheduleGraph::Create(proc, {})); + TestDelayEstimator delay_estimator; + SchedulingOptions options; + + // 1. Verify scheduling fails under tight worst_case_throughput = 1 (requires + // cycle difference <= 0, but min is 3) + XLS_ASSERT_OK_AND_ASSIGN( + auto infeasible_scheduler, + SDCScheduler::Create(graph, delay_estimator, options)); + auto infeasible_result = infeasible_scheduler->Schedule( + std::nullopt, 2, + SchedulingFailureBehavior{.explain_infeasibility = false}, 1); + EXPECT_THAT(infeasible_result, + absl_testing::StatusIs( + absl::StatusCode::kInternal, + testing::HasSubstr("does not have an optimal solution"))); + + // 2. Verify scheduling succeeds under relaxed worst_case_throughput = 4 + // (allows cycle difference <= 3, which matches min of 3) + XLS_ASSERT_OK_AND_ASSIGN( + auto feasible_scheduler, + SDCScheduler::Create(graph, delay_estimator, options)); + XLS_ASSERT_OK_AND_ASSIGN( + ScheduleCycleMap feasible_cycle_map, + feasible_scheduler->Schedule( + std::nullopt, 2, + SchedulingFailureBehavior{.explain_infeasibility = false}, 4)); + EXPECT_EQ(feasible_cycle_map.at(*proc->next_values().begin()) - + feasible_cycle_map.at(x_read.node()), + 3); +} + } // namespace } // namespace xls