Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 45 additions & 8 deletions lib/Transforms/HandshakeOptimizeBitwidths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,14 +373,51 @@ static ExtWidth addWidth(ExtWidth lhs, ExtWidth rhs) {

/// Transfer function for sub operations or alike.
static ExtWidth subWidth(ExtWidth lhs, ExtWidth rhs) {
// Subtraction logic can be broken down using other operations as follows:
// sub(a, b) = add(a, ~b + 1) = add(a, xor(b, sext(-1)) + 1)
// Applying the forward function from 'xor' we can conclude that rhs is
// always sign-extended if reduced in bitwidth.

// We apply the logic from 'add' here, but with the assumption that 'rhs' is
// SEXT.
return addWidth({lhs.extType, lhs.bitWidth}, {ExtType::SEXT, rhs.bitWidth});
// When the extension types match we have the following property:
// All bits beyond 'std::max(lhs.bitWidth, rhs.bitWidth) + 1' are duplicates
// of 'std::max(lhs.bitWidth, rhs.bitWidth) + 1' and can therefore be computed
// using just sign-extension.
//
// Proof for both operands being ZEXT:
// Note that subtraction is defined as 'add(a, ~b + 1)'
// There are two cases:
// 1) 'rhs' is 0, in which case negating it keeps the value as 0. Bits
// std::max(lhs.bitWidth, rhs.bitWidth) + 1 and above are guaranteed to be
// zero from the zero-extension of lhs. We can therefore perform the
// computation using just 'std::max(lhs.bitWidth, rhs.bitWidth) + 1' and
// zero-extend. (Note: For this subcase, we can even perform the
// computation in 'lhs.bitWidth' many bits and zero-extend).
// 2) 'rhs' is any other value. In that case there must exist at least one
// bit that is set and that when negated catches any carry-over from the
// negation (~operand + 1) operation. Bits
// 'max(lhs.bitWidth, rhs.bitWidth) + 1' onwards of the right operand
// are guaranteed to be all 1. The corresponding bits of the left operand
// are guaranteed to be all 0 from the zero-extension.
// Due to the carry logic of addition, all bits beyond
// 'max(lhs.bitWidth, rhs.bitWidth) + 1' are therefore guaranteed to be
// equal to bit 'max(lhs.bitWidth, rhs.bitWidth) + 1'.
//
// Proof when both operands are SEXT:
// After negation, all bits beyond 'max(lhs.bitWidth, rhs.bitWidth)' are
// guaranteed to be either all 0s or all 1s, in each operand respectively.
// For the case of all bits being all 0 in one operand and all 1s in the
// other see the ZEXT case above (case 2).
// In the case of all 0s of both operands, bit
// 'max(lhs.bitWidth, rhs.bitWidth) + 1' of the result is guaranteed to be 0
// and so are all bits beyond. The reason is that bits
// 'max(lhs.bitWidth, rhs.bitWidth)' of both operands must have been 0,
// otherwise the sign-extension wouldn't have made all subsequent bits 0.
// The final case of all bits beyond 'max(lhs.bitWidth, rhs.bitWidth)' being
// 1 in both operands follows similar logic as the 0 case. The only way they
// could have been all ones is if 'max(lhs.bitWidth, rhs.bitWidth)' is also
// one. Their addition leads to a carry to the next bit which propagates
// through the entire adder. All bits in the result from bits
// 'max(lhs.bitWidth, rhs.bitWidth) + 1' onwards are guaranteed to be 1.
if (lhs.extType == rhs.extType)
return {ExtType::SEXT, std::max(lhs.bitWidth, rhs.bitWidth) + 1};

// Otherwise, conservative add logic applies.
return addWidth(lhs, rhs);
}

/// Transfer function for mul operations or alike.
Expand Down
75 changes: 56 additions & 19 deletions test/Transforms/HandshakeOptimizeBitwidths/arith-forward.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,62 @@ handshake.func @subiFW(%arg0: !handshake.channel<i8>, %arg1: !handshake.channel<

// -----

// CHECK-LABEL: handshake.func @subiFW_ui_si(
// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel<i8>, %[[VAL_1:.*]]: !handshake.channel<i8>,
// CHECK-SAME: %[[VAL_2:.*]]: !handshake.control<>, ...) -> !handshake.channel<i32> attributes {argNames = ["arg0", "arg1", "start"], resNames = ["out0"]} {
// CHECK: %[[VAL_3:.*]] = extui %[[VAL_1]] {handshake.bb = 0 : ui32} : <i8> to <i10>
// CHECK: %[[VAL_4:.*]] = extsi %[[VAL_0]] {handshake.bb = 0 : ui32} : <i8> to <i10>
// CHECK: %[[VAL_5:.*]] = subi %[[VAL_4]], %[[VAL_3]] : <i10>
// CHECK: %[[VAL_6:.*]] = extsi %[[VAL_5]] : <i10> to <i32>
// CHECK: end %[[VAL_6]] : <i32>
// CHECK: }
handshake.func @subiFW_ui_si(%arg0: !handshake.channel<i8>, %arg1: !handshake.channel<i8>, %start: !handshake.control<>) -> !handshake.channel<i32> {
%ext0 = extsi %arg0 : <i8> to <i32>
%ext1 = extui %arg1 : <i8> to <i32>
%res = subi %ext0, %ext1 : <i32>
end %res : <i32>
}

// -----

// CHECK-LABEL: handshake.func @subiFW_si_ui(
// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel<i8>,
// CHECK-SAME: %[[VAL_1:.*]]: !handshake.channel<i16>,
// CHECK-SAME: %[[VAL_2:.*]]: !handshake.control<>, ...) -> !handshake.channel<i32> attributes {argNames = ["arg0", "arg1", "start"], resNames = ["out0"]} {
// CHECK: %[[VAL_3:.*]] = extui %[[VAL_1]] {handshake.bb = 0 : ui32} : <i16> to <i18>
// CHECK: %[[VAL_4:.*]] = extsi %[[VAL_0]] {handshake.bb = 0 : ui32} : <i8> to <i18>
// CHECK: %[[VAL_5:.*]] = subi %[[VAL_4]], %[[VAL_3]] : <i18>
// CHECK: %[[VAL_6:.*]] = extsi %[[VAL_5]] : <i18> to <i32>
// CHECK: end %[[VAL_6]] : <i32>
// CHECK: }
handshake.func @subiFW_si_ui(%arg0: !handshake.channel<i8>, %arg1: !handshake.channel<i16>, %start: !handshake.control<>) -> !handshake.channel<i32> {
%ext0 = extsi %arg0 : <i8> to <i32>
%ext1 = extui %arg1 : <i16> to <i32>
%res = subi %ext0, %ext1 : <i32>
end %res : <i32>
}

// -----

// CHECK-LABEL: handshake.func @subiFW_ui_ui(
// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel<i8>,
// CHECK-SAME: %[[VAL_1:.*]]: !handshake.channel<i16>,
// CHECK-SAME: %[[VAL_2:.*]]: !handshake.control<>, ...) -> !handshake.channel<i32> attributes {argNames = ["arg0", "arg1", "start"], resNames = ["out0"]} {
// CHECK: %[[VAL_3:.*]] = extui %[[VAL_1]] {handshake.bb = 0 : ui32} : <i16> to <i17>
// CHECK: %[[VAL_4:.*]] = extui %[[VAL_0]] {handshake.bb = 0 : ui32} : <i8> to <i17>
// CHECK: %[[VAL_5:.*]] = subi %[[VAL_4]], %[[VAL_3]] : <i17>
// CHECK: %[[VAL_6:.*]] = extsi %[[VAL_5]] : <i17> to <i32>
// CHECK: end %[[VAL_6]] : <i32>
// CHECK: }
handshake.func @subiFW_ui_ui(%arg0: !handshake.channel<i8>, %arg1: !handshake.channel<i16>, %start: !handshake.control<>) -> !handshake.channel<i32> {
%ext0 = extui %arg0 : <i8> to <i32>
%ext1 = extui %arg1 : <i16> to <i32>
%res = subi %ext0, %ext1 : <i32>
end %res : <i32>
}

// -----

// CHECK-LABEL: handshake.func @muliFW(
// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel<i8>,
// CHECK-SAME: %[[VAL_1:.*]]: !handshake.channel<i16>,
Expand Down Expand Up @@ -505,25 +561,6 @@ handshake.func @cmpi_ui_si2(%arg0: !handshake.channel<i8>, %arg1: !handshake.cha

// -----

// CHECK-LABEL: handshake.func @subiFW_extui(
// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel<i8>,
// CHECK-SAME: %[[VAL_1:.*]]: !handshake.channel<i16>,
// CHECK-SAME: %[[VAL_2:.*]]: !handshake.control<>, ...) -> !handshake.channel<i32> attributes {argNames = ["arg0", "arg1", "start"], resNames = ["out0"]} {
// CHECK: %[[VAL_3:.*]] = extui %[[VAL_1]] {handshake.bb = 0 : ui32} : <i16> to <i17>
// CHECK: %[[VAL_4:.*]] = extsi %[[VAL_0]] {handshake.bb = 0 : ui32} : <i8> to <i17>
// CHECK: %[[VAL_5:.*]] = subi %[[VAL_4]], %[[VAL_3]] : <i17>
// CHECK: %[[VAL_6:.*]] = extsi %[[VAL_5]] : <i17> to <i32>
// CHECK: end %[[VAL_6]] : <i32>
// CHECK: }
handshake.func @subiFW_extui(%arg0: !handshake.channel<i8>, %arg1: !handshake.channel<i16>, %start: !handshake.control<>) -> !handshake.channel<i32> {
%ext0 = extsi %arg0 : <i8> to <i32>
%ext1 = extui %arg1 : <i16> to <i32>
%res = subi %ext0, %ext1 : <i32>
end %res : <i32>
}

// -----

// CHECK-LABEL: handshake.func @cmpi_sge_ui_ui(
// CHECK-SAME: %[[VAL_0:.*]]: !handshake.channel<i8>, %[[VAL_1:.*]]: !handshake.channel<i8>,
// CHECK-SAME: %[[VAL_2:.*]]: !handshake.control<>, ...) -> !handshake.channel<i1> attributes {argNames = ["arg0", "arg1", "start"], resNames = ["out0"]} {
Expand Down
Loading