From 4fe80eece5a51167b13d90a30d8bb20ea7e1c576 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 21 Oct 2024 17:55:08 +0530 Subject: [PATCH 01/12] Fix multiplying a triangular matrix and a Diagonal --- stdlib/LinearAlgebra/src/diagonal.jl | 63 ++++++++++++++++------------ 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 1ed599fbb120e..727e0a0151194 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -397,17 +397,9 @@ function lmul!(D::Diagonal, T::Tridiagonal) return T end -@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, _add::MulAddMul) - @inbounds for j in axes(B, 2) - @simd for i in axes(B, 1) - _modify!(_add, D.diag[i] * B[i,j], out, (i,j)) - end - end - out -end -_has_matching_zeros(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true -_has_matching_zeros(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true -_has_matching_zeros(out, A) = false +_has_matching_storage(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true +_has_matching_storage(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true +_has_matching_storage(out, A) = false function _rowrange_tri_stored(B::UpperOrUnitUpperTriangular, col) isunit = B isa UnitUpperTriangular 1:min(col-isunit, size(B,1)) @@ -416,31 +408,44 @@ function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col) isunit = B isa UnitLowerTriangular col+isunit:size(B,1) end -_rowrange_tri_zeros(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1) -_rowrange_tri_zeros(B::LowerOrUnitLowerTriangular, col) = 1:col-1 +_rowrange_tri_nonstored(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1) +_rowrange_tri_nonstored(B::LowerOrUnitLowerTriangular, col) = 1:col-1 + +@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, _add::MulAddMul) + @inbounds for j in axes(B, 2) + @simd for i in axes(B, 1) + _modify!(_add, D.diag[i] * B[i,j], out, (i,j)) + end + end + out +end function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul) isunit = B isa UnitUpperOrUnitLowerTriangular - out_maybeparent, B_maybeparent = _has_matching_zeros(out, B) ? (parent(out), parent(B)) : (out, B) + out_maybeparent, B_maybeparent = _has_matching_storage(out, B) ? (parent(out), parent(B)) : (out, B) for j in axes(B, 2) # store the diagonal separately for unit triangular matrices if isunit @inbounds _modify!(_add, D.diag[j] * B[j,j], out, (j,j)) end - # The indices of out corresponding to the stored indices of B + # indices of out corresponding to the stored indices of B rowrange = _rowrange_tri_stored(B, j) @inbounds @simd for i in rowrange _modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j)) end - # Fill the indices of out corresponding to the zeros of B + # indices of out corresponding to the zeros of B # we only fill these if out and B don't have matching zeros - if !_has_matching_zeros(out, B) - rowrange = _rowrange_tri_zeros(B, j) - @inbounds @simd for i in rowrange - _modify!(_add, D.diag[i] * B[i,j], out, (i,j)) + if !_has_matching_storage(out, B) + rowrange = _rowrange_tri_nonstored(B, j) + if haszero(eltype(out)) + _rmul_or_fill!(@view(out[rowrange,j]), _add.beta) + else + @inbounds @simd for i in rowrange + _modify!(_add, D.diag[i] * B[i,j], out, (i,j)) + end end end end - return out + out end @inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} @@ -462,7 +467,7 @@ function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _a _add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta) # if both A and out have the same upper/lower triangular structure, # we may directly read and write from the parents - out_maybeparent, A_maybeparent = _has_matching_zeros(out, A) ? (parent(out), parent(A)) : (out, A) + out_maybeparent, A_maybeparent = _has_matching_storage(out, A) ? (parent(out), parent(A)) : (out, A) for j in axes(A, 2) dja = _add(@inbounds D.diag[j]) # store the diagonal separately for unit triangular matrices @@ -474,12 +479,16 @@ function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _a @inbounds @simd for i in rowrange _modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j)) end - # Fill the indices of out corresponding to the zeros of A + # indices of out corresponding to the zeros of A # we only fill these if out and A don't have matching zeros - if !_has_matching_zeros(out, A) - rowrange = _rowrange_tri_zeros(A, j) - @inbounds @simd for i in rowrange - _modify!(_add_aisone, A[i,j] * dja, out, (i,j)) + if !_has_matching_storage(out, A) + rowrange = _rowrange_tri_nonstored(A, j) + if haszero(eltype(out)) + _rmul_or_fill!(@view(out[rowrange,j]), _add.beta) + else + @inbounds @simd for i in rowrange + _modify!(_add, A[i,j] * dja, out, (i,j)) + end end end end From ba86176df8635b90e40cbc0a7151389bcfb4b2de Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 25 Oct 2024 12:55:40 +0530 Subject: [PATCH 02/12] Rename _has_matching_storage to _has_matching_zeros --- stdlib/LinearAlgebra/src/diagonal.jl | 61 ++++++++++++---------------- 1 file changed, 26 insertions(+), 35 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 727e0a0151194..e7dc12c6e6d29 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -397,20 +397,6 @@ function lmul!(D::Diagonal, T::Tridiagonal) return T end -_has_matching_storage(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true -_has_matching_storage(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true -_has_matching_storage(out, A) = false -function _rowrange_tri_stored(B::UpperOrUnitUpperTriangular, col) - isunit = B isa UnitUpperTriangular - 1:min(col-isunit, size(B,1)) -end -function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col) - isunit = B isa UnitLowerTriangular - col+isunit:size(B,1) -end -_rowrange_tri_nonstored(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1) -_rowrange_tri_nonstored(B::LowerOrUnitLowerTriangular, col) = 1:col-1 - @inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, _add::MulAddMul) @inbounds for j in axes(B, 2) @simd for i in axes(B, 1) @@ -419,29 +405,38 @@ _rowrange_tri_nonstored(B::LowerOrUnitLowerTriangular, col) = 1:col-1 end out end +_has_matching_zeros(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true +_has_matching_zeros(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true +_has_matching_zeros(out, A) = false +function _rowrange_tri_stored(B::UpperOrUnitUpperTriangular, col) + isunit = B isa UnitUpperTriangular + 1:min(col-isunit, size(B,1)) +end +function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col) + isunit = B isa UnitLowerTriangular + col+isunit:size(B,1) +end +_rowrange_tri_zeros(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1) +_rowrange_tri_zeros(B::LowerOrUnitLowerTriangular, col) = 1:col-1 function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul) isunit = B isa UnitUpperOrUnitLowerTriangular - out_maybeparent, B_maybeparent = _has_matching_storage(out, B) ? (parent(out), parent(B)) : (out, B) + out_maybeparent, B_maybeparent = _has_matching_zeros(out, B) ? (parent(out), parent(B)) : (out, B) for j in axes(B, 2) # store the diagonal separately for unit triangular matrices if isunit @inbounds _modify!(_add, D.diag[j] * B[j,j], out, (j,j)) end - # indices of out corresponding to the stored indices of B + # The indices of out corresponding to the stored indices of B rowrange = _rowrange_tri_stored(B, j) @inbounds @simd for i in rowrange _modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j)) end - # indices of out corresponding to the zeros of B + # Fill the indices of out corresponding to the zeros of B # we only fill these if out and B don't have matching zeros - if !_has_matching_storage(out, B) - rowrange = _rowrange_tri_nonstored(B, j) - if haszero(eltype(out)) - _rmul_or_fill!(@view(out[rowrange,j]), _add.beta) - else - @inbounds @simd for i in rowrange - _modify!(_add, D.diag[i] * B[i,j], out, (i,j)) - end + if !_has_matching_zeros(out, B) + rowrange = _rowrange_tri_zeros(B, j) + @inbounds @simd for i in rowrange + _modify!(_add, D.diag[i] * B[i,j], out, (i,j)) end end end @@ -467,7 +462,7 @@ function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _a _add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta) # if both A and out have the same upper/lower triangular structure, # we may directly read and write from the parents - out_maybeparent, A_maybeparent = _has_matching_storage(out, A) ? (parent(out), parent(A)) : (out, A) + out_maybeparent, A_maybeparent = _has_matching_zeros(out, A) ? (parent(out), parent(A)) : (out, A) for j in axes(A, 2) dja = _add(@inbounds D.diag[j]) # store the diagonal separately for unit triangular matrices @@ -479,16 +474,12 @@ function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _a @inbounds @simd for i in rowrange _modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j)) end - # indices of out corresponding to the zeros of A + # Fill the indices of out corresponding to the zeros of A # we only fill these if out and A don't have matching zeros - if !_has_matching_storage(out, A) - rowrange = _rowrange_tri_nonstored(A, j) - if haszero(eltype(out)) - _rmul_or_fill!(@view(out[rowrange,j]), _add.beta) - else - @inbounds @simd for i in rowrange - _modify!(_add, A[i,j] * dja, out, (i,j)) - end + if !_has_matching_zeros(out, A) + rowrange = _rowrange_tri_zeros(A, j) + @inbounds @simd for i in rowrange + _modify!(_add, A[i,j] * dja, out, (i,j)) end end end From 3315b2ee40cbc42663b6d85b4f61d2e1c725d87d Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 25 Oct 2024 15:34:51 +0530 Subject: [PATCH 03/12] MulAddMul type parameters in _mul methods --- stdlib/LinearAlgebra/src/diagonal.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index e7dc12c6e6d29..d31a4fdf8d299 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -518,15 +518,15 @@ function _mul_diag!(out, A, B, _add) return out end -_mul!(out::AbstractVecOrMat, D::Diagonal, V::AbstractVector, _add) = +_mul!(out::AbstractVecOrMat, D::Diagonal, V::AbstractVector, _add::MulAddMul) = _mul_diag!(out, D, V, _add) -_mul!(out::AbstractMatrix, D::Diagonal, B::AbstractMatrix, _add) = +_mul!(out::AbstractMatrix, D::Diagonal, B::AbstractMatrix, _add::MulAddMul) = _mul_diag!(out, D, B, _add) -_mul!(out::AbstractMatrix, A::AbstractMatrix, D::Diagonal, _add) = +_mul!(out::AbstractMatrix, A::AbstractMatrix, D::Diagonal, _add::MulAddMul) = _mul_diag!(out, A, D, _add) -_mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, _add) = +_mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, _add::MulAddMul) = _mul_diag!(C, Da, Db, _add) -_mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, _add) = +_mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, _add::MulAddMul) = _mul_diag!(C, Da, Db, _add) function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal) From a6c05498aee8f47577bc8b2d0341e6890fe8e54c Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 28 Oct 2024 12:01:42 +0530 Subject: [PATCH 04/12] Correct MulAddMul in triangular * Diagonal __muldiag_nonzeroalpha! --- stdlib/LinearAlgebra/src/diagonal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index d31a4fdf8d299..f8b6e6f3e3c1a 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -479,7 +479,7 @@ function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _a if !_has_matching_zeros(out, A) rowrange = _rowrange_tri_zeros(A, j) @inbounds @simd for i in rowrange - _modify!(_add, A[i,j] * dja, out, (i,j)) + _modify!(_add_aisone, A[i,j] * dja, out, (i,j)) end end end From c9d1885bdcd5a8e5f52083a8fde81902cfa05ccf Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 28 Oct 2024 16:53:07 +0530 Subject: [PATCH 05/12] Merge __muldiag! methods --- stdlib/LinearAlgebra/src/diagonal.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index f8b6e6f3e3c1a..2f835c8a20f1b 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -496,7 +496,6 @@ end out end -# ambiguity resolution @inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul) @inbounds for j in axes(D2, 2), i in axes(D2, 1) _modify!(_add, D1.diag[i] * D2[i,j], out, (i,j)) @@ -504,9 +503,9 @@ end out end -# muldiag mainly handles the zero-alpha case, so that we need only +# muldiag mainly handles the zero-alpha case, so that we need only # specialize the non-trivial case -function _mul_diag!(out, A, B, _add) +function _mul_diag!(out, A, B, _add::MulAddMul) require_one_based_indexing(out, A, B) _muldiag_size_check(size(out), size(A), size(B)) alpha, beta = _add.alpha, _add.beta From 45a1a41839ebbd20bf48a8dc89fef4c6d966bb7f Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 28 Oct 2024 17:15:56 +0530 Subject: [PATCH 06/12] Fix whitespace --- stdlib/LinearAlgebra/src/diagonal.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 2f835c8a20f1b..26eca7fe01a2e 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -440,7 +440,7 @@ function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _a end end end - out + return out end @inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} @@ -503,7 +503,7 @@ end out end -# muldiag mainly handles the zero-alpha case, so that we need only +# muldiag mainly handles the zero-alpha case, so that we need only # specialize the non-trivial case function _mul_diag!(out, A, B, _add::MulAddMul) require_one_based_indexing(out, A, B) From a83e139a312050ede6d997470f5a0556dcc99ecf Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 28 Oct 2024 17:17:23 +0530 Subject: [PATCH 07/12] Remove MulAddMul type from _mul! signature --- stdlib/LinearAlgebra/src/diagonal.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 26eca7fe01a2e..f7ca68f569bf6 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -517,15 +517,15 @@ function _mul_diag!(out, A, B, _add::MulAddMul) return out end -_mul!(out::AbstractVecOrMat, D::Diagonal, V::AbstractVector, _add::MulAddMul) = +_mul!(out::AbstractVecOrMat, D::Diagonal, V::AbstractVector, _add) = _mul_diag!(out, D, V, _add) -_mul!(out::AbstractMatrix, D::Diagonal, B::AbstractMatrix, _add::MulAddMul) = +_mul!(out::AbstractMatrix, D::Diagonal, B::AbstractMatrix, _add) = _mul_diag!(out, D, B, _add) -_mul!(out::AbstractMatrix, A::AbstractMatrix, D::Diagonal, _add::MulAddMul) = +_mul!(out::AbstractMatrix, A::AbstractMatrix, D::Diagonal, _add) = _mul_diag!(out, A, D, _add) -_mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, _add::MulAddMul) = +_mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, _add) = _mul_diag!(C, Da, Db, _add) -_mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, _add::MulAddMul) = +_mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, _add) = _mul_diag!(C, Da, Db, _add) function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal) From 375d7b84a6fda68a77007a08f9a211d73c67df0c Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 28 Oct 2024 11:59:54 +0530 Subject: [PATCH 08/12] Replace MulAddMul by alpha,beta in __muldiag methods --- stdlib/LinearAlgebra/src/bidiag.jl | 16 +++-- stdlib/LinearAlgebra/src/diagonal.jl | 100 ++++++++++++++------------- 2 files changed, 64 insertions(+), 52 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index b38a983296065..4edd8da4d19f2 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -472,10 +472,14 @@ const BiTri = Union{Bidiagonal,Tridiagonal} @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) @inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) = - @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = - @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) +for T in (:AbstractMatrix, :Diagonal) + @eval begin + @inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::$T, alpha::Number, beta::Number) = + @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) + @inline _mul!(C::AbstractMatrix, A::$T, B::BandedMatrix, alpha::Number, beta::Number) = + @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) + end +end @inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) @@ -831,6 +835,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) C end +_mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) = + @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul) require_one_based_indexing(C) check_A_mul_B!_sizes(size(C), size(A), size(B)) @@ -1067,6 +1073,8 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd C end +_mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, alpha::Number, beta::Number) = + @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) _mul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add::MulAddMul) = _dibimul!(C, A, B, _add) _mul!(C::AbstractMatrix, A::Diagonal, B::TriSym, _add::MulAddMul) = diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index f7ca68f569bf6..8a231444c5b6d 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -397,10 +397,10 @@ function lmul!(D::Diagonal, T::Tridiagonal) return T end -@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, _add::MulAddMul) +@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, alpha::Number, beta::Number) @inbounds for j in axes(B, 2) - @simd for i in axes(B, 1) - _modify!(_add, D.diag[i] * B[i,j], out, (i,j)) + for i in axes(B, 1) + @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j)) end end out @@ -418,115 +418,119 @@ function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col) end _rowrange_tri_zeros(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1) _rowrange_tri_zeros(B::LowerOrUnitLowerTriangular, col) = 1:col-1 -function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul) +function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, alpha::Number, beta::Number) isunit = B isa UnitUpperOrUnitLowerTriangular out_maybeparent, B_maybeparent = _has_matching_zeros(out, B) ? (parent(out), parent(B)) : (out, B) for j in axes(B, 2) # store the diagonal separately for unit triangular matrices if isunit - @inbounds _modify!(_add, D.diag[j] * B[j,j], out, (j,j)) + @inbounds @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[j] * B[j,j], out, (j,j)) end # The indices of out corresponding to the stored indices of B rowrange = _rowrange_tri_stored(B, j) - @inbounds @simd for i in rowrange - _modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j)) + @inbounds for i in rowrange + @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j)) end # Fill the indices of out corresponding to the zeros of B # we only fill these if out and B don't have matching zeros if !_has_matching_zeros(out, B) rowrange = _rowrange_tri_zeros(B, j) - @inbounds @simd for i in rowrange - _modify!(_add, D.diag[i] * B[i,j], out, (i,j)) + @inbounds for i in rowrange + @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j)) end end end - return out + out end -@inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} - beta = _add.beta - _add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta) +@inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, alpha::Number, beta::Number) @inbounds for j in axes(A, 2) - dja = _add(D.diag[j]) - @simd for i in axes(A, 1) - _modify!(_add_aisone, A[i,j] * dja, out, (i,j)) + dja = @stable_muladdmul MulAddMul(alpha,false)(D.diag[j]) + for i in axes(A, 1) + @stable_muladdmul _modify!(MulAddMul(true,beta), A[i,j] * dja, out, (i,j)) end end out end -function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} +function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, alpha::Number, beta::Number) isunit = A isa UnitUpperOrUnitLowerTriangular - beta = _add.beta - # since alpha is multiplied to the diagonal element of D, - # we may skip alpha in the second multiplication by setting ais1 to true - _add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta) # if both A and out have the same upper/lower triangular structure, # we may directly read and write from the parents - out_maybeparent, A_maybeparent = _has_matching_zeros(out, A) ? (parent(out), parent(A)) : (out, A) + out_maybeparent, A_maybeparent = _has_matching_zeros(out, A) ? (parent(out), parent(A)) : (out, A) for j in axes(A, 2) - dja = _add(@inbounds D.diag[j]) + dja = @stable_muladdmul MulAddMul(alpha,false)(@inbounds D.diag[j]) # store the diagonal separately for unit triangular matrices if isunit - @inbounds _modify!(_add_aisone, A[j,j] * dja, out, (j,j)) + # since alpha is multiplied to the diagonal element of D, + # we may skip alpha in the second multiplication by setting ais1 to true + @inbounds @stable_muladdmul _modify!(MulAddMul(true,beta), A[j,j] * dja, out, (j,j)) end # indices of out corresponding to the stored indices of A rowrange = _rowrange_tri_stored(A, j) - @inbounds @simd for i in rowrange - _modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j)) + @inbounds for i in rowrange + # since alpha is multiplied to the diagonal element of D, + # we may skip alpha in the second multiplication by setting ais1 to true + @stable_muladdmul _modify!(MulAddMul(true,beta), A_maybeparent[i,j] * dja, out_maybeparent, (i,j)) end # Fill the indices of out corresponding to the zeros of A # we only fill these if out and A don't have matching zeros if !_has_matching_zeros(out, A) rowrange = _rowrange_tri_zeros(A, j) - @inbounds @simd for i in rowrange - _modify!(_add_aisone, A[i,j] * dja, out, (i,j)) + @inbounds for i in rowrange + @stable_muladdmul _modify!(MulAddMul(true,beta), A[i,j] * dja, out, (i,j)) end end end out end -@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul) +@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha::Number, beta::Number) d1 = D1.diag d2 = D2.diag outd = out.diag - @inbounds @simd for i in eachindex(d1, d2, outd) - _modify!(_add, d1[i] * d2[i], outd, i) + @inbounds for i in eachindex(d1, d2, outd) + @stable_muladdmul _modify!(MulAddMul(alpha,beta), d1[i] * d2[i], outd, i) end out end -@inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul) - @inbounds for j in axes(D2, 2), i in axes(D2, 1) - _modify!(_add, D1.diag[i] * D2[i,j], out, (i,j)) +@inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, alpha::Number, beta::Number) + @inbounds for j in axes(D1, 2) + dja = @stable_muladdmul MulAddMul(alpha,false)(D2.diag[j]) + for i in axes(D1, 1) + @stable_muladdmul _modify!(MulAddMul(true,beta), D1[i,j] * dja, out, (i,j)) + end end out end -# muldiag mainly handles the zero-alpha case, so that we need only +# muldiag handles the zero-alpha case, so that we need only # specialize the non-trivial case -function _mul_diag!(out, A, B, _add::MulAddMul) +function _mul_diag!(out, A, B, alpha, beta) require_one_based_indexing(out, A, B) _muldiag_size_check(size(out), size(A), size(B)) - alpha, beta = _add.alpha, _add.beta if iszero(alpha) _rmul_or_fill!(out, beta) else - __muldiag_nonzeroalpha!(out, A, B, _add) + __muldiag_nonzeroalpha!(out, A, B, alpha, beta) end return out end -_mul!(out::AbstractVecOrMat, D::Diagonal, V::AbstractVector, _add) = - _mul_diag!(out, D, V, _add) -_mul!(out::AbstractMatrix, D::Diagonal, B::AbstractMatrix, _add) = - _mul_diag!(out, D, B, _add) -_mul!(out::AbstractMatrix, A::AbstractMatrix, D::Diagonal, _add) = - _mul_diag!(out, A, D, _add) -_mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, _add) = - _mul_diag!(C, Da, Db, _add) -_mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, _add) = - _mul_diag!(C, Da, Db, _add) +_mul!(out::AbstractVector, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) = + _mul_diag!(out, D, V, alpha, beta) +_mul!(out::AbstractMatrix, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) = + _mul_diag!(out, D, V, alpha, beta) +for MT in (:AbstractMatrix, :AbstractTriangular) + @eval begin + _mul!(out::AbstractMatrix, D::Diagonal, B::$MT, alpha::Number, beta::Number) = + _mul_diag!(out, D, B, alpha, beta) + _mul!(out::AbstractMatrix, A::$MT, D::Diagonal, alpha::Number, beta::Number) = + _mul_diag!(out, A, D, alpha, beta) + end +end +_mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) = + _mul_diag!(C, Da, Db, alpha, beta) function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal) _muldiag_size_check(size(Da), size(A)) From 96b1d1b23b1faa3bd721a3d58cd7ef3f399d18e9 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 28 Oct 2024 14:24:33 +0530 Subject: [PATCH 09/12] Remove annotations --- stdlib/LinearAlgebra/src/diagonal.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 8a231444c5b6d..24b295b46e4e2 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -398,10 +398,8 @@ function lmul!(D::Diagonal, T::Tridiagonal) end @inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, alpha::Number, beta::Number) - @inbounds for j in axes(B, 2) - for i in axes(B, 1) - @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j)) - end + @inbounds for j in axes(B, 2), i in axes(B, 1) + @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j)) end out end From 1014fae19d7fc8de3355d2f6b7bc1c28e969fcce Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 28 Oct 2024 14:50:33 +0530 Subject: [PATCH 10/12] inline __muldiag_nonzeroalpha! for Diagonal destination --- stdlib/LinearAlgebra/src/diagonal.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 24b295b46e4e2..f79d96d12456f 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -491,7 +491,6 @@ end end out end - @inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, alpha::Number, beta::Number) @inbounds for j in axes(D1, 2) dja = @stable_muladdmul MulAddMul(alpha,false)(D2.diag[j]) From 3329958680d13e14ea689fd6d35b67e610b450d1 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 29 Oct 2024 19:34:03 +0530 Subject: [PATCH 11/12] Collect common method --- stdlib/LinearAlgebra/src/diagonal.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index f79d96d12456f..13db58e10c6f9 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -441,7 +441,7 @@ function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, al out end -@inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, alpha::Number, beta::Number) +@inline function __muldiag_nonzeroalpha_right!(out, A, D::Diagonal, alpha::Number, beta::Number) @inbounds for j in axes(A, 2) dja = @stable_muladdmul MulAddMul(alpha,false)(D.diag[j]) for i in axes(A, 1) @@ -450,6 +450,10 @@ end end out end + +function __muldiag_nonzeroalpha!(out, A, D::Diagonal, alpha::Number, beta::Number) + __muldiag_nonzeroalpha_right!(out, A, D, alpha, beta) +end function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, alpha::Number, beta::Number) isunit = A isa UnitUpperOrUnitLowerTriangular # if both A and out have the same upper/lower triangular structure, @@ -482,6 +486,11 @@ function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, al out end +# ambiguity resolution +function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, alpha::Number, beta::Number) + __muldiag_nonzeroalpha_right!(out, D1, D2, alpha, beta) +end + @inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha::Number, beta::Number) d1 = D1.diag d2 = D2.diag @@ -491,15 +500,6 @@ end end out end -@inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, alpha::Number, beta::Number) - @inbounds for j in axes(D1, 2) - dja = @stable_muladdmul MulAddMul(alpha,false)(D2.diag[j]) - for i in axes(D1, 1) - @stable_muladdmul _modify!(MulAddMul(true,beta), D1[i,j] * dja, out, (i,j)) - end - end - out -end # muldiag handles the zero-alpha case, so that we need only # specialize the non-trivial case From de1d06c3aba8c70d1f1d735b365f88a35d0aa7a9 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 29 Oct 2024 19:37:31 +0530 Subject: [PATCH 12/12] Add annotations back --- stdlib/LinearAlgebra/src/diagonal.jl | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 13db58e10c6f9..7359c2867f6a4 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -398,10 +398,12 @@ function lmul!(D::Diagonal, T::Tridiagonal) end @inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, alpha::Number, beta::Number) - @inbounds for j in axes(B, 2), i in axes(B, 1) - @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j)) + @inbounds for j in axes(B, 2) + @simd for i in axes(B, 1) + @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j)) + end end - out + return out end _has_matching_zeros(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true _has_matching_zeros(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true @@ -426,29 +428,29 @@ function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, al end # The indices of out corresponding to the stored indices of B rowrange = _rowrange_tri_stored(B, j) - @inbounds for i in rowrange + @inbounds @simd for i in rowrange @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j)) end # Fill the indices of out corresponding to the zeros of B # we only fill these if out and B don't have matching zeros if !_has_matching_zeros(out, B) rowrange = _rowrange_tri_zeros(B, j) - @inbounds for i in rowrange + @inbounds @simd for i in rowrange @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j)) end end end - out + return out end @inline function __muldiag_nonzeroalpha_right!(out, A, D::Diagonal, alpha::Number, beta::Number) @inbounds for j in axes(A, 2) dja = @stable_muladdmul MulAddMul(alpha,false)(D.diag[j]) - for i in axes(A, 1) + @simd for i in axes(A, 1) @stable_muladdmul _modify!(MulAddMul(true,beta), A[i,j] * dja, out, (i,j)) end end - out + return out end function __muldiag_nonzeroalpha!(out, A, D::Diagonal, alpha::Number, beta::Number) @@ -469,7 +471,7 @@ function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, al end # indices of out corresponding to the stored indices of A rowrange = _rowrange_tri_stored(A, j) - @inbounds for i in rowrange + @inbounds @simd for i in rowrange # since alpha is multiplied to the diagonal element of D, # we may skip alpha in the second multiplication by setting ais1 to true @stable_muladdmul _modify!(MulAddMul(true,beta), A_maybeparent[i,j] * dja, out_maybeparent, (i,j)) @@ -478,12 +480,12 @@ function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, al # we only fill these if out and A don't have matching zeros if !_has_matching_zeros(out, A) rowrange = _rowrange_tri_zeros(A, j) - @inbounds for i in rowrange + @inbounds @simd for i in rowrange @stable_muladdmul _modify!(MulAddMul(true,beta), A[i,j] * dja, out, (i,j)) end end end - out + return out end # ambiguity resolution @@ -495,10 +497,10 @@ end d1 = D1.diag d2 = D2.diag outd = out.diag - @inbounds for i in eachindex(d1, d2, outd) + @inbounds @simd for i in eachindex(d1, d2, outd) @stable_muladdmul _modify!(MulAddMul(alpha,beta), d1[i] * d2[i], outd, i) end - out + return out end # muldiag handles the zero-alpha case, so that we need only