@@ -364,28 +364,32 @@ lmul!(A, B)
364
364
# aggressive constant propagation makes mul!(C, A, B) invoke gemm_wrapper! directly
365
365
Base. @constprop :aggressive function generic_matmatmul! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
366
366
_add:: MulAddMul = MulAddMul ()) where {T<: BlasFloat }
367
- if all (in ((' N' , ' T' , ' C' )), (tA, tB))
368
- if tA == ' T' && tB == ' N' && A === B
367
+ # We convert the chars to uppercase to potentially unwrap a WrapperChar,
368
+ # and extract the char corresponding to the wrapper type
369
+ tA_uc, tB_uc = uppercase (tA), uppercase (tB)
370
+ # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
371
+ if all (map (in ((' N' , ' T' , ' C' )), (tA_uc, tB_uc)))
372
+ if tA_uc == ' T' && tB_uc == ' N' && A === B
369
373
return syrk_wrapper! (C, ' T' , A, _add)
370
- elseif tA == ' N' && tB == ' T' && A === B
374
+ elseif tA_uc == ' N' && tB_uc == ' T' && A === B
371
375
return syrk_wrapper! (C, ' N' , A, _add)
372
- elseif tA == ' C' && tB == ' N' && A === B
376
+ elseif tA_uc == ' C' && tB_uc == ' N' && A === B
373
377
return herk_wrapper! (C, ' C' , A, _add)
374
- elseif tA == ' N' && tB == ' C' && A === B
378
+ elseif tA_uc == ' N' && tB_uc == ' C' && A === B
375
379
return herk_wrapper! (C, ' N' , A, _add)
376
380
else
377
381
return gemm_wrapper! (C, tA, tB, A, B, _add)
378
382
end
379
383
end
380
384
alpha, beta = promote (_add. alpha, _add. beta, zero (T))
381
385
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
382
- if (tA == ' S' || tA == ' s ' ) && tB == ' N'
386
+ if tA_uc == ' S' && tB_uc == ' N'
383
387
return BLAS. symm! (' L' , tA == ' S' ? ' U' : ' L' , alpha, A, B, beta, C)
384
- elseif (tB == ' S ' || tB == ' s ' ) && tA == ' N '
388
+ elseif tA_uc == ' N ' && tB_uc == ' S '
385
389
return BLAS. symm! (' R' , tB == ' S' ? ' U' : ' L' , alpha, B, A, beta, C)
386
- elseif (tA == ' H' || tA == ' h ' ) && tB == ' N'
390
+ elseif tA_uc == ' H' && tB_uc == ' N'
387
391
return BLAS. hemm! (' L' , tA == ' H' ? ' U' : ' L' , alpha, A, B, beta, C)
388
- elseif (tB == ' H ' || tB == ' h ' ) && tA == ' N '
392
+ elseif tA_uc == ' N ' && tB_uc == ' H '
389
393
return BLAS. hemm! (' R' , tB == ' H' ? ' U' : ' L' , alpha, B, A, beta, C)
390
394
end
391
395
end
395
399
# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
396
400
Base. @constprop :aggressive function generic_matmatmul! (C:: StridedVecOrMat{Complex{T}} , tA, tB, A:: StridedVecOrMat{Complex{T}} , B:: StridedVecOrMat{T} ,
397
401
_add:: MulAddMul = MulAddMul ()) where {T<: BlasReal }
398
- if all (in ((' N' , ' T' , ' C' )), (tA, tB))
402
+ # We convert the chars to uppercase to potentially unwrap a WrapperChar,
403
+ # and extract the char corresponding to the wrapper type
404
+ tA_uc, tB_uc = uppercase (tA), uppercase (tB)
405
+ # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
406
+ if all (map (in ((' N' , ' T' , ' C' )), (tA_uc, tB_uc)))
399
407
gemm_wrapper! (C, tA, tB, A, B, _add)
400
408
else
401
409
_generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
@@ -434,18 +442,19 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar
434
442
mA == 0 && return y
435
443
nA == 0 && return _rmul_or_fill! (y, β)
436
444
alpha, beta = promote (α, β, zero (T))
445
+ tA_uc = uppercase (tA) # potentially convert a WrapperChar to a Char
437
446
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
438
447
stride (A, 1 ) == 1 && abs (stride (A, 2 )) >= size (A, 1 ) &&
439
448
! iszero (stride (x, 1 )) && # We only check input's stride here.
440
- if tA in (' N' , ' T' , ' C' )
449
+ if tA_uc in (' N' , ' T' , ' C' )
441
450
return BLAS. gemv! (tA, alpha, A, x, beta, y)
442
- elseif tA in ( ' S' , ' s ' )
451
+ elseif tA_uc == ' S'
443
452
return BLAS. symv! (tA == ' S' ? ' U' : ' L' , alpha, A, x, beta, y)
444
- elseif tA in ( ' H' , ' h ' )
453
+ elseif tA_uc == ' H'
445
454
return BLAS. hemv! (tA == ' H' ? ' U' : ' L' , alpha, A, x, beta, y)
446
455
end
447
456
end
448
- if tA in (' S' , ' s ' , ' H ' , ' h ' )
457
+ if tA_uc in (' S' , ' H ' )
449
458
# re-wrap again and use plain ('N') matvec mul algorithm,
450
459
# because _generic_matvecmul! can't handle the HermOrSym cases specifically
451
460
return _generic_matvecmul! (y, ' N' , wrap (A, tA), x, MulAddMul (α, β))
@@ -464,14 +473,15 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
464
473
mA == 0 && return y
465
474
nA == 0 && return _rmul_or_fill! (y, β)
466
475
alpha, beta = promote (α, β, zero (T))
476
+ tA_uc = uppercase (tA) # potentially convert a WrapperChar to a Char
467
477
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
468
478
stride (A, 1 ) == 1 && abs (stride (A, 2 )) >= size (A, 1 ) &&
469
- stride (y, 1 ) == 1 && tA == ' N' && # reinterpret-based optimization is valid only for contiguous `y`
479
+ stride (y, 1 ) == 1 && tA_uc == ' N' && # reinterpret-based optimization is valid only for contiguous `y`
470
480
! iszero (stride (x, 1 ))
471
481
BLAS. gemv! (tA, alpha, reinterpret (T, A), x, beta, reinterpret (T, y))
472
482
return y
473
483
else
474
- Anew, ta = tA in (' S' , ' s ' , ' H ' , ' h ' ) ? (wrap (A, tA), ' N' ) : (A, tA)
484
+ Anew, ta = tA_uc in (' S' , ' H ' ) ? (wrap (A, tA), oftype (tA, ' N' ) ) : (A, tA)
475
485
return _generic_matvecmul! (y, ta, Anew, x, MulAddMul (α, β))
476
486
end
477
487
end
@@ -487,15 +497,16 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
487
497
mA == 0 && return y
488
498
nA == 0 && return _rmul_or_fill! (y, β)
489
499
alpha, beta = promote (α, β, zero (T))
500
+ tA_uc = uppercase (tA) # potentially convert a WrapperChar to a Char
490
501
@views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
491
502
stride (A, 1 ) == 1 && abs (stride (A, 2 )) >= size (A, 1 ) &&
492
- ! iszero (stride (x, 1 )) && tA in (' N' , ' T' , ' C' )
503
+ ! iszero (stride (x, 1 )) && tA_uc in (' N' , ' T' , ' C' )
493
504
xfl = reinterpret (reshape, T, x) # Use reshape here.
494
505
yfl = reinterpret (reshape, T, y)
495
506
BLAS. gemv! (tA, alpha, A, xfl[1 , :], beta, yfl[1 , :])
496
507
BLAS. gemv! (tA, alpha, A, xfl[2 , :], beta, yfl[2 , :])
497
508
return y
498
- elseif tA in (' S' , ' s ' , ' H ' , ' h ' )
509
+ elseif tA_uc in (' S' , ' H ' )
499
510
# re-wrap again and use plain ('N') matvec mul algorithm,
500
511
# because _generic_matvecmul! can't handle the HermOrSym cases specifically
501
512
return _generic_matvecmul! (y, ' N' , wrap (A, tA), x, MulAddMul (α, β))
@@ -504,10 +515,13 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
504
515
end
505
516
end
506
517
507
- function syrk_wrapper! (C:: StridedMatrix{T} , tA:: AbstractChar , A:: StridedVecOrMat{T} ,
518
+ # the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it
519
+ # to be concretely inferred
520
+ Base. @constprop :aggressive function syrk_wrapper! (C:: StridedMatrix{T} , tA:: AbstractChar , A:: StridedVecOrMat{T} ,
508
521
_add = MulAddMul ()) where {T<: BlasFloat }
509
522
nC = checksquare (C)
510
- if tA == ' T'
523
+ tA_uc = uppercase (tA) # potentially convert a WrapperChar to a Char
524
+ if tA_uc == ' T'
511
525
(nA, mA) = size (A,1 ), size (A,2 )
512
526
tAt = ' N'
513
527
else
@@ -542,10 +556,13 @@ function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat
542
556
return gemm_wrapper! (C, tA, tAt, A, A, _add)
543
557
end
544
558
545
- function herk_wrapper! (C:: Union{StridedMatrix{T}, StridedMatrix{Complex{T}}} , tA:: AbstractChar , A:: Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}} ,
559
+ # the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it
560
+ # to be concretely inferred
561
+ Base. @constprop :aggressive function herk_wrapper! (C:: Union{StridedMatrix{T}, StridedMatrix{Complex{T}}} , tA:: AbstractChar , A:: Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}} ,
546
562
_add = MulAddMul ()) where {T<: BlasReal }
547
563
nC = checksquare (C)
548
- if tA == ' C'
564
+ tA_uc = uppercase (tA) # potentially convert a WrapperChar to a Char
565
+ if tA_uc == ' C'
549
566
(nA, mA) = size (A,1 ), size (A,2 )
550
567
tAt = ' N'
551
568
else
@@ -581,20 +598,28 @@ function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA
581
598
return gemm_wrapper! (C, tA, tAt, A, A, _add)
582
599
end
583
600
584
- function gemm_wrapper (tA:: AbstractChar , tB:: AbstractChar ,
601
+ # Aggressive constprop helps propagate the values of tA and tB into wrap, which
602
+ # makes the calls concretely inferred
603
+ Base. @constprop :aggressive function gemm_wrapper (tA:: AbstractChar , tB:: AbstractChar ,
585
604
A:: StridedVecOrMat{T} ,
586
605
B:: StridedVecOrMat{T} ) where {T<: BlasFloat }
587
606
mA, nA = lapack_size (tA, A)
588
607
mB, nB = lapack_size (tB, B)
589
608
C = similar (B, T, mA, nB)
590
- if all (in ((' N' , ' T' , ' C' )), (tA, tB))
609
+ # We convert the chars to uppercase to potentially unwrap a WrapperChar,
610
+ # and extract the char corresponding to the wrapper type
611
+ tA_uc, tB_uc = uppercase (tA), uppercase (tB)
612
+ # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
613
+ if all (map (in ((' N' , ' T' , ' C' )), (tA_uc, tB_uc)))
591
614
gemm_wrapper! (C, tA, tB, A, B)
592
615
else
593
616
_generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
594
617
end
595
618
end
596
619
597
- function gemm_wrapper! (C:: StridedVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar ,
620
+ # Aggressive constprop helps propagate the values of tA and tB into wrap, which
621
+ # makes the calls concretely inferred
622
+ Base. @constprop :aggressive function gemm_wrapper! (C:: StridedVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar ,
598
623
A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
599
624
_add = MulAddMul ()) where {T<: BlasFloat }
600
625
mA, nA = lapack_size (tA, A)
@@ -634,7 +659,9 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
634
659
_generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
635
660
end
636
661
637
- function gemm_wrapper! (C:: StridedVecOrMat{Complex{T}} , tA:: AbstractChar , tB:: AbstractChar ,
662
+ # Aggressive constprop helps propagate the values of tA and tB into wrap, which
663
+ # makes the calls concretely inferred
664
+ Base. @constprop :aggressive function gemm_wrapper! (C:: StridedVecOrMat{Complex{T}} , tA:: AbstractChar , tB:: AbstractChar ,
638
665
A:: StridedVecOrMat{Complex{T}} , B:: StridedVecOrMat{T} ,
639
666
_add = MulAddMul ()) where {T<: BlasReal }
640
667
mA, nA = lapack_size (tA, A)
@@ -664,13 +691,15 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs
664
691
665
692
alpha, beta = promote (_add. alpha, _add. beta, zero (T))
666
693
694
+ tA_uc = uppercase (tA) # potentially convert a WrapperChar to a Char
695
+
667
696
# Make-sure reinterpret-based optimization is BLAS-compatible.
668
697
if (alpha isa Union{Bool,T} &&
669
698
beta isa Union{Bool,T} &&
670
699
stride (A, 1 ) == stride (B, 1 ) == stride (C, 1 ) == 1 &&
671
700
stride (A, 2 ) >= size (A, 1 ) &&
672
701
stride (B, 2 ) >= size (B, 1 ) &&
673
- stride (C, 2 ) >= size (C, 1 ) && tA == ' N' )
702
+ stride (C, 2 ) >= size (C, 1 ) && tA_uc == ' N' )
674
703
BLAS. gemm! (tA, tB, alpha, reinterpret (T, A), B, beta, reinterpret (T, C))
675
704
return C
676
705
end
@@ -703,9 +732,10 @@ parameters must satisfy `length(ir_dest) == length(ir_src)` and
703
732
See also [`copy_transpose!`](@ref) and [`copy_adjoint!`](@ref).
704
733
"""
705
734
function copyto! (B:: AbstractVecOrMat , ir_dest:: AbstractUnitRange{Int} , jr_dest:: AbstractUnitRange{Int} , tM:: AbstractChar , M:: AbstractVecOrMat , ir_src:: AbstractUnitRange{Int} , jr_src:: AbstractUnitRange{Int} )
706
- if tM == ' N'
735
+ tM_uc = uppercase (tM) # potentially convert a WrapperChar to a Char
736
+ if tM_uc == ' N'
707
737
copyto! (B, ir_dest, jr_dest, M, ir_src, jr_src)
708
- elseif tM == ' T'
738
+ elseif tM_uc == ' T'
709
739
copy_transpose! (B, ir_dest, jr_dest, M, jr_src, ir_src)
710
740
else
711
741
copy_adjoint! (B, ir_dest, jr_dest, M, jr_src, ir_src)
@@ -734,11 +764,12 @@ range parameters must satisfy `length(ir_dest) == length(jr_src)` and
734
764
See also [`copyto!`](@ref) and [`copy_adjoint!`](@ref).
735
765
"""
736
766
function copy_transpose! (B:: AbstractMatrix , ir_dest:: AbstractUnitRange{Int} , jr_dest:: AbstractUnitRange{Int} , tM:: AbstractChar , M:: AbstractVecOrMat , ir_src:: AbstractUnitRange{Int} , jr_src:: AbstractUnitRange{Int} )
737
- if tM == ' N'
767
+ tM_uc = uppercase (tM) # potentially convert a WrapperChar to a Char
768
+ if tM_uc == ' N'
738
769
copy_transpose! (B, ir_dest, jr_dest, M, ir_src, jr_src)
739
770
else
740
771
copyto! (B, ir_dest, jr_dest, M, jr_src, ir_src)
741
- tM == ' C' && conj! (@view B[ir_dest, jr_dest])
772
+ tM_uc == ' C' && conj! (@view B[ir_dest, jr_dest])
742
773
end
743
774
B
744
775
end
751
782
752
783
@inline function generic_matvecmul! (C:: AbstractVector , tA, A:: AbstractVecOrMat , B:: AbstractVector ,
753
784
_add:: MulAddMul = MulAddMul ())
754
- Anew, ta = tA in (' S' , ' s' , ' H' , ' h' ) ? (wrap (A, tA), ' N' ) : (A, tA)
785
+ tA_uc = uppercase (tA) # potentially convert a WrapperChar to a Char
786
+ Anew, ta = tA_uc in (' S' , ' H' ) ? (wrap (A, tA), oftype (tA, ' N' )) : (A, tA)
755
787
return _generic_matvecmul! (C, ta, Anew, B, _add)
756
788
end
757
789
0 commit comments