@eightysteele
On the gen3d branch, we are running into issues upgrading JAX to 0.4.31. We want to do this because per the top answer here, starting with JAX 0.4.31, the function jax.map.lax supports a batch_size argument. We would like to include this batch_size parameters at this line. One way to test that this works is to ensure that all the tests in tests/gen3d/ pass after adding batch_size=1000 at the indicated line. More specifically, if this line runs without failing, we should be good to go. [In this blob it looks like we commented this out -- oops! But this test should be uncommented and work.]
When we try changing this line to request jaxlib ==0.4.31, and run pixi install, we get
(gpu) georgematheos@pixi-vm-2:~/b3d$ pixi install
WARN Defined custom mapping channel https://conda.anaconda.org/conda-forge/ is missing from project channels
× failed to solve the conda requirements of 'gpu' 'linux-64'
╰─▶ Cannot solve the request because of: The following packages are incompatible
├─ pytorch ==2.3.0 cuda12* can be installed with any of the following options:
│ └─ pytorch 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 would require
│ └─ cudnn >=8.9.7.29,<9.0a0, which can be installed with any of the following options:
│ └─ cudnn 8.9.7.29
└─ jaxlib ==0.4.31 cuda12* cannot be installed because there are no viable options:
└─ jaxlib 0.4.31 | 0.4.31 | 0.4.31 would require
└─ cudnn >=9.2.1.18,<10.0a0, which cannot be installed because there are no viable options:
└─ cudnn 9.2.1.18 | 9.2.1.18, which conflicts with the versions reported above.
@eightysteele
On the
gen3dbranch, we are running into issues upgrading JAX to 0.4.31. We want to do this because per the top answer here, starting with JAX 0.4.31, the functionjax.map.laxsupports abatch_sizeargument. We would like to include this batch_size parameters at this line. One way to test that this works is to ensure that all the tests intests/gen3d/pass after addingbatch_size=1000at the indicated line. More specifically, if this line runs without failing, we should be good to go. [In this blob it looks like we commented this out -- oops! But this test should be uncommented and work.]When we try changing this line to request jaxlib
==0.4.31, and runpixi install, we get