Conversation
6b8a952 to
6d0d300
Compare
|
I like the idea! Can we check for the presence of |
6d0d300 to
a24fe51
Compare
|
@rdyro Are you suggesting adding a flag |
e206e87 to
7956e45
Compare
|
I was thinking of using Python's I'm not sure about the name of this function now, |
4d48682 to
4bc029b
Compare
|
Hmmm, Can you test this on the following case: fn = lambda: tree_bytes(jnp.ones((1024,), dtype=jnp.int32))
fn()
jax.jit(fn)()It should report 512 in both cases |
|
@rdyro Shouldn't it be 4096? An int32 has 32 bits / 8 bits per byte = 4 bytes. And 1024 * 4 = 4096. I get 4096 in both cases. |
4bc029b to
780739d
Compare
|
Should have been fn = lambda: tree_bytes(jnp.ones((1024,), dtype=jnp.int4))
fn()
jax.jit(fn)() |
780739d to
85ef02b
Compare
|
For |
|
96697d4 to
53e365f
Compare
|
This still works for the most common dtypes, so let's go with the I added a warning to the function's docstring. Feel free to reword this warning. |
|
Let’s wait until we have a solution on the JAX side, I’ll keep the PR open for now. |
|
Will you be opening an issue for that? |
|
I'll follow up on this internally. For the JAX issue you opened, can you explicitly ask about the use case of getting int4 byte size under jit? |
53e365f to
44714c9
Compare
|
@rdyro In your opinion, what would be the ideal output of Perhaps this suggests that, more generally, we ought to be counting in bits rather than bytes? |
Currently on CPU, GPU and TPU the byte size of However, it's possible that a platform doesn't guarantee packing for int4, I don't think it's possible to have a jit-compatible function counting bytes currently. I believe users interested in RAM/VRAM size should use a custom lambda with I believe fp4 will suffer from the same problem as int4, I'm not sure there's a difference between integer or floating point representations. We'd typically make an assumption that 1 byte is 8 bits, so it shouldn't change the calculation and it doesn't solve the packing representation problem of fp4/int4. Perhaps this function could be I'd prefer not to merge this function into optax. I find the haiku version actively confusing when working with int4 quantized models. |
|
What does @vroulet think? A potential alternative is to say that we're interested only in how much information there is in a pytree (not how it will be packed or laid out on devices, which is hardware-dependent). We can do this by counting bits. A |
89d298f to
f3944d4
Compare
f3944d4 to
cbbd7ca
Compare
|
@rdyro I fixed the issue by switching to a |
cbbd7ca to
ad15aa0
Compare
ad15aa0 to
ebe94c1
Compare
Add a
tree_bytesfunction to the tree utilities, analogous to Haiku's.For context, see #1321 (comment).