From 0f22c0bcdb5892cf2a814898e81dc630e46fbd2e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 19 May 2026 16:53:01 +0200 Subject: [PATCH 1/4] Forward rules for TensorOperations calls --- ext/TensorKitMooncakeExt/tensoroperations.jl | 61 +++++++++++++++++++- test/mooncake/tensoroperations.jl | 14 ++--- 2 files changed, 65 insertions(+), 10 deletions(-) diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 0627a7b2c..9d57bd04b 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -8,7 +8,6 @@ # this permutation is done multiple times. @is_primitive( DefaultCtx, - ReverseMode, Tuple{ typeof(TensorKit.blas_contract!), AbstractTensorMap, @@ -70,6 +69,35 @@ function Mooncake.rrule!!( return C_ΔC, blas_contract_pullback end +function Mooncake.frule!!( + ::Dual{typeof(TensorKit.blas_contract!)}, + C_ΔC::Dual{<:AbstractTensorMap}, + A_ΔA::Dual{<:AbstractTensorMap}, pA_ΔpA::Dual{<:Index2Tuple}, + B_ΔB::Dual{<:AbstractTensorMap}, pB_ΔpB::Dual{<:Index2Tuple}, + pAB_ΔpAB::Dual{<:Index2Tuple}, + α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number}, + backend_Δbackend::Dual, allocator_Δallocator::Dual + ) + # prepare arguments + (C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB)) + pA, pB, pAB = primal.((pA_ΔpA, pB_ΔpB, pAB_ΔpAB)) + α, Δα = Mooncake.extract(α_Δα) + β, Δβ = Mooncake.extract(β_Δβ) + backend, allocator = primal.((backend_Δbackend, allocator_Δallocator)) + # ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α + scale!(ΔC, β) + if !isa(Δβ, Mooncake.NoTangent) + add!(ΔC, C, Δβ) + end + if !isa(Δα, Mooncake.NoTangent) + TensorKit.blas_contract!(ΔC, A, pA, B, pB, pAB, Δα, One(), backend, allocator) + end + TensorKit.blas_contract!(ΔC, ΔA, pA, B, pB, pAB, α, One(), backend, allocator) + TensorKit.blas_contract!(ΔC, A, pA, ΔB, pB, pAB, α, One(), backend, allocator) + TensorKit.blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator) + return C_ΔC +end + function blas_contract_pullback_ΔA!( ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) @@ -124,7 +152,6 @@ end # ------------ @is_primitive( DefaultCtx, - ReverseMode, Tuple{ typeof(TensorKit.trace_permute!), AbstractTensorMap, @@ -177,6 +204,36 @@ function Mooncake.rrule!!( return C_ΔC, trace_permute_pullback end +function Mooncake.frule!!( + ::Dual{typeof(TensorKit.trace_permute!)}, + C_ΔC::Dual{<:AbstractTensorMap}, + A_ΔA::Dual{<:AbstractTensorMap}, p_Δp::Dual{<:Index2Tuple}, q_Δq::Dual{<:Index2Tuple}, + α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number}, + backend_Δbackend::Dual + ) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + p = primal(p_Δp) + q = primal(q_Δq) + α, Δα = Mooncake.extract(α_Δα) + β, Δβ = Mooncake.extract(β_Δβ) + backend = primal(backend_Δbackend) + + # dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC + # dC1 = dβ * C + β * dC + scale!(ΔC, β) + if !isa(Δβ, Mooncake.NoTangent) + add!(ΔC, C, Δβ) + end + if !isa(Δα, Mooncake.NoTangent) + TensorKit.trace_permute!(ΔC, A, p, q, Δα, One(), backend) + end + TensorKit.trace_permute!(ΔC, ΔA, p, q, α, One(), backend) + TensorKit.trace_permute!(C, A, p, q, α, β, backend) + return C_ΔC +end + function trace_permute_pullback_ΔA!( ΔA, ΔC, A, p, q, α, backend ) diff --git a/test/mooncake/tensoroperations.jl b/test/mooncake/tensoroperations.jl index b97b90c2f..cc99c00f1 100644 --- a/test/mooncake/tensoroperations.jl +++ b/test/mooncake/tensoroperations.jl @@ -5,8 +5,6 @@ using VectorInterface: One, Zero using Mooncake using Random - -mode = Mooncake.ReverseMode rng = Random.default_rng() spacelist = ad_spacelist(fast_tests) @@ -53,32 +51,32 @@ eltypes = (Float64, ComplexF64) rng, TensorKit.blas_contract!, C, A, pA, B, pB, pAB, One(), Zero(), TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); - atol, rtol, mode + atol, rtol ) Mooncake.TestUtils.test_rule( rng, TensorKit.blas_contract!, C, A, pA, B, pB, pAB, α, β, TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); - atol, rtol, mode + atol, rtol ) if !(T <: Real) Mooncake.TestUtils.test_rule( rng, TensorKit.blas_contract!, C, A, pA, B, pB, pAB, real(α), real(β), TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); - atol, rtol, mode + atol, rtol ) Mooncake.TestUtils.test_rule( rng, TensorKit.blas_contract!, C, real(A), pA, B, pB, pAB, real(α), real(β), TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); - atol, rtol, mode + atol, rtol ) Mooncake.TestUtils.test_rule( rng, TensorKit.blas_contract!, C, A, pA, real(B), pB, pAB, real(α), real(β), TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); - atol, rtol, mode + atol, rtol ) end end @@ -102,7 +100,7 @@ eltypes = (Float64, ComplexF64) C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) Mooncake.TestUtils.test_rule( rng, TensorKit.trace_permute!, C, A, p, q, α, β, TensorOperations.DefaultBackend(); - atol, rtol, mode + atol, rtol ) end end From ffb72f0fe6f47d2199719636851abf08b63a23bd Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 19 May 2026 17:38:14 +0200 Subject: [PATCH 2/4] Format --- ext/TensorKitMooncakeExt/tensoroperations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 9d57bd04b..8cc9f06ec 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -219,7 +219,7 @@ function Mooncake.frule!!( α, Δα = Mooncake.extract(α_Δα) β, Δβ = Mooncake.extract(β_Δβ) backend = primal(backend_Δbackend) - + # dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC # dC1 = dβ * C + β * dC scale!(ΔC, β) From 1ee6479d0fbd171ff0ff6a436ca4eb75be5fd1da Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 29 May 2026 13:53:02 +0200 Subject: [PATCH 3/4] Update ext/TensorKitMooncakeExt/tensoroperations.jl Co-authored-by: Lukas Devos --- ext/TensorKitMooncakeExt/tensoroperations.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 8cc9f06ec..c580056df 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -85,9 +85,10 @@ function Mooncake.frule!!( β, Δβ = Mooncake.extract(β_Δβ) backend, allocator = primal.((backend_Δbackend, allocator_Δallocator)) # ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α - scale!(ΔC, β) - if !isa(Δβ, Mooncake.NoTangent) - add!(ΔC, C, Δβ) + if isa(Δβ, Mooncake.NoTangent) + scale!(ΔC, β) + else + add!(ΔC, C, Δβ, β) end if !isa(Δα, Mooncake.NoTangent) TensorKit.blas_contract!(ΔC, A, pA, B, pB, pAB, Δα, One(), backend, allocator) From 4008c4d1546583911abb9f8960589fbc5fed2b85 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 29 May 2026 13:57:08 +0200 Subject: [PATCH 4/4] Apply suggestion to trace_permute also --- ext/TensorKitMooncakeExt/tensoroperations.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index c580056df..0bfe16643 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -223,9 +223,10 @@ function Mooncake.frule!!( # dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC # dC1 = dβ * C + β * dC - scale!(ΔC, β) - if !isa(Δβ, Mooncake.NoTangent) - add!(ΔC, C, Δβ) + if isa(Δβ, Mooncake.NoTangent) + scale!(ΔC, β) + else + add!(ΔC, C, Δβ, β) end if !isa(Δα, Mooncake.NoTangent) TensorKit.trace_permute!(ΔC, A, p, q, Δα, One(), backend)