|
8 | 8 | from jaxtyping import Float |
9 | 9 | from torch import Tensor |
10 | 10 |
|
| 11 | +from pyfastkron import fastkrontorch as fktorch |
| 12 | + |
11 | 13 | from linear_operator import settings |
12 | 14 | from linear_operator.operators._linear_operator import IndexType, LinearOperator |
13 | 15 | from linear_operator.operators.dense_linear_operator import to_linear_operator |
@@ -267,14 +269,13 @@ def _matmul( |
267 | 269 | self: Float[LinearOperator, "*batch M N"], |
268 | 270 | rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], |
269 | 271 | ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: |
270 | | - is_vec = rhs.ndimension() == 1 |
271 | | - if is_vec: |
272 | | - rhs = rhs.unsqueeze(-1) |
273 | | - |
274 | | - res = _matmul(self.linear_ops, self.shape, rhs.contiguous()) |
275 | | - |
276 | | - if is_vec: |
277 | | - res = res.squeeze(-1) |
| 272 | + res = fktorch.gekmm([op.to_dense() for op in self.linear_ops], rhs.contiguous()) |
| 273 | + return res |
| 274 | + |
| 275 | + def rmatmul(self: Float[LinearOperator, "... M N"], |
| 276 | + rhs: Union[Float[Tensor, "... P M"], Float[Tensor, "... M"], Float[LinearOperator, "... P M"]], |
| 277 | + ) -> Union[Float[Tensor, "... P N"], Float[Tensor, "N"], Float[LinearOperator, "... P N"]]: |
| 278 | + res = fktorch.gemkm(rhs.contiguous(), [op.to_dense() for op in self.linear_ops]) |
278 | 279 | return res |
279 | 280 |
|
280 | 281 | @cached(name="root_decomposition") |
@@ -357,14 +358,7 @@ def _t_matmul( |
357 | 358 | self: Float[LinearOperator, "*batch M N"], |
358 | 359 | rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]], |
359 | 360 | ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]: |
360 | | - is_vec = rhs.ndimension() == 1 |
361 | | - if is_vec: |
362 | | - rhs = rhs.unsqueeze(-1) |
363 | | - |
364 | | - res = _t_matmul(self.linear_ops, self.shape, rhs.contiguous()) |
365 | | - |
366 | | - if is_vec: |
367 | | - res = res.squeeze(-1) |
| 361 | + res = fktorch.gekmm([op.to_dense().mT for op in self.linear_ops], rhs.contiguous()) |
368 | 362 | return res |
369 | 363 |
|
370 | 364 | def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: |
|
0 commit comments