Skip to content

Commit f690710

Browse files
jishnublazarusA
authored andcommitted
LinearAlgebra: improve type-inference in Symmetric/Hermitian matmul (JuliaLang#54303)
1 parent dda5799 commit f690710

File tree

3 files changed

+126
-46
lines changed

3 files changed

+126
-46
lines changed

stdlib/LinearAlgebra/src/LinearAlgebra.jl

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -516,32 +516,56 @@ const ⋅ = dot
516516
const × = cross
517517
export , ×
518518

519+
# Separate the char corresponding to the wrapper from that corresponding to the uplo
520+
# In most cases, the former may be constant-propagated, while the latter usually can't be.
521+
# This improves type-inference in wrap for Symmetric/Hermitian matrices
522+
# A WrapperChar is equivalent to `isuppertri ? uppercase(wrapperchar) : lowercase(wrapperchar)`
523+
struct WrapperChar <: AbstractChar
524+
wrapperchar :: Char
525+
isuppertri :: Bool
526+
end
527+
function Base.Char(w::WrapperChar)
528+
T = w.wrapperchar
529+
if T ('N', 'T', 'C') # known cases where isuppertri is true
530+
T
531+
else
532+
_isuppertri(w) ? uppercase(T) : lowercase(T)
533+
end
534+
end
535+
Base.codepoint(w::WrapperChar) = codepoint(Char(w))
536+
WrapperChar(n::UInt32) = WrapperChar(Char(n))
537+
WrapperChar(c::Char) = WrapperChar(c, isuppercase(c))
538+
# We extract the wrapperchar so that the result may be constant-propagated
539+
# This doesn't return a value of the same type on purpose
540+
Base.uppercase(w::WrapperChar) = uppercase(w.wrapperchar)
541+
Base.lowercase(w::WrapperChar) = lowercase(w.wrapperchar)
542+
_isuppertri(w::WrapperChar) = w.isuppertri
543+
_isuppertri(x::AbstractChar) = isuppercase(x) # compatibility with earlier Char-based implementation
544+
_uplosym(x) = _isuppertri(x) ? (:U) : (:L)
545+
519546
wrapper_char(::AbstractArray) = 'N'
520547
wrapper_char(::Adjoint) = 'C'
521548
wrapper_char(::Adjoint{<:Real}) = 'T'
522549
wrapper_char(::Transpose) = 'T'
523-
wrapper_char(A::Hermitian) = A.uplo == 'U' ? 'H' : 'h'
524-
wrapper_char(A::Hermitian{<:Real}) = A.uplo == 'U' ? 'S' : 's'
525-
wrapper_char(A::Symmetric) = A.uplo == 'U' ? 'S' : 's'
550+
wrapper_char(A::Hermitian) = WrapperChar('H', A.uplo == 'U')
551+
wrapper_char(A::Hermitian{<:Real}) = WrapperChar('S', A.uplo == 'U')
552+
wrapper_char(A::Symmetric) = WrapperChar('S', A.uplo == 'U')
526553

527554
Base.@constprop :aggressive function wrap(A::AbstractVecOrMat, tA::AbstractChar)
528555
# merge the result of this before return, so that we can type-assert the return such
529556
# that even if the tmerge is inaccurate, inference can still identify that the
530557
# `_generic_matmatmul` signature still matches and doesn't require missing backedges
531-
B = if tA == 'N'
558+
tA_uc = uppercase(tA)
559+
B = if tA_uc == 'N'
532560
A
533-
elseif tA == 'T'
561+
elseif tA_uc == 'T'
534562
transpose(A)
535-
elseif tA == 'C'
563+
elseif tA_uc == 'C'
536564
adjoint(A)
537-
elseif tA == 'H'
538-
Hermitian(A, :U)
539-
elseif tA == 'h'
540-
Hermitian(A, :L)
541-
elseif tA == 'S'
542-
Symmetric(A, :U)
543-
else # tA == 's'
544-
Symmetric(A, :L)
565+
elseif tA_uc == 'H'
566+
Hermitian(A, _uplosym(tA))
567+
elseif tA_uc == 'S'
568+
Symmetric(A, _uplosym(tA))
545569
end
546570
return B::AbstractVecOrMat
547571
end

stdlib/LinearAlgebra/src/matmul.jl

Lines changed: 64 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -364,28 +364,32 @@ lmul!(A, B)
364364
# aggressive constant propagation makes mul!(C, A, B) invoke gemm_wrapper! directly
365365
Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
366366
_add::MulAddMul=MulAddMul()) where {T<:BlasFloat}
367-
if all(in(('N', 'T', 'C')), (tA, tB))
368-
if tA == 'T' && tB == 'N' && A === B
367+
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
368+
# and extract the char corresponding to the wrapper type
369+
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
370+
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
371+
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
372+
if tA_uc == 'T' && tB_uc == 'N' && A === B
369373
return syrk_wrapper!(C, 'T', A, _add)
370-
elseif tA == 'N' && tB == 'T' && A === B
374+
elseif tA_uc == 'N' && tB_uc == 'T' && A === B
371375
return syrk_wrapper!(C, 'N', A, _add)
372-
elseif tA == 'C' && tB == 'N' && A === B
376+
elseif tA_uc == 'C' && tB_uc == 'N' && A === B
373377
return herk_wrapper!(C, 'C', A, _add)
374-
elseif tA == 'N' && tB == 'C' && A === B
378+
elseif tA_uc == 'N' && tB_uc == 'C' && A === B
375379
return herk_wrapper!(C, 'N', A, _add)
376380
else
377381
return gemm_wrapper!(C, tA, tB, A, B, _add)
378382
end
379383
end
380384
alpha, beta = promote(_add.alpha, _add.beta, zero(T))
381385
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
382-
if (tA == 'S' || tA == 's') && tB == 'N'
386+
if tA_uc == 'S' && tB_uc == 'N'
383387
return BLAS.symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C)
384-
elseif (tB == 'S' || tB == 's') && tA == 'N'
388+
elseif tA_uc == 'N' && tB_uc == 'S'
385389
return BLAS.symm!('R', tB == 'S' ? 'U' : 'L', alpha, B, A, beta, C)
386-
elseif (tA == 'H' || tA == 'h') && tB == 'N'
390+
elseif tA_uc == 'H' && tB_uc == 'N'
387391
return BLAS.hemm!('L', tA == 'H' ? 'U' : 'L', alpha, A, B, beta, C)
388-
elseif (tB == 'H' || tB == 'h') && tA == 'N'
392+
elseif tA_uc == 'N' && tB_uc == 'H'
389393
return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C)
390394
end
391395
end
@@ -395,7 +399,11 @@ end
395399
# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
396400
Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
397401
_add::MulAddMul=MulAddMul()) where {T<:BlasReal}
398-
if all(in(('N', 'T', 'C')), (tA, tB))
402+
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
403+
# and extract the char corresponding to the wrapper type
404+
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
405+
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
406+
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
399407
gemm_wrapper!(C, tA, tB, A, B, _add)
400408
else
401409
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
@@ -434,18 +442,19 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar
434442
mA == 0 && return y
435443
nA == 0 && return _rmul_or_fill!(y, β)
436444
alpha, beta = promote(α, β, zero(T))
445+
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
437446
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
438447
stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
439448
!iszero(stride(x, 1)) && # We only check input's stride here.
440-
if tA in ('N', 'T', 'C')
449+
if tA_uc in ('N', 'T', 'C')
441450
return BLAS.gemv!(tA, alpha, A, x, beta, y)
442-
elseif tA in ('S', 's')
451+
elseif tA_uc == 'S'
443452
return BLAS.symv!(tA == 'S' ? 'U' : 'L', alpha, A, x, beta, y)
444-
elseif tA in ('H', 'h')
453+
elseif tA_uc == 'H'
445454
return BLAS.hemv!(tA == 'H' ? 'U' : 'L', alpha, A, x, beta, y)
446455
end
447456
end
448-
if tA in ('S', 's', 'H', 'h')
457+
if tA_uc in ('S', 'H')
449458
# re-wrap again and use plain ('N') matvec mul algorithm,
450459
# because _generic_matvecmul! can't handle the HermOrSym cases specifically
451460
return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β))
@@ -464,14 +473,15 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
464473
mA == 0 && return y
465474
nA == 0 && return _rmul_or_fill!(y, β)
466475
alpha, beta = promote(α, β, zero(T))
476+
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
467477
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
468478
stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
469-
stride(y, 1) == 1 && tA == 'N' && # reinterpret-based optimization is valid only for contiguous `y`
479+
stride(y, 1) == 1 && tA_uc == 'N' && # reinterpret-based optimization is valid only for contiguous `y`
470480
!iszero(stride(x, 1))
471481
BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y))
472482
return y
473483
else
474-
Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA)
484+
Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA)
475485
return _generic_matvecmul!(y, ta, Anew, x, MulAddMul(α, β))
476486
end
477487
end
@@ -487,15 +497,16 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
487497
mA == 0 && return y
488498
nA == 0 && return _rmul_or_fill!(y, β)
489499
alpha, beta = promote(α, β, zero(T))
500+
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
490501
@views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
491502
stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
492-
!iszero(stride(x, 1)) && tA in ('N', 'T', 'C')
503+
!iszero(stride(x, 1)) && tA_uc in ('N', 'T', 'C')
493504
xfl = reinterpret(reshape, T, x) # Use reshape here.
494505
yfl = reinterpret(reshape, T, y)
495506
BLAS.gemv!(tA, alpha, A, xfl[1, :], beta, yfl[1, :])
496507
BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :])
497508
return y
498-
elseif tA in ('S', 's', 'H', 'h')
509+
elseif tA_uc in ('S', 'H')
499510
# re-wrap again and use plain ('N') matvec mul algorithm,
500511
# because _generic_matvecmul! can't handle the HermOrSym cases specifically
501512
return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β))
@@ -504,10 +515,13 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
504515
end
505516
end
506517

507-
function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T},
518+
# the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it
519+
# to be concretely inferred
520+
Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T},
508521
_add = MulAddMul()) where {T<:BlasFloat}
509522
nC = checksquare(C)
510-
if tA == 'T'
523+
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
524+
if tA_uc == 'T'
511525
(nA, mA) = size(A,1), size(A,2)
512526
tAt = 'N'
513527
else
@@ -542,10 +556,13 @@ function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat
542556
return gemm_wrapper!(C, tA, tAt, A, A, _add)
543557
end
544558

545-
function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}},
559+
# the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it
560+
# to be concretely inferred
561+
Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}},
546562
_add = MulAddMul()) where {T<:BlasReal}
547563
nC = checksquare(C)
548-
if tA == 'C'
564+
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
565+
if tA_uc == 'C'
549566
(nA, mA) = size(A,1), size(A,2)
550567
tAt = 'N'
551568
else
@@ -581,20 +598,28 @@ function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA
581598
return gemm_wrapper!(C, tA, tAt, A, A, _add)
582599
end
583600

584-
function gemm_wrapper(tA::AbstractChar, tB::AbstractChar,
601+
# Aggressive constprop helps propagate the values of tA and tB into wrap, which
602+
# makes the calls concretely inferred
603+
Base.@constprop :aggressive function gemm_wrapper(tA::AbstractChar, tB::AbstractChar,
585604
A::StridedVecOrMat{T},
586605
B::StridedVecOrMat{T}) where {T<:BlasFloat}
587606
mA, nA = lapack_size(tA, A)
588607
mB, nB = lapack_size(tB, B)
589608
C = similar(B, T, mA, nB)
590-
if all(in(('N', 'T', 'C')), (tA, tB))
609+
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
610+
# and extract the char corresponding to the wrapper type
611+
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
612+
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
613+
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
591614
gemm_wrapper!(C, tA, tB, A, B)
592615
else
593616
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
594617
end
595618
end
596619

597-
function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
620+
# Aggressive constprop helps propagate the values of tA and tB into wrap, which
621+
# makes the calls concretely inferred
622+
Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
598623
A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
599624
_add = MulAddMul()) where {T<:BlasFloat}
600625
mA, nA = lapack_size(tA, A)
@@ -634,7 +659,9 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
634659
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
635660
end
636661

637-
function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
662+
# Aggressive constprop helps propagate the values of tA and tB into wrap, which
663+
# makes the calls concretely inferred
664+
Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
638665
A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
639666
_add = MulAddMul()) where {T<:BlasReal}
640667
mA, nA = lapack_size(tA, A)
@@ -664,13 +691,15 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs
664691

665692
alpha, beta = promote(_add.alpha, _add.beta, zero(T))
666693

694+
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
695+
667696
# Make-sure reinterpret-based optimization is BLAS-compatible.
668697
if (alpha isa Union{Bool,T} &&
669698
beta isa Union{Bool,T} &&
670699
stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 &&
671700
stride(A, 2) >= size(A, 1) &&
672701
stride(B, 2) >= size(B, 1) &&
673-
stride(C, 2) >= size(C, 1) && tA == 'N')
702+
stride(C, 2) >= size(C, 1) && tA_uc == 'N')
674703
BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C))
675704
return C
676705
end
@@ -703,9 +732,10 @@ parameters must satisfy `length(ir_dest) == length(ir_src)` and
703732
See also [`copy_transpose!`](@ref) and [`copy_adjoint!`](@ref).
704733
"""
705734
function copyto!(B::AbstractVecOrMat, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int})
706-
if tM == 'N'
735+
tM_uc = uppercase(tM) # potentially convert a WrapperChar to a Char
736+
if tM_uc == 'N'
707737
copyto!(B, ir_dest, jr_dest, M, ir_src, jr_src)
708-
elseif tM == 'T'
738+
elseif tM_uc == 'T'
709739
copy_transpose!(B, ir_dest, jr_dest, M, jr_src, ir_src)
710740
else
711741
copy_adjoint!(B, ir_dest, jr_dest, M, jr_src, ir_src)
@@ -734,11 +764,12 @@ range parameters must satisfy `length(ir_dest) == length(jr_src)` and
734764
See also [`copyto!`](@ref) and [`copy_adjoint!`](@ref).
735765
"""
736766
function copy_transpose!(B::AbstractMatrix, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int})
737-
if tM == 'N'
767+
tM_uc = uppercase(tM) # potentially convert a WrapperChar to a Char
768+
if tM_uc == 'N'
738769
copy_transpose!(B, ir_dest, jr_dest, M, ir_src, jr_src)
739770
else
740771
copyto!(B, ir_dest, jr_dest, M, jr_src, ir_src)
741-
tM == 'C' && conj!(@view B[ir_dest, jr_dest])
772+
tM_uc == 'C' && conj!(@view B[ir_dest, jr_dest])
742773
end
743774
B
744775
end
@@ -751,7 +782,8 @@ end
751782

752783
@inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
753784
_add::MulAddMul = MulAddMul())
754-
Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA)
785+
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
786+
Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA)
755787
return _generic_matvecmul!(C, ta, Anew, B, _add)
756788
end
757789

stdlib/LinearAlgebra/test/matmul.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,30 @@ mul_wrappers = [
3030
h(A) = LinearAlgebra.wrap(LinearAlgebra._unwrap(A), LinearAlgebra.wrapper_char(A))
3131
@test @inferred(h(transpose(A))) === transpose(A)
3232
@test @inferred(h(adjoint(A))) === transpose(A)
33+
34+
M = rand(2,2)
35+
for S in (Symmetric(M), Hermitian(M))
36+
@test @inferred((A -> LinearAlgebra.wrap(parent(A), LinearAlgebra.wrapper_char(A)))(S)) === Symmetric(M)
37+
end
38+
M = rand(ComplexF64,2,2)
39+
for S in (Symmetric(M), Hermitian(M))
40+
@test @inferred((A -> LinearAlgebra.wrap(parent(A), LinearAlgebra.wrapper_char(A)))(S)) === S
41+
end
42+
43+
@testset "WrapperChar" begin
44+
@test LinearAlgebra.WrapperChar('c') == 'c'
45+
@test LinearAlgebra.WrapperChar('C') == 'C'
46+
@testset "constant propagation in uppercase/lowercase" begin
47+
v = @inferred (() -> Val(uppercase(LinearAlgebra.WrapperChar('C'))))()
48+
@test v isa Val{'C'}
49+
v = @inferred (() -> Val(uppercase(LinearAlgebra.WrapperChar('s'))))()
50+
@test v isa Val{'S'}
51+
v = @inferred (() -> Val(lowercase(LinearAlgebra.WrapperChar('C'))))()
52+
@test v isa Val{'c'}
53+
v = @inferred (() -> Val(lowercase(LinearAlgebra.WrapperChar('s'))))()
54+
@test v isa Val{'s'}
55+
end
56+
end
3357
end
3458

3559
@testset "matrices with zero dimensions" begin

0 commit comments

Comments
 (0)