@@ -283,12 +283,47 @@ function LinearAlgebra.tr(t::AbstractTensorMap)
283283end
284284
285285# TensorMap multiplication
286- function LinearAlgebra. mul! (tC:: AbstractTensorMap ,
287- tA:: AbstractTensorMap ,
288- tB:: AbstractTensorMap , α= true , β= false )
286+ function LinearAlgebra. mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
287+ tB:: AbstractTensorMap ,
288+ α:: Number , β:: Number ,
289+ backend:: AbstractBackend = TO. DefaultBackend ())
290+ if backend isa TO. DefaultBackend
291+ newbackend = TO. select_backend (mul!, tC, tA, tB)
292+ return mul! (tC, tA, tB, α, β, newbackend)
293+ elseif backend isa TO. NoBackend # error for missing backend
294+ TC = typeof (tC)
295+ TA = typeof (tA)
296+ TB = typeof (tB)
297+ throw (ArgumentError (" No suitable backend found for `mul!` and tensor types $TC , $TA and $TB " ))
298+ else # error for unknown backend
299+ TC = typeof (tC)
300+ TA = typeof (tA)
301+ TB = typeof (tB)
302+ throw (ArgumentError (" Unknown backend for `mul!` and tensor types $TC , $TA and $TB " ))
303+ end
304+ end
305+
306+ function TO. select_backend (:: typeof (mul!), C:: AbstractTensorMap , A:: AbstractTensorMap ,
307+ B:: AbstractTensorMap )
308+ return TensorKitBackend ()
309+ end
310+
311+ function LinearAlgebra. mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
312+ tB:: AbstractTensorMap , α:: Number , β:: Number ,
313+ backend:: TensorKitBackend )
289314 compose (space (tA), space (tB)) == space (tC) ||
290315 throw (SpaceMismatch (lazy " $(space(tC)) ≠ $(space(tA)) * $(space(tB))" ))
291316
317+ scheduler = backend. blockscheduler
318+ if isnothing (scheduler)
319+ return sequential_mul! (tC, tA, tB, α, β)
320+ else
321+ return threaded_mul! (tC, tA, tB, α, β, scheduler)
322+ end
323+ end
324+
325+ function sequential_mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
326+ tB:: AbstractTensorMap , α:: Number , β:: Number )
292327 iterC = blocks (tC)
293328 iterA = blocks (tA)
294329 iterB = blocks (tB)
@@ -310,13 +345,13 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
310345 elseif cB < cC
311346 nextB = iterate (iterB, stateB)
312347 else
313- if β != one (β)
348+ if ! isone (β)
314349 rmul! (C, β)
315350 end
316351 nextC = iterate (iterC, stateC)
317352 end
318353 else
319- if β != one (β)
354+ if ! isone (β)
320355 rmul! (C, β)
321356 end
322357 nextC = iterate (iterC, stateC)
@@ -325,7 +360,21 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
325360 return tC
326361end
327362
328- # TODO : consider spawning threads for different blocks, support backends
363+ function threaded_mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap , tB:: AbstractTensorMap ,
364+ α:: Number , β:: Number , scheduler:: Scheduler )
365+ # obtain cached data before multithreading
366+ bCs, bAs, bBs = blocks (tC), blocks (tA), blocks (tB)
367+
368+ tforeach (blocksectors (tC); scheduler) do c
369+ if haskey (bAs, c) # then also bBs should have it
370+ mul! (bCs[c], bAs[c], bBs[c], α, β)
371+ elseif ! isone (β)
372+ scale! (bCs[c], β)
373+ end
374+ end
375+
376+ return tC
377+ end
329378
330379# TensorMap inverse
331380function Base. inv (t:: AbstractTensorMap )
0 commit comments