This following minimal StableHLO code exemplifies it:
- The
U8 value 0xE1 if fed as an input parameter and as a constant --> 0x1 nibble is 1 in S4, and 0xE becomes -2 in S4.
- Both values (presumably the same) are
%stablehlo.bitcast_convert()'ed to Int4
- And then converted to
Int8 for outputting/printing.
But the outputs are different!?
module @TestSubByteDTypes_Int4_AsLiteral {
func.func @main(%x: tensor<1xui8>) -> (tensor<1x2xi8>, tensor<1x2xi8>, tensor<1xi1>) {
%0 = "stablehlo.constant"() { value = dense<[225]> : tensor<1xui8> } : () -> tensor<1xui8>
%1 = "stablehlo.compare"(%0, %x) {
compare_type = #stablehlo<comparison_type UNSIGNED>,
comparison_direction = #stablehlo<comparison_direction EQ>
} : (tensor<1xui8>, tensor<1xui8>) -> tensor<1xi1>
%2 = "stablehlo.bitcast_convert"(%0) : (tensor<1xui8>) -> tensor<1x2xi4>
%3 = "stablehlo.convert"(%2) : (tensor<1x2xi4>) -> tensor<1x2xi8>
%4 = "stablehlo.bitcast_convert"(%x) : (tensor<1xui8>) -> tensor<1x2xi4>
%5 = "stablehlo.convert"(%4) : (tensor<1x2xi4>) -> tensor<1x2xi8>
"stablehlo.return"(%3, %5, %1) : (tensor<1x2xi8>, tensor<1x2xi8>, tensor<1xi1>) -> ()
}
}
If one feeds this function the same value as the constant 0xE1 (225 in decimal) one would expect them to be bitcast to Int4 values [1, -2]. But the outputs of the StableHLO program above are:
[[1, 0]] --> Fed as a constant, it ignores the upper nibble (upper 4 bits).
[[1, -2]] --> Same value fed as an input parameter, makes the correct stablehlo.bitcast_convert().
[true] --> A (in-)sanity check: the input parameter and the constant, both set to [0xE1], compare to true, even though they bitcast_convert to different values.
ps.: Notice PJRT PJRT_Client_BufferFromHostBuffer() method doesn't seem to accept buffers with packed S4 as inputs, so it requires feeding them as U8 and then %stablehlo.bitcast_convert them to S4.
This following minimal
StableHLOcode exemplifies it:U8value0xE1if fed as an input parameter and as a constant -->0x1nibble is 1 inS4, and0xEbecomes -2 inS4.%stablehlo.bitcast_convert()'ed toInt4Int8for outputting/printing.But the outputs are different!?
If one feeds this function the same value as the constant
0xE1(225 in decimal) one would expect them to be bitcast toInt4values [1, -2]. But the outputs of theStableHLOprogram above are:[[1, 0]]--> Fed as a constant, it ignores the upper nibble (upper 4 bits).[[1, -2]]--> Same value fed as an input parameter, makes the correctstablehlo.bitcast_convert().[true]--> A (in-)sanity check: the input parameter and the constant, both set to[0xE1], compare to true, even though theybitcast_convertto different values.ps.: Notice PJRT
PJRT_Client_BufferFromHostBuffer()method doesn't seem to accept buffers with packedS4as inputs, so it requires feeding them asU8and then%stablehlo.bitcast_convertthem toS4.