Vibe-coded experimental Elixir macro for bringing np.einsum/torch.einsum notation to Elixir Nx.
defmodule MyOps do
import NxEinsum
defeinsum(row_self_inner, "ij,ji->i")
defeinsum(weighted_trace, "ij,ji,j->")
defeinsum(frobenius_inner, "ij,ji->")
defeinsum(hadamard_transpose, "ij,ji->ji")
defeinsum(trace_product, "ii,jj->")
defeinsum(diag_outer_row, "ii,jj->i")
defeinsum(diag_outer_col, "ii,jj->j")
defeinsum(diag_outer, "ii,jj->ij")
end
MyOps.weighted_trace(
Nx.tensor([[3, 2], [4, 5]]),
Nx.tensor([[1, -1], [1, 3]]),
Nx.tensor([1, 2])
)
|> Nx.to_number()
|> IO.inspect()
# => 33