Skip to content

bitcast_convert of packed constant U8 (uint8) to "packed" types (S4, S2, U4, U2) fails #38964

@janpfeifer

Description

@janpfeifer

This following minimal StableHLO code exemplifies it:

  • The U8 value 0xE1 if fed as an input parameter and as a constant --> 0x1 nibble is 1 in S4, and 0xE becomes -2 in S4.
  • Both values (presumably the same) are %stablehlo.bitcast_convert()'ed to Int4
  • And then converted to Int8 for outputting/printing.

But the outputs are different!?

module @TestSubByteDTypes_Int4_AsLiteral {
  func.func @main(%x: tensor<1xui8>) -> (tensor<1x2xi8>, tensor<1x2xi8>, tensor<1xi1>) {
    %0 = "stablehlo.constant"() { value = dense<[225]> : tensor<1xui8> } : () -> tensor<1xui8>
    %1 = "stablehlo.compare"(%0, %x) {
      compare_type = #stablehlo<comparison_type UNSIGNED>,
      comparison_direction = #stablehlo<comparison_direction EQ>
    } : (tensor<1xui8>, tensor<1xui8>) -> tensor<1xi1>
    %2 = "stablehlo.bitcast_convert"(%0) : (tensor<1xui8>) -> tensor<1x2xi4>
    %3 = "stablehlo.convert"(%2) : (tensor<1x2xi4>) -> tensor<1x2xi8>
    %4 = "stablehlo.bitcast_convert"(%x) : (tensor<1xui8>) -> tensor<1x2xi4>
    %5 = "stablehlo.convert"(%4) : (tensor<1x2xi4>) -> tensor<1x2xi8>
    "stablehlo.return"(%3, %5, %1) : (tensor<1x2xi8>, tensor<1x2xi8>, tensor<1xi1>) -> ()
  }
}

If one feeds this function the same value as the constant 0xE1 (225 in decimal) one would expect them to be bitcast to Int4 values [1, -2]. But the outputs of the StableHLO program above are:

  1. [[1, 0]] --> Fed as a constant, it ignores the upper nibble (upper 4 bits).
  2. [[1, -2]] --> Same value fed as an input parameter, makes the correct stablehlo.bitcast_convert().
  3. [true] --> A (in-)sanity check: the input parameter and the constant, both set to [0xE1], compare to true, even though they bitcast_convert to different values.

ps.: Notice PJRT PJRT_Client_BufferFromHostBuffer() method doesn't seem to accept buffers with packed S4 as inputs, so it requires feeding them as U8 and then %stablehlo.bitcast_convert them to S4.

Metadata

Metadata

Labels

bugSomething isn't workingstat:awaiting openxla-engAwaiting response from openxla-eng

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions