Skip to content

Commit 090660b

Browse files
lukebaumanncopybara-github
authored andcommitted
Refactor resharding logic into helper functions.
This change introduces `_reshard_with_sidechannel` and `_reshard_with_ifrt` to encapsulate the different resharding mechanisms used by the `reshard` function. These are internal APIs and should not be depended on. PiperOrigin-RevId: 856749135
1 parent dad55e8 commit 090660b

2 files changed

Lines changed: 106 additions & 4 deletions

File tree

pathwaysutils/experimental/reshard.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,42 @@ def _ifrt_jax_array_reshard(
198198
)
199199

200200

201+
def _reshard_with_sidechannel(
202+
x: Any,
203+
sharding: jax.sharding.Sharding | Any,
204+
*,
205+
donate: bool = False,
206+
may_alias: bool | None = None,
207+
cache_resharding_plans: bool = False,
208+
) -> Any:
209+
"""Reshards `x` to `sharding` using sidechannel."""
210+
return _reshard(
211+
x,
212+
sharding,
213+
donate=donate,
214+
may_alias=may_alias,
215+
jax_array_reshard_fn=_sidechannel_jax_array_reshard,
216+
cache_resharding_plans=cache_resharding_plans,
217+
)
218+
219+
220+
def _reshard_with_ifrt(
221+
x: Any,
222+
sharding: jax.sharding.Sharding | Any,
223+
*,
224+
donate: bool = False,
225+
may_alias: bool | None = None,
226+
) -> Any:
227+
"""Reshards `x` to `sharding` using IFRT."""
228+
return _reshard(
229+
x,
230+
sharding,
231+
donate=donate,
232+
may_alias=may_alias,
233+
jax_array_reshard_fn=_ifrt_jax_array_reshard,
234+
)
235+
236+
201237
def reshard(
202238
x: Any,
203239
sharding: jax.sharding.Sharding | Any,
@@ -231,19 +267,17 @@ def reshard(
231267
A copy of `x` whose sharding is `sharding`.
232268
"""
233269
if pw_jax.ifrt_reshard_available():
234-
return _reshard(
270+
return _reshard_with_ifrt(
235271
x,
236272
sharding,
237273
donate=donate,
238274
may_alias=may_alias,
239-
jax_array_reshard_fn=_ifrt_jax_array_reshard,
240275
)
241276
else:
242-
return _reshard(
277+
return _reshard_with_sidechannel(
243278
x,
244279
sharding,
245280
donate=donate,
246281
may_alias=may_alias,
247-
jax_array_reshard_fn=_sidechannel_jax_array_reshard,
248282
cache_resharding_plans=cache_resharding_plans,
249283
)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for the resharding utilities in `pathwaysutils.experimental.reshard`."""
16+
17+
import json
18+
from unittest import mock
19+
20+
from absl.testing import absltest
21+
from absl.testing import parameterized
22+
import jax
23+
import jax.numpy as jnp
24+
from pathwaysutils import jax as pw_jax
25+
from pathwaysutils import plugin_executable
26+
from pathwaysutils.experimental import reshard
27+
28+
29+
class ReshardTest(parameterized.TestCase):
30+
31+
@parameterized.parameters(True, False)
32+
def test_reshard_plan_wrapper_donate(self, donate):
33+
avals = [jax.core.ShapedArray((2, 2), jnp.float32)]
34+
devices = jax.devices()
35+
sharding = jax.sharding.SingleDeviceSharding(devices[0])
36+
37+
mock_pe = self.enter_context(
38+
mock.patch.object(plugin_executable, 'PluginExecutable')
39+
)
40+
41+
reshard.ReshardingPlanWrapper(avals, [sharding], [sharding], donate=donate)
42+
43+
mock_pe.assert_called_once()
44+
args, _ = mock_pe.call_args
45+
request = json.loads(args[0])
46+
self.assertEqual(request['reshardRequest']['donateInput'], donate)
47+
48+
@parameterized.parameters(True, False)
49+
def test_ifrt_reshard_donate(self, donate):
50+
x = jnp.array([1, 2])
51+
devices = jax.devices()
52+
sharding = jax.sharding.SingleDeviceSharding(devices[0])
53+
54+
mock_transfer = self.enter_context(
55+
mock.patch.object(pw_jax, 'transfer_to_shardings')
56+
)
57+
self.enter_context(
58+
mock.patch.object(pw_jax, 'ifrt_reshard_available', return_value=True)
59+
)
60+
61+
reshard.reshard(x, sharding, donate=donate)
62+
# Check that transfer_to_shardings was called with donate=True
63+
# Signature: transfer_to_shardings(arrays, shardings, donate)
64+
mock_transfer.assert_called_with(mock.ANY, mock.ANY, donate)
65+
66+
67+
if __name__ == '__main__':
68+
absltest.main()

0 commit comments

Comments
 (0)