Skip to content

Commit 4e22032

Browse files
kshyattlkdvos
andauthored
Small fixes for upstream + CUDA (#366)
* Small fixes for upstream + CUDA * Update ext/TensorKitChainRulesCoreExt/linalg.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Update src/tensors/tensor.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Update ext/TensorKitCUDAExt/cutensormap.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Remove unneeded method --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent db4ae81 commit 4e22032

2 files changed

Lines changed: 7 additions & 7 deletions

File tree

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ function ChainRulesCore.rrule(
8080
end
8181

8282
function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap)
83-
tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A))
83+
tr_pullback(Δtr) = NoTangent(), scale!!(id!(similar(A)), unthunk(Δtr))
8484
return tr(A), tr_pullback
8585
end
8686

src/tensors/tensor.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,14 +306,14 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one))
306306
return Base.$fname(codomain domain)
307307
end
308308
function Base.$fname(
309-
::Type{T}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain)
310-
) where {T, S <: IndexSpace}
311-
return Base.$fname(T, codomain domain)
309+
::Type{TorA}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain)
310+
) where {TorA, S <: IndexSpace}
311+
return Base.$fname(TorA, codomain domain)
312312
end
313313
Base.$fname(V::TensorMapSpace) = Base.$fname(Float64, V)
314-
function Base.$fname(::Type{T}, V::TensorMapSpace) where {T}
315-
t = TensorMap{T}(undef, V)
316-
fill!(t, $felt(T))
314+
function Base.$fname(::Type{TorA}, V::TensorMapSpace) where {TorA}
315+
t = tensormaptype(spacetype(V), numout(V), numin(V), TorA)(undef, V)
316+
fill!(t, $felt(scalartype(t)))
317317
return t
318318
end
319319
end

0 commit comments

Comments
 (0)