Skip to content

Commit 05a8a6b

Browse files
committed
Small fixes for upstream + CUDA
1 parent db4ae81 commit 05a8a6b

5 files changed

Lines changed: 19 additions & 9 deletions

File tree

ext/TensorKitCUDAExt/TensorKitCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using TensorKit.Factorizations
1010
using TensorKit.Strided
1111
using TensorKit.Factorizations: AbstractAlgorithm
1212
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check
13-
import TensorKit: randisometry, rand, randn
13+
import TensorKit: randisometry, rand, randn, twist!
1414

1515
using TensorKit: MatrixAlgebraKit
1616

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one))
3737
fill!(t, $felt(T))
3838
return t
3939
end
40+
function Base.$fname(::Type{TorA}, V::TensorMapSpace) where {TorA <: CuArray}
41+
t = tensormaptype(spacetype(V), numout(V), numin(V), TorA)(undef, V)
42+
fill!(t, $felt(eltype(TorA)))
43+
return t
44+
end
4045
end
4146
end
4247

ext/TensorKitChainRulesCoreExt/constructors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
function ChainRulesCore.rrule(::Type{TensorMap}, d::DenseArray, args...; kwargs...)
88
function TensorMap_pullback(Δt)
9-
∂d = convert(Array, unthunk(Δt))
9+
∂d = convert(typeof(d), unthunk(Δt))
1010
return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))...
1111
end
1212
return TensorMap(d, args...; kwargs...), TensorMap_pullback

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ function ChainRulesCore.rrule(
8080
end
8181

8282
function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap)
83-
tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A))
83+
tr_pullback(Δtr) = NoTangent(), Δtr * id(storagetype(A), domain(A))
8484
return tr(A), tr_pullback
8585
end
8686

src/tensors/tensor.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,14 +306,19 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one))
306306
return Base.$fname(codomain domain)
307307
end
308308
function Base.$fname(
309-
::Type{T}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain)
310-
) where {T, S <: IndexSpace}
311-
return Base.$fname(T, codomain domain)
309+
::Type{TorA}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain)
310+
) where {TorA, S <: IndexSpace}
311+
return Base.$fname(TorA, codomain domain)
312+
end
313+
function Base.$fname(
314+
::Type{T}, ::Type{TorA}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain)
315+
) where {T, TorA, S <: IndexSpace}
316+
return Base.$fname(TorA, codomain domain)
312317
end
313318
Base.$fname(V::TensorMapSpace) = Base.$fname(Float64, V)
314-
function Base.$fname(::Type{T}, V::TensorMapSpace) where {T}
315-
t = TensorMap{T}(undef, V)
316-
fill!(t, $felt(T))
319+
function Base.$fname(::Type{TorA}, V::TensorMapSpace) where {TorA}
320+
t = tensormaptype(spacetype(V), numout(V), numin(V), TorA)(undef, V)
321+
fill!(t, $felt(TorA))
317322
return t
318323
end
319324
end

0 commit comments

Comments
 (0)