From 73c783b0753809fe0841ef8587636446af6dbadd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20B=C3=B6ck?= Date: Thu, 21 May 2026 13:55:42 +0200 Subject: [PATCH] [HandshakeOptimizeBitwidths] Fix `subi` bitwidth calculation The previous logic was incorrect in the case of mixed `extui` and `extsi` operands, using one bit too little to perfrom the calculation. This PR tries to fix these holes in the `subi` logic while also keeping it as optimized as possible by reusing as much of the logic of `addi` when it applies and special casing the more optimized case when both extension types match. Quite a lot of time was spent writing down somewhat of a proof of correctness in addition to testing in alive2. Fixes https://github.com/EPFL-LAP/dynamatic/issues/927 --- lib/Transforms/HandshakeOptimizeBitwidths.cpp | 53 +++++++++++-- .../arith-forward.mlir | 75 ++++++++++++++----- 2 files changed, 101 insertions(+), 27 deletions(-) diff --git a/lib/Transforms/HandshakeOptimizeBitwidths.cpp b/lib/Transforms/HandshakeOptimizeBitwidths.cpp index d80c94e45..f8deb6303 100644 --- a/lib/Transforms/HandshakeOptimizeBitwidths.cpp +++ b/lib/Transforms/HandshakeOptimizeBitwidths.cpp @@ -373,14 +373,51 @@ static ExtWidth addWidth(ExtWidth lhs, ExtWidth rhs) { /// Transfer function for sub operations or alike. static ExtWidth subWidth(ExtWidth lhs, ExtWidth rhs) { - // Subtraction logic can be broken down using other operations as follows: - // sub(a, b) = add(a, ~b + 1) = add(a, xor(b, sext(-1)) + 1) - // Applying the forward function from 'xor' we can conclude that rhs is - // always sign-extended if reduced in bitwidth. - - // We apply the logic from 'add' here, but with the assumption that 'rhs' is - // SEXT. - return addWidth({lhs.extType, lhs.bitWidth}, {ExtType::SEXT, rhs.bitWidth}); + // When the extension types match we have the following property: + // All bits beyond 'std::max(lhs.bitWidth, rhs.bitWidth) + 1' are duplicates + // of 'std::max(lhs.bitWidth, rhs.bitWidth) + 1' and can therefore be computed + // using just sign-extension. + // + // Proof for both operands being ZEXT: + // Note that subtraction is defined as 'add(a, ~b + 1)' + // There are two cases: + // 1) 'rhs' is 0, in which case negating it keeps the value as 0. Bits + // std::max(lhs.bitWidth, rhs.bitWidth) + 1 and above are guaranteed to be + // zero from the zero-extension of lhs. We can therefore perform the + // computation using just 'std::max(lhs.bitWidth, rhs.bitWidth) + 1' and + // zero-extend. (Note: For this subcase, we can even perform the + // computation in 'lhs.bitWidth' many bits and zero-extend). + // 2) 'rhs' is any other value. In that case there must exist at least one + // bit that is set and that when negated catches any carry-over from the + // negation (~operand + 1) operation. Bits + // 'max(lhs.bitWidth, rhs.bitWidth) + 1' onwards of the right operand + // are guaranteed to be all 1. The corresponding bits of the left operand + // are guaranteed to be all 0 from the zero-extension. + // Due to the carry logic of addition, all bits beyond + // 'max(lhs.bitWidth, rhs.bitWidth) + 1' are therefore guaranteed to be + // equal to bit 'max(lhs.bitWidth, rhs.bitWidth) + 1'. + // + // Proof when both operands are SEXT: + // After negation, all bits beyond 'max(lhs.bitWidth, rhs.bitWidth)' are + // guaranteed to be either all 0s or all 1s, in each operand respectively. + // For the case of all bits being all 0 in one operand and all 1s in the + // other see the ZEXT case above (case 2). + // In the case of all 0s of both operands, bit + // 'max(lhs.bitWidth, rhs.bitWidth) + 1' of the result is guaranteed to be 0 + // and so are all bits beyond. The reason is that bits + // 'max(lhs.bitWidth, rhs.bitWidth)' of both operands must have been 0, + // otherwise the sign-extension wouldn't have made all subsequent bits 0. + // The final case of all bits beyond 'max(lhs.bitWidth, rhs.bitWidth)' being + // 1 in both operands follows similar logic as the 0 case. The only way they + // could have been all ones is if 'max(lhs.bitWidth, rhs.bitWidth)' is also + // one. Their addition leads to a carry to the next bit which propagates + // through the entire adder. All bits in the result from bits + // 'max(lhs.bitWidth, rhs.bitWidth) + 1' onwards are guaranteed to be 1. + if (lhs.extType == rhs.extType) + return {ExtType::SEXT, std::max(lhs.bitWidth, rhs.bitWidth) + 1}; + + // Otherwise, conservative add logic applies. + return addWidth(lhs, rhs); } /// Transfer function for mul operations or alike. diff --git a/test/Transforms/HandshakeOptimizeBitwidths/arith-forward.mlir b/test/Transforms/HandshakeOptimizeBitwidths/arith-forward.mlir index 6d5aa181c..8e769f5b1 100644 --- a/test/Transforms/HandshakeOptimizeBitwidths/arith-forward.mlir +++ b/test/Transforms/HandshakeOptimizeBitwidths/arith-forward.mlir @@ -96,6 +96,62 @@ handshake.func @subiFW(%arg0: !handshake.channel, %arg1: !handshake.channel< // ----- +// CHECK-LABEL: handshake.func @subiFW_ui_si( +// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, %[[VAL_1:.*]]: !handshake.channel, +// CHECK-SAME: %[[VAL_2:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "arg1", "start"], resNames = ["out0"]} { +// CHECK: %[[VAL_3:.*]] = extui %[[VAL_1]] {handshake.bb = 0 : ui32} : to +// CHECK: %[[VAL_4:.*]] = extsi %[[VAL_0]] {handshake.bb = 0 : ui32} : to +// CHECK: %[[VAL_5:.*]] = subi %[[VAL_4]], %[[VAL_3]] : +// CHECK: %[[VAL_6:.*]] = extsi %[[VAL_5]] : to +// CHECK: end %[[VAL_6]] : +// CHECK: } +handshake.func @subiFW_ui_si(%arg0: !handshake.channel, %arg1: !handshake.channel, %start: !handshake.control<>) -> !handshake.channel { + %ext0 = extsi %arg0 : to + %ext1 = extui %arg1 : to + %res = subi %ext0, %ext1 : + end %res : +} + +// ----- + +// CHECK-LABEL: handshake.func @subiFW_si_ui( +// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, +// CHECK-SAME: %[[VAL_1:.*]]: !handshake.channel, +// CHECK-SAME: %[[VAL_2:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "arg1", "start"], resNames = ["out0"]} { +// CHECK: %[[VAL_3:.*]] = extui %[[VAL_1]] {handshake.bb = 0 : ui32} : to +// CHECK: %[[VAL_4:.*]] = extsi %[[VAL_0]] {handshake.bb = 0 : ui32} : to +// CHECK: %[[VAL_5:.*]] = subi %[[VAL_4]], %[[VAL_3]] : +// CHECK: %[[VAL_6:.*]] = extsi %[[VAL_5]] : to +// CHECK: end %[[VAL_6]] : +// CHECK: } +handshake.func @subiFW_si_ui(%arg0: !handshake.channel, %arg1: !handshake.channel, %start: !handshake.control<>) -> !handshake.channel { + %ext0 = extsi %arg0 : to + %ext1 = extui %arg1 : to + %res = subi %ext0, %ext1 : + end %res : +} + +// ----- + +// CHECK-LABEL: handshake.func @subiFW_ui_ui( +// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, +// CHECK-SAME: %[[VAL_1:.*]]: !handshake.channel, +// CHECK-SAME: %[[VAL_2:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "arg1", "start"], resNames = ["out0"]} { +// CHECK: %[[VAL_3:.*]] = extui %[[VAL_1]] {handshake.bb = 0 : ui32} : to +// CHECK: %[[VAL_4:.*]] = extui %[[VAL_0]] {handshake.bb = 0 : ui32} : to +// CHECK: %[[VAL_5:.*]] = subi %[[VAL_4]], %[[VAL_3]] : +// CHECK: %[[VAL_6:.*]] = extsi %[[VAL_5]] : to +// CHECK: end %[[VAL_6]] : +// CHECK: } +handshake.func @subiFW_ui_ui(%arg0: !handshake.channel, %arg1: !handshake.channel, %start: !handshake.control<>) -> !handshake.channel { + %ext0 = extui %arg0 : to + %ext1 = extui %arg1 : to + %res = subi %ext0, %ext1 : + end %res : +} + +// ----- + // CHECK-LABEL: handshake.func @muliFW( // CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, // CHECK-SAME: %[[VAL_1:.*]]: !handshake.channel, @@ -505,25 +561,6 @@ handshake.func @cmpi_ui_si2(%arg0: !handshake.channel, %arg1: !handshake.cha // ----- -// CHECK-LABEL: handshake.func @subiFW_extui( -// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, -// CHECK-SAME: %[[VAL_1:.*]]: !handshake.channel, -// CHECK-SAME: %[[VAL_2:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "arg1", "start"], resNames = ["out0"]} { -// CHECK: %[[VAL_3:.*]] = extui %[[VAL_1]] {handshake.bb = 0 : ui32} : to -// CHECK: %[[VAL_4:.*]] = extsi %[[VAL_0]] {handshake.bb = 0 : ui32} : to -// CHECK: %[[VAL_5:.*]] = subi %[[VAL_4]], %[[VAL_3]] : -// CHECK: %[[VAL_6:.*]] = extsi %[[VAL_5]] : to -// CHECK: end %[[VAL_6]] : -// CHECK: } -handshake.func @subiFW_extui(%arg0: !handshake.channel, %arg1: !handshake.channel, %start: !handshake.control<>) -> !handshake.channel { - %ext0 = extsi %arg0 : to - %ext1 = extui %arg1 : to - %res = subi %ext0, %ext1 : - end %res : -} - -// ----- - // CHECK-LABEL: handshake.func @cmpi_sge_ui_ui( // CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel, %[[VAL_1:.*]]: !handshake.channel, // CHECK-SAME: %[[VAL_2:.*]]: !handshake.control<>, ...) -> !handshake.channel attributes {argNames = ["arg0", "arg1", "start"], resNames = ["out0"]} {