Skip to content

Commit da54044

Browse files
feat: add torch compile to further reduce the reference model GPU usage in non-sharded and sharded computation
1 parent 00ba99c commit da54044

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/forge/util/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch.distributed.tensor.placement_types import Shard
1212

1313

14+
@torch.compile
1415
def compute_logprobs(
1516
logits: torch.Tensor,
1617
input_ids: torch.Tensor,
@@ -100,6 +101,7 @@ def compute_logprobs(
100101
return logprobs.reshape(batch_size, seq_len)
101102

102103

104+
@torch.compile
103105
def compute_logprobs_parallel(
104106
logits: DTensor,
105107
target_ids: torch.Tensor,

0 commit comments

Comments
 (0)