Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace MulAddMul by alpha,beta in __muldiag #56360

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
16 changes: 12 additions & 4 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,14 @@ const BiTri = Union{Bidiagonal,Tridiagonal}
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
for T in (:AbstractMatrix, :Diagonal)
@eval begin
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::$T, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::$T, B::BandedMatrix, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
end
end
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))

Expand Down Expand Up @@ -831,6 +835,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
C
end

_mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
require_one_based_indexing(C)
check_A_mul_B!_sizes(size(C), size(A), size(B))
Expand Down Expand Up @@ -1067,6 +1073,8 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd
C
end

_mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
_mul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add::MulAddMul) =
_dibimul!(C, A, B, _add)
_mul!(C::AbstractMatrix, A::Diagonal, B::TriSym, _add::MulAddMul) =
Expand Down
100 changes: 51 additions & 49 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,13 @@ function lmul!(D::Diagonal, T::Tridiagonal)
return T
end

@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, _add::MulAddMul)
@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, alpha::Number, beta::Number)
@inbounds for j in axes(B, 2)
@simd for i in axes(B, 1)
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
@stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j))
end
end
out
return out
end
_has_matching_zeros(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true
_has_matching_zeros(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true
Expand All @@ -418,116 +418,118 @@ function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col)
end
_rowrange_tri_zeros(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1)
_rowrange_tri_zeros(B::LowerOrUnitLowerTriangular, col) = 1:col-1
function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul)
function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, alpha::Number, beta::Number)
isunit = B isa UnitUpperOrUnitLowerTriangular
out_maybeparent, B_maybeparent = _has_matching_zeros(out, B) ? (parent(out), parent(B)) : (out, B)
for j in axes(B, 2)
# store the diagonal separately for unit triangular matrices
if isunit
@inbounds _modify!(_add, D.diag[j] * B[j,j], out, (j,j))
@inbounds @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[j] * B[j,j], out, (j,j))
end
# The indices of out corresponding to the stored indices of B
rowrange = _rowrange_tri_stored(B, j)
@inbounds @simd for i in rowrange
_modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
@stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
end
# Fill the indices of out corresponding to the zeros of B
# we only fill these if out and B don't have matching zeros
if !_has_matching_zeros(out, B)
rowrange = _rowrange_tri_zeros(B, j)
@inbounds @simd for i in rowrange
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
@stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j))
end
end
end
return out
end

@inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
beta = _add.beta
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
@inline function __muldiag_nonzeroalpha_right!(out, A, D::Diagonal, alpha::Number, beta::Number)
@inbounds for j in axes(A, 2)
dja = _add(D.diag[j])
dja = @stable_muladdmul MulAddMul(alpha,false)(D.diag[j])
@simd for i in axes(A, 1)
_modify!(_add_aisone, A[i,j] * dja, out, (i,j))
@stable_muladdmul _modify!(MulAddMul(true,beta), A[i,j] * dja, out, (i,j))
end
end
out
return out
end

function __muldiag_nonzeroalpha!(out, A, D::Diagonal, alpha::Number, beta::Number)
__muldiag_nonzeroalpha_right!(out, A, D, alpha, beta)
end
function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, alpha::Number, beta::Number)
isunit = A isa UnitUpperOrUnitLowerTriangular
beta = _add.beta
# since alpha is multiplied to the diagonal element of D,
# we may skip alpha in the second multiplication by setting ais1 to true
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
# if both A and out have the same upper/lower triangular structure,
# we may directly read and write from the parents
out_maybeparent, A_maybeparent = _has_matching_zeros(out, A) ? (parent(out), parent(A)) : (out, A)
out_maybeparent, A_maybeparent = _has_matching_zeros(out, A) ? (parent(out), parent(A)) : (out, A)
for j in axes(A, 2)
dja = _add(@inbounds D.diag[j])
dja = @stable_muladdmul MulAddMul(alpha,false)(@inbounds D.diag[j])
# store the diagonal separately for unit triangular matrices
if isunit
@inbounds _modify!(_add_aisone, A[j,j] * dja, out, (j,j))
# since alpha is multiplied to the diagonal element of D,
# we may skip alpha in the second multiplication by setting ais1 to true
@inbounds @stable_muladdmul _modify!(MulAddMul(true,beta), A[j,j] * dja, out, (j,j))
end
# indices of out corresponding to the stored indices of A
rowrange = _rowrange_tri_stored(A, j)
@inbounds @simd for i in rowrange
_modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
# since alpha is multiplied to the diagonal element of D,
# we may skip alpha in the second multiplication by setting ais1 to true
@stable_muladdmul _modify!(MulAddMul(true,beta), A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
end
# Fill the indices of out corresponding to the zeros of A
# we only fill these if out and A don't have matching zeros
if !_has_matching_zeros(out, A)
rowrange = _rowrange_tri_zeros(A, j)
@inbounds @simd for i in rowrange
_modify!(_add_aisone, A[i,j] * dja, out, (i,j))
@stable_muladdmul _modify!(MulAddMul(true,beta), A[i,j] * dja, out, (i,j))
end
end
end
out
return out
end

# ambiguity resolution
function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, alpha::Number, beta::Number)
__muldiag_nonzeroalpha_right!(out, D1, D2, alpha, beta)
end

@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha::Number, beta::Number)
d1 = D1.diag
d2 = D2.diag
outd = out.diag
@inbounds @simd for i in eachindex(d1, d2, outd)
_modify!(_add, d1[i] * d2[i], outd, i)
@stable_muladdmul _modify!(MulAddMul(alpha,beta), d1[i] * d2[i], outd, i)
end
out
end

# ambiguity resolution
@inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
@inbounds for j in axes(D2, 2), i in axes(D2, 1)
_modify!(_add, D1.diag[i] * D2[i,j], out, (i,j))
end
out
return out
end

# muldiag mainly handles the zero-alpha case, so that we need only
# muldiag handles the zero-alpha case, so that we need only
# specialize the non-trivial case
function _mul_diag!(out, A, B, _add)
function _mul_diag!(out, A, B, alpha, beta)
require_one_based_indexing(out, A, B)
_muldiag_size_check(size(out), size(A), size(B))
alpha, beta = _add.alpha, _add.beta
if iszero(alpha)
_rmul_or_fill!(out, beta)
else
__muldiag_nonzeroalpha!(out, A, B, _add)
__muldiag_nonzeroalpha!(out, A, B, alpha, beta)
end
return out
end

_mul!(out::AbstractVecOrMat, D::Diagonal, V::AbstractVector, _add) =
_mul_diag!(out, D, V, _add)
_mul!(out::AbstractMatrix, D::Diagonal, B::AbstractMatrix, _add) =
_mul_diag!(out, D, B, _add)
_mul!(out::AbstractMatrix, A::AbstractMatrix, D::Diagonal, _add) =
_mul_diag!(out, A, D, _add)
_mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, _add) =
_mul_diag!(C, Da, Db, _add)
_mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, _add) =
_mul_diag!(C, Da, Db, _add)
_mul!(out::AbstractVector, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) =
_mul_diag!(out, D, V, alpha, beta)
_mul!(out::AbstractMatrix, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) =
_mul_diag!(out, D, V, alpha, beta)
for MT in (:AbstractMatrix, :AbstractTriangular)
@eval begin
_mul!(out::AbstractMatrix, D::Diagonal, B::$MT, alpha::Number, beta::Number) =
_mul_diag!(out, D, B, alpha, beta)
_mul!(out::AbstractMatrix, A::$MT, D::Diagonal, alpha::Number, beta::Number) =
_mul_diag!(out, A, D, alpha, beta)
end
end
_mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) =
_mul_diag!(C, Da, Db, alpha, beta)

function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal)
_muldiag_size_check(size(Da), size(A))
Expand Down