Skip to content

Commit 2d2abc5

Browse files
committed
Small fixes for upstream + CUDA
1 parent db4ae81 commit 2d2abc5

3 files changed

Lines changed: 17 additions & 7 deletions

File tree

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one))
3737
fill!(t, $felt(T))
3838
return t
3939
end
40+
function Base.$fname(::Type{TorA}, V::TensorMapSpace) where {TorA <: CuArray}
41+
t = tensormaptype(spacetype(V), numout(V), numin(V), TorA)(undef, V)
42+
fill!(t, $felt(eltype(TorA)))
43+
return t
44+
end
4045
end
4146
end
4247

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ function ChainRulesCore.rrule(
8080
end
8181

8282
function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap)
83-
tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A))
83+
tr_pullback(Δtr) = NoTangent(), Δtr * id(storagetype(A), domain(A))
8484
return tr(A), tr_pullback
8585
end
8686

src/tensors/tensor.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,14 +306,19 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one))
306306
return Base.$fname(codomain domain)
307307
end
308308
function Base.$fname(
309-
::Type{T}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain)
310-
) where {T, S <: IndexSpace}
311-
return Base.$fname(T, codomain domain)
309+
::Type{TorA}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain)
310+
) where {TorA, S <: IndexSpace}
311+
return Base.$fname(TorA, codomain domain)
312+
end
313+
function Base.$fname(
314+
::Type{T}, ::Type{TorA}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain)
315+
) where {T, TorA, S <: IndexSpace}
316+
return Base.$fname(TorA, codomain domain)
312317
end
313318
Base.$fname(V::TensorMapSpace) = Base.$fname(Float64, V)
314-
function Base.$fname(::Type{T}, V::TensorMapSpace) where {T}
315-
t = TensorMap{T}(undef, V)
316-
fill!(t, $felt(T))
319+
function Base.$fname(::Type{TorA}, V::TensorMapSpace) where {TorA}
320+
t = tensormaptype(spacetype(V), numout(V), numin(V), TorA)(undef, V)
321+
fill!(t, $felt(TorA))
317322
return t
318323
end
319324
end

0 commit comments

Comments
 (0)