Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 61 additions & 2 deletions ext/TensorKitMooncakeExt/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# this permutation is done multiple times.
@is_primitive(
DefaultCtx,
ReverseMode,
Tuple{
typeof(TensorKit.blas_contract!),
AbstractTensorMap,
Expand Down Expand Up @@ -70,6 +69,36 @@ function Mooncake.rrule!!(
return C_ΔC, blas_contract_pullback
end

function Mooncake.frule!!(
::Dual{typeof(TensorKit.blas_contract!)},
C_ΔC::Dual{<:AbstractTensorMap},
A_ΔA::Dual{<:AbstractTensorMap}, pA_ΔpA::Dual{<:Index2Tuple},
B_ΔB::Dual{<:AbstractTensorMap}, pB_ΔpB::Dual{<:Index2Tuple},
pAB_ΔpAB::Dual{<:Index2Tuple},
α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number},
backend_Δbackend::Dual, allocator_Δallocator::Dual
)
# prepare arguments
(C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB))
pA, pB, pAB = primal.((pA_ΔpA, pB_ΔpB, pAB_ΔpAB))
α, Δα = Mooncake.extract(α_Δα)
β, Δβ = Mooncake.extract(β_Δβ)
backend, allocator = primal.((backend_Δbackend, allocator_Δallocator))
# ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α
if isa(Δβ, Mooncake.NoTangent)
scale!(ΔC, β)
else
add!(ΔC, C, Δβ, β)
end
if !isa(Δα, Mooncake.NoTangent)
TensorKit.blas_contract!(ΔC, A, pA, B, pB, pAB, Δα, One(), backend, allocator)
end
TensorKit.blas_contract!(ΔC, ΔA, pA, B, pB, pAB, α, One(), backend, allocator)
TensorKit.blas_contract!(ΔC, A, pA, ΔB, pB, pAB, α, One(), backend, allocator)
TensorKit.blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
return C_ΔC
end

function blas_contract_pullback_ΔA!(
ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator
)
Expand Down Expand Up @@ -124,7 +153,6 @@ end
# ------------
@is_primitive(
DefaultCtx,
ReverseMode,
Tuple{
typeof(TensorKit.trace_permute!),
AbstractTensorMap,
Expand Down Expand Up @@ -177,6 +205,37 @@ function Mooncake.rrule!!(
return C_ΔC, trace_permute_pullback
end

function Mooncake.frule!!(
::Dual{typeof(TensorKit.trace_permute!)},
C_ΔC::Dual{<:AbstractTensorMap},
A_ΔA::Dual{<:AbstractTensorMap}, p_Δp::Dual{<:Index2Tuple}, q_Δq::Dual{<:Index2Tuple},
α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number},
backend_Δbackend::Dual
)
# prepare arguments
C, ΔC = arrayify(C_ΔC)
A, ΔA = arrayify(A_ΔA)
p = primal(p_Δp)
q = primal(q_Δq)
α, Δα = Mooncake.extract(α_Δα)
β, Δβ = Mooncake.extract(β_Δβ)
backend = primal(backend_Δbackend)

# dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC
# dC1 = dβ * C + β * dC
if isa(Δβ, Mooncake.NoTangent)
scale!(ΔC, β)
else
add!(ΔC, C, Δβ, β)
end
if !isa(Δα, Mooncake.NoTangent)
TensorKit.trace_permute!(ΔC, A, p, q, Δα, One(), backend)
end
TensorKit.trace_permute!(ΔC, ΔA, p, q, α, One(), backend)
TensorKit.trace_permute!(C, A, p, q, α, β, backend)
return C_ΔC
end

function trace_permute_pullback_ΔA!(
ΔA, ΔC, A, p, q, α, backend
)
Expand Down
14 changes: 6 additions & 8 deletions test/mooncake/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ using VectorInterface: One, Zero
using Mooncake
using Random


mode = Mooncake.ReverseMode
rng = Random.default_rng()

spacelist = ad_spacelist(fast_tests)
Expand Down Expand Up @@ -53,32 +51,32 @@ eltypes = (Float64, ComplexF64)
rng, TensorKit.blas_contract!,
C, A, pA, B, pB, pAB, One(), Zero(),
TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator();
atol, rtol, mode
atol, rtol
)
Mooncake.TestUtils.test_rule(
rng, TensorKit.blas_contract!,
C, A, pA, B, pB, pAB, α, β,
TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator();
atol, rtol, mode
atol, rtol
)
if !(T <: Real)
Mooncake.TestUtils.test_rule(
rng, TensorKit.blas_contract!,
C, A, pA, B, pB, pAB, real(α), real(β),
TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator();
atol, rtol, mode
atol, rtol
)
Mooncake.TestUtils.test_rule(
rng, TensorKit.blas_contract!,
C, real(A), pA, B, pB, pAB, real(α), real(β),
TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator();
atol, rtol, mode
atol, rtol
)
Mooncake.TestUtils.test_rule(
rng, TensorKit.blas_contract!,
C, A, pA, real(B), pB, pAB, real(α), real(β),
TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator();
atol, rtol, mode
atol, rtol
)
end
end
Expand All @@ -102,7 +100,7 @@ eltypes = (Float64, ComplexF64)
C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false)))
Mooncake.TestUtils.test_rule(
rng, TensorKit.trace_permute!, C, A, p, q, α, β, TensorOperations.DefaultBackend();
atol, rtol, mode
atol, rtol
)
end
end
Expand Down
Loading