-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfix_XLA_PYTHON_CLIENT_PREALLOCATE_parsing.patch
More file actions
42 lines (38 loc) · 1.71 KB
/
fix_XLA_PYTHON_CLIENT_PREALLOCATE_parsing.patch
File metadata and controls
42 lines (38 loc) · 1.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
From 45c41c43c86addba13f8782cfb0e72135d8821e0 Mon Sep 17 00:00:00 2001
From: Peter Hawkins <phawkins@google.com>
Date: Mon, 18 Dec 2023 19:13:07 -0800
Subject: [PATCH] [XLA:Python] Fix bug that mean XLA_PYTHON_CLIENT_PREALLOCATE
was no longer parsed correctly.
Fixes https://github.com/google/jax/issues/19035
PiperOrigin-RevId: 592074504
---
xla/python/xla_client.py | 11 ++++-------
1 file changed, 4 insertions(+), 7 deletions(-)
diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py
index 3aeafd1af7d3..57b6f771ed99 100644
--- a/xla/python/xla_client.py
+++ b/xla/python/xla_client.py
@@ -90,10 +90,6 @@ def make_gpu_client(
"""Returns a GPU client. BFC allocator is used by default."""
options = generate_pjrt_gpu_plugin_options()
allocator = options['allocator']
- memory_fraction = (
- options['memory_fraction'] if 'memory_fraction' in options else None
- )
- preallocate = options['preallocate'] if 'preallocate' in options else None
config = _xla.GpuAllocatorConfig()
if allocator == 'default':
config.kind = _xla.GpuAllocatorConfig.Kind.DEFAULT
@@ -103,9 +99,10 @@ def make_gpu_client(
config.kind = _xla.GpuAllocatorConfig.Kind.BFC
if allocator == 'cuda_async':
config.kind = _xla.GpuAllocatorConfig.Kind.CUDA_ASYNC
- if memory_fraction:
- config.memory_fraction = float(memory_fraction)
- config.preallocate = preallocate not in ('0', 'false', 'False')
+ if 'memory_fraction' in options:
+ config.memory_fraction = options['memory_fraction']
+ if 'preallocate' in options:
+ config.preallocate = options['preallocate']
register_custom_call_handler('CUDA', _xla.register_custom_call_target)
register_custom_call_handler('ROCM', _xla.register_custom_call_target)