From ff832b2b8bec6d070274fcb0d1349e2bb04d41bd Mon Sep 17 00:00:00 2001 From: sarah Date: Wed, 2 Oct 2024 04:00:40 +0200 Subject: [PATCH] replace the old constrained api --- src/iter/mod.rs | 84 +- src/linalg/matmul/mod.rs | 14 +- src/mat/matmut.rs | 21 +- src/mat/matown.rs | 4 +- src/mat/matref.rs | 15 +- src/perm/mod.rs | 16 +- src/perm/permref.rs | 23 +- src/sparse/csc/matmut.rs | 2 +- src/sparse/csc/matown.rs | 2 +- src/sparse/csc/matref.rs | 2 +- src/sparse/csc/symbolic_own.rs | 23 +- src/sparse/csc/symbolic_ref.rs | 99 +- src/sparse/csr/matmut.rs | 2 +- src/sparse/csr/matown.rs | 2 +- src/sparse/csr/matref.rs | 2 +- src/sparse/csr/symbolic_own.rs | 23 +- src/sparse/csr/symbolic_ref.rs | 33 +- src/sparse/linalg/amd.rs | 14 +- src/sparse/linalg/cholesky.rs | 140 +- src/sparse/linalg/lu.rs | 22 +- src/sparse/linalg/matmul.rs | 13 +- src/sparse/linalg/mod.rs | 2 +- src/sparse/linalg/qr.rs | 85 +- src/sparse/linalg/triangular_solve.rs | 2934 +++++++++++++------------ src/sparse/mod.rs | 202 +- src/utils/bound.rs | 375 +++- src/utils/constrained/mat.rs | 201 -- src/utils/constrained/mod.rs | 283 --- src/utils/constrained/perm.rs | 70 - src/utils/constrained/sparse.rs | 292 --- src/utils/mod.rs | 2 - 31 files changed, 2239 insertions(+), 2763 deletions(-) delete mode 100644 src/utils/constrained/mat.rs delete mode 100644 src/utils/constrained/mod.rs delete mode 100644 src/utils/constrained/perm.rs delete mode 100644 src/utils/constrained/sparse.rs diff --git a/src/iter/mod.rs b/src/iter/mod.rs index 0a7d2996..d256999d 100644 --- a/src/iter/mod.rs +++ b/src/iter/mod.rs @@ -1,4 +1,6 @@ -use crate::{linalg::entity::GroupFor, ColMut, ColRef, Entity, MatMut, MatRef, RowMut, RowRef}; +use crate::{ + linalg::entity::GroupFor, mat, ColMut, ColRef, Entity, MatMut, MatRef, RowMut, RowRef, Shape, +}; use self::chunks::ChunkPolicy; @@ -173,23 +175,23 @@ pub struct ElemIterMut<'a, E: Entity> { /// Iterator over the columns of a matrix. #[derive(Debug, Clone)] -pub struct ColIter<'a, E: Entity> { - pub(crate) inner: MatRef<'a, E>, +pub struct ColIter<'a, E: Entity, R: Shape = usize> { + pub(crate) inner: MatRef<'a, E, R>, } /// Iterator over the columns of a matrix. #[derive(Debug)] -pub struct ColIterMut<'a, E: Entity> { - pub(crate) inner: MatMut<'a, E>, +pub struct ColIterMut<'a, E: Entity, R: Shape = usize> { + pub(crate) inner: MatMut<'a, E, R>, } /// Iterator over the rows of a matrix. #[derive(Debug, Clone)] -pub struct RowIter<'a, E: Entity> { - pub(crate) inner: MatRef<'a, E>, +pub struct RowIter<'a, E: Entity, C: Shape = usize> { + pub(crate) inner: MatRef<'a, E, usize, C>, } /// Iterator over the rows of a matrix. #[derive(Debug)] -pub struct RowIterMut<'a, E: Entity> { - pub(crate) inner: MatMut<'a, E>, +pub struct RowIterMut<'a, E: Entity, C: Shape = usize> { + pub(crate) inner: MatMut<'a, E, usize, C>, } impl<'a, E: Entity> Iterator for ElemIter<'a, E> { @@ -272,12 +274,22 @@ impl<'a, E: Entity> ExactSizeIterator for ElemIterMut<'a, E> { } } -impl<'a, E: Entity> Iterator for ColIter<'a, E> { - type Item = ColRef<'a, E>; +impl<'a, E: Entity, R: Shape> Iterator for ColIter<'a, E, R> { + type Item = ColRef<'a, E, R>; #[inline] fn next(&mut self) -> Option { - match core::mem::take(&mut self.inner).split_first_col() { + let nrows = self.inner.nrows(); + match core::mem::replace( + &mut self.inner, + mat::from_column_major_slice_generic( + E::faer_map(E::UNIT, |()| &[] as &[E::Unit]), + nrows, + 0, + ), + ) + .split_first_col() + { Some((head, tail)) => { self.inner = tail; Some(head) @@ -291,10 +303,20 @@ impl<'a, E: Entity> Iterator for ColIter<'a, E> { (self.inner.ncols(), Some(self.inner.ncols())) } } -impl<'a, E: Entity> DoubleEndedIterator for ColIter<'a, E> { +impl<'a, E: Entity, R: Shape> DoubleEndedIterator for ColIter<'a, E, R> { #[inline] fn next_back(&mut self) -> Option { - match core::mem::take(&mut self.inner).split_last_col() { + let nrows = self.inner.nrows(); + match core::mem::replace( + &mut self.inner, + mat::from_column_major_slice_generic( + E::faer_map(E::UNIT, |()| &[] as &[E::Unit]), + nrows, + 0, + ), + ) + .split_last_col() + { Some((head, tail)) => { self.inner = tail; Some(head) @@ -303,19 +325,29 @@ impl<'a, E: Entity> DoubleEndedIterator for ColIter<'a, E> { } } } -impl<'a, E: Entity> ExactSizeIterator for ColIter<'a, E> { +impl<'a, E: Entity, R: Shape> ExactSizeIterator for ColIter<'a, E, R> { #[inline] fn len(&self) -> usize { self.inner.ncols() } } -impl<'a, E: Entity> Iterator for ColIterMut<'a, E> { - type Item = ColMut<'a, E>; +impl<'a, E: Entity, R: Shape> Iterator for ColIterMut<'a, E, R> { + type Item = ColMut<'a, E, R>; #[inline] fn next(&mut self) -> Option { - match core::mem::take(&mut self.inner).split_first_col_mut() { + let nrows = self.inner.nrows(); + match core::mem::replace( + &mut self.inner, + mat::from_column_major_slice_mut_generic( + E::faer_map(E::UNIT, |()| &mut [] as &mut [E::Unit]), + nrows, + 0, + ), + ) + .split_first_col_mut() + { Some((head, tail)) => { self.inner = tail; Some(head) @@ -329,10 +361,20 @@ impl<'a, E: Entity> Iterator for ColIterMut<'a, E> { (self.inner.ncols(), Some(self.inner.ncols())) } } -impl<'a, E: Entity> DoubleEndedIterator for ColIterMut<'a, E> { +impl<'a, E: Entity, R: Shape> DoubleEndedIterator for ColIterMut<'a, E, R> { #[inline] fn next_back(&mut self) -> Option { - match core::mem::take(&mut self.inner).split_last_col_mut() { + let nrows = self.inner.nrows(); + match core::mem::replace( + &mut self.inner, + mat::from_column_major_slice_mut_generic( + E::faer_map(E::UNIT, |()| &mut [] as &mut [E::Unit]), + nrows, + 0, + ), + ) + .split_last_col_mut() + { Some((head, tail)) => { self.inner = tail; Some(head) @@ -341,7 +383,7 @@ impl<'a, E: Entity> DoubleEndedIterator for ColIterMut<'a, E> { } } } -impl<'a, E: Entity> ExactSizeIterator for ColIterMut<'a, E> { +impl<'a, E: Entity, R: Shape> ExactSizeIterator for ColIterMut<'a, E, R> { #[inline] fn len(&self) -> usize { self.inner.ncols() diff --git a/src/linalg/matmul/mod.rs b/src/linalg/matmul/mod.rs index ef9010d0..d1cec3e8 100644 --- a/src/linalg/matmul/mod.rs +++ b/src/linalg/matmul/mod.rs @@ -555,22 +555,18 @@ pub mod inner_prod { } } else { with_dim!(nrows, nrows); - with_dim!(ncols, 1); - let zero_idx = ncols.check(0); - let a = crate::utils::constrained::mat::MatRef::new(a.as_2d(), nrows, ncols); - let b = crate::utils::constrained::mat::MatRef::new(b.as_2d(), nrows, ncols); + let a = a.as_shape(nrows); + let b = b.as_shape(nrows); + let mut acc = E::faer_zero(); if conj_lhs == conj_rhs { for i in nrows.indices() { - acc = acc.faer_add(E::faer_mul(a.read(i, zero_idx), b.read(i, zero_idx))); + acc = acc.faer_add(E::faer_mul(a.read(i), b.read(i))); } } else { for i in nrows.indices() { - acc = acc.faer_add(E::faer_mul( - a.read(i, zero_idx).faer_conj(), - b.read(i, zero_idx), - )); + acc = acc.faer_add(E::faer_mul(a.read(i).faer_conj(), b.read(i))); } } acc diff --git a/src/mat/matmut.rs b/src/mat/matmut.rs index 79881c65..52464fac 100644 --- a/src/mat/matmut.rs +++ b/src/mat/matmut.rs @@ -7,6 +7,7 @@ use crate::{ linalg::zip, unzipped, zipped, Idx, IdxInc, Unbind, }; +use core::ops::Range; /// Mutable view over a matrix, similar to a mutable reference to a 2D strided [prim@slice]. /// @@ -1460,6 +1461,18 @@ impl<'a, E: Entity, R: Shape, C: Shape> MatMut<'a, E, R, C> { unsafe { self.into_const().subcols(col_start, ncols).const_cast() } } + #[track_caller] + #[inline(always)] + pub fn subcols_range(self, cols: Range>) -> MatRef<'a, E, R, usize> { + self.into_const().subcols_range(cols) + } + + #[track_caller] + #[inline(always)] + pub fn subcols_range_mut(self, cols: Range>) -> MatMut<'a, E, R, usize> { + unsafe { self.into_const().subcols_range(cols).const_cast() } + } + /// Returns a view over the row at the given index. /// /// # Safety @@ -1755,7 +1768,7 @@ impl<'a, E: Entity, R: Shape, C: Shape> MatMut<'a, E, R, C> { /// Returns an iterator over the columns of the matrix. #[inline] - pub fn col_iter(self) -> iter::ColIter<'a, E> { + pub fn col_iter(self) -> iter::ColIter<'a, E, R> { self.into_const().col_iter() } @@ -1767,9 +1780,11 @@ impl<'a, E: Entity, R: Shape, C: Shape> MatMut<'a, E, R, C> { /// Returns an iterator over the columns of the matrix. #[inline] - pub fn col_iter_mut(self) -> iter::ColIterMut<'a, E> { + pub fn col_iter_mut(self) -> iter::ColIterMut<'a, E, R> { + let nrows = self.nrows(); + let ncols = self.ncols(); iter::ColIterMut { - inner: self.as_dyn_mut(), + inner: self.as_shape_mut(nrows, ncols.unbound()), } } diff --git a/src/mat/matown.rs b/src/mat/matown.rs index 2248eeba..65b42e52 100644 --- a/src/mat/matown.rs +++ b/src/mat/matown.rs @@ -429,7 +429,7 @@ impl Mat { /// Returns an iterator over the columns of the matrix. #[inline] - pub fn col_iter(&self) -> iter::ColIter<'_, E> { + pub fn col_iter(&self) -> iter::ColIter<'_, E, R> { self.as_ref().col_iter() } @@ -441,7 +441,7 @@ impl Mat { /// Returns an iterator over the columns of the matrix. #[inline] - pub fn col_iter_mut(&mut self) -> iter::ColIterMut<'_, E> { + pub fn col_iter_mut(&mut self) -> iter::ColIterMut<'_, E, R> { self.as_mut().col_iter_mut() } diff --git a/src/mat/matref.rs b/src/mat/matref.rs index b774de7c..e2a7e6b1 100644 --- a/src/mat/matref.rs +++ b/src/mat/matref.rs @@ -3,6 +3,7 @@ use crate::{ assert, debug_assert, diag::DiagRef, iter, iter::chunks::ChunkPolicy, unzipped, utils::bound::*, zipped, Idx, IdxInc, Shape, Unbind, }; +use core::ops::Range; use generativity::make_guard; /// Immutable view over a matrix, similar to an immutable reference to a 2D strided [prim@slice]. @@ -818,6 +819,14 @@ impl<'a, E: Entity, R: Shape, C: Shape> MatRef<'a, E, R, C> { unsafe { self.subcols_unchecked(col_start, ncols) } } + #[track_caller] + #[inline(always)] + pub fn subcols_range(self, cols: Range>) -> MatRef<'a, E, R, usize> { + assert!(all(cols.start <= self.ncols(), cols.end <= self.ncols())); + let ncols = cols.end.unbound().saturating_sub(cols.start.unbound()); + unsafe { self.subcols_unchecked(cols.start, ncols) } + } + /// Returns a view over the row at the given index. /// /// # Safety @@ -1167,9 +1176,11 @@ impl<'a, E: Entity, R: Shape, C: Shape> MatRef<'a, E, R, C> { /// Returns an iterator over the columns of the matrix. #[inline] - pub fn col_iter(self) -> iter::ColIter<'a, E> { + pub fn col_iter(self) -> iter::ColIter<'a, E, R> { + let nrows = self.nrows(); + let ncols = self.ncols(); iter::ColIter { - inner: self.as_dyn(), + inner: self.as_shape(nrows, ncols.unbound()), } } diff --git a/src/perm/mod.rs b/src/perm/mod.rs index 1351e8ff..09220069 100644 --- a/src/perm/mod.rs +++ b/src/perm/mod.rs @@ -1,4 +1,4 @@ -use crate::{assert, col::*, linalg::temp_mat_uninit, mat::*, row::*, utils::constrained, *}; +use crate::{assert, col::*, linalg::temp_mat_uninit, mat::*, row::*, *}; use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; @@ -217,13 +217,11 @@ pub fn permute_rows( with_dim!(m, src.nrows()); with_dim!(n, src.ncols()); - let mut dst = constrained::mat::MatMut::new(dst, m, n); - let src = constrained::mat::MatRef::new(src, m, n); - let perm = constrained::perm::PermRef::new(perm_indices, m).arrays().0; + let mut dst = dst.as_shape_mut(m, n); + let src = src.as_shape(m, n); + let perm = perm_indices.as_shape(m).bound_arrays().0; - if dst.rb().into_inner().row_stride().unsigned_abs() - < dst.rb().into_inner().col_stride().unsigned_abs() - { + if dst.rb().row_stride().unsigned_abs() < dst.rb().col_stride().unsigned_abs() { for j in n.indices() { for i in m.indices() { dst.rb_mut().write(i, j, src.read(perm[i].zx(), j)); @@ -231,8 +229,8 @@ pub fn permute_rows( } } else { for i in m.indices() { - let src_i = src.into_inner().row(perm[i].zx().unbound()); - let mut dst_i = dst.rb_mut().into_inner().row_mut(i.unbound()); + let src_i = src.row(perm[i].zx()); + let mut dst_i = dst.rb_mut().row_mut(i); dst_i.copy_from(src_i); } diff --git a/src/perm/permref.rs b/src/perm/permref.rs index e1e52348..65eb9d47 100644 --- a/src/perm/permref.rs +++ b/src/perm/permref.rs @@ -1,5 +1,8 @@ use super::*; -use crate::assert; +use crate::{ + assert, + utils::bound::{Array, Dim}, +}; /// Immutable permutation matrix view. #[derive(Debug)] @@ -176,3 +179,21 @@ impl<'a, I: Index, N: Shape> PermRef<'a, I, N> { } } } + +impl<'a, 'N, I: Index> PermRef<'a, I, Dim<'N>> { + /// Returns the permutation as an array. + #[inline] + pub fn bound_arrays( + self, + ) -> ( + &'a Array<'N, Idx, I>>, + &'a Array<'N, Idx, I>>, + ) { + unsafe { + ( + &*(self.forward as *const [Idx, I>] as *const Array<'N, Idx, I>>), + &*(self.inverse as *const [Idx, I>] as *const Array<'N, Idx, I>>), + ) + } + } +} diff --git a/src/sparse/csc/matmut.rs b/src/sparse/csc/matmut.rs index 4b914025..91440a35 100644 --- a/src/sparse/csc/matmut.rs +++ b/src/sparse/csc/matmut.rs @@ -349,7 +349,7 @@ impl<'a, I: Index, E: Entity, R: Shape, C: Shape> SparseColMatMut<'a, I, E, R, C /// Returns the row indices. #[inline] - pub fn row_indices(&self) -> &'a [Idx] { + pub fn row_indices(&self) -> &'a [I] { self.symbolic.row_ind } diff --git a/src/sparse/csc/matown.rs b/src/sparse/csc/matown.rs index 9b23d072..4039a496 100644 --- a/src/sparse/csc/matown.rs +++ b/src/sparse/csc/matown.rs @@ -340,7 +340,7 @@ impl SparseColMat { /// Returns the row indices. #[inline] - pub fn row_indices(&self) -> &'_ [Idx] { + pub fn row_indices(&self) -> &'_ [I] { &self.symbolic.row_ind } diff --git a/src/sparse/csc/matref.rs b/src/sparse/csc/matref.rs index 5912d775..e5c0dda2 100644 --- a/src/sparse/csc/matref.rs +++ b/src/sparse/csc/matref.rs @@ -372,7 +372,7 @@ impl<'a, I: Index, E: Entity, R: Shape, C: Shape> SparseColMatRef<'a, I, E, R, C /// Returns the row indices. #[inline] - pub fn row_indices(&self) -> &'a [Idx] { + pub fn row_indices(&self) -> &'a [I] { self.symbolic.row_ind } diff --git a/src/sparse/csc/symbolic_own.rs b/src/sparse/csc/symbolic_own.rs index abd15244..5da6b183 100644 --- a/src/sparse/csc/symbolic_own.rs +++ b/src/sparse/csc/symbolic_own.rs @@ -31,7 +31,7 @@ pub struct SymbolicSparseColMat { pub(crate) ncols: C, pub(crate) col_ptr: alloc::vec::Vec, pub(crate) col_nnz: Option>, - pub(crate) row_ind: alloc::vec::Vec>, + pub(crate) row_ind: alloc::vec::Vec, } impl SymbolicSparseColMat { @@ -47,7 +47,7 @@ impl SymbolicSparseColMat { ncols: C, col_ptrs: alloc::vec::Vec, nnz_per_col: Option>, - row_indices: alloc::vec::Vec>, + row_indices: alloc::vec::Vec, ) -> Self { SymbolicSparseColMatRef::new_checked( nrows, @@ -79,7 +79,7 @@ impl SymbolicSparseColMat { ncols: C, col_ptrs: alloc::vec::Vec, nnz_per_col: Option>, - row_indices: alloc::vec::Vec>, + row_indices: alloc::vec::Vec, ) -> Self { SymbolicSparseColMatRef::new_unsorted_checked( nrows, @@ -110,7 +110,7 @@ impl SymbolicSparseColMat { ncols: C, col_ptrs: alloc::vec::Vec, nnz_per_col: Option>, - row_indices: alloc::vec::Vec>, + row_indices: alloc::vec::Vec, ) -> Self { SymbolicSparseColMatRef::new_unchecked( nrows, @@ -143,7 +143,7 @@ impl SymbolicSparseColMat { C, alloc::vec::Vec, Option>, - alloc::vec::Vec>, + alloc::vec::Vec, ) { ( self.nrows, @@ -242,7 +242,7 @@ impl SymbolicSparseColMat { /// Returns the row indices. #[inline] - pub fn row_indices(&self) -> &[Idx] { + pub fn row_indices(&self) -> &[I] { &self.row_ind } @@ -257,6 +257,17 @@ impl SymbolicSparseColMat { self.as_ref().row_indices_of_col_raw(j) } + /// Returns the row indices of column `j`. + /// + /// # Panics + /// + /// Panics if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub fn row_indices_of_col_raw_unbound(&self, j: Idx) -> &[I] { + self.as_ref().row_indices_of_col_raw_unbound(j) + } + /// Returns the row indices of column `j`. /// /// # Panics diff --git a/src/sparse/csc/symbolic_ref.rs b/src/sparse/csc/symbolic_ref.rs index 6503194e..ab9e6b06 100644 --- a/src/sparse/csc/symbolic_ref.rs +++ b/src/sparse/csc/symbolic_ref.rs @@ -29,7 +29,7 @@ pub struct SymbolicSparseColMatRef<'a, I: Index, R: Shape = usize, C: Shape = us pub(crate) ncols: C, pub(crate) col_ptr: &'a [I], pub(crate) col_nnz: Option<&'a [I]>, - pub(crate) row_ind: &'a [Idx], + pub(crate) row_ind: &'a [I], } impl Copy for SymbolicSparseColMatRef<'_, I, R, C> {} @@ -84,7 +84,7 @@ impl<'a, I: Index, R: Shape, C: Shape> SymbolicSparseColMatRef<'a, I, R, C> { ncols: C, col_ptrs: &'a [I], nnz_per_col: Option<&'a [I]>, - row_indices: &'a [Idx], + row_indices: &'a [I], ) -> Self { assert!(all( ncols.unbound() <= I::Signed::MAX.zx(), @@ -99,34 +99,30 @@ impl<'a, I: Index, R: Shape, C: Shape> SymbolicSparseColMatRef<'a, I, R, C> { if let Some(nnz_per_col) = nnz_per_col { for (&nnz_j, &[c, c_next]) in zip(nnz_per_col, windows2(col_ptrs)) { assert!(nnz_j <= c_next - c); - if const { !R::IS_BOUND } { - let row_indices = &row_indices[c.zx()..c.zx() + nnz_j.zx()]; - if !row_indices.is_empty() { - let mut i_prev = row_indices[0]; - for &i in &row_indices[1..] { - assert!(i_prev < i); - i_prev = i; - } - let i_prev = i_prev.unbound(); - let nrows = I::truncate(nrows.unbound()); - assert!(i_prev < nrows); + let row_indices = &row_indices[c.zx()..c.zx() + nnz_j.zx()]; + if !row_indices.is_empty() { + let mut i_prev = row_indices[0]; + for &i in &row_indices[1..] { + assert!(i_prev < i); + i_prev = i; } + let i_prev = i_prev.unbound(); + let nrows = I::truncate(nrows.unbound()); + assert!(i_prev < nrows); } } } else { for &[c, c_next] in windows2(col_ptrs) { - if const { !R::IS_BOUND } { - let row_indices = &row_indices[c.zx()..c_next.zx()]; - if !row_indices.is_empty() { - let mut i_prev = row_indices[0]; - for &i in &row_indices[1..] { - assert!(i_prev <= i); - i_prev = i; - } - let i_prev = i_prev.unbound(); - let nrows = I::truncate(nrows.unbound()); - assert!(i_prev < nrows); + let row_indices = &row_indices[c.zx()..c_next.zx()]; + if !row_indices.is_empty() { + let mut i_prev = row_indices[0]; + for &i in &row_indices[1..] { + assert!(i_prev <= i); + i_prev = i; } + let i_prev = i_prev.unbound(); + let nrows = I::truncate(nrows.unbound()); + assert!(i_prev < nrows); } } } @@ -153,7 +149,7 @@ impl<'a, I: Index, R: Shape, C: Shape> SymbolicSparseColMatRef<'a, I, R, C> { ncols: C, col_ptrs: &'a [I], nnz_per_col: Option<&'a [I]>, - row_indices: &'a [Idx], + row_indices: &'a [I], ) -> Self { #[track_caller] fn check_ptrs( @@ -182,25 +178,23 @@ impl<'a, I: Index, R: Shape, C: Shape> SymbolicSparseColMatRef<'a, I, R, C> { check_ptrs(nrows.unbound(), ncols.unbound(), col_ptrs, nnz_per_col); assert!(col_ptrs[ncols.unbound()].zx() <= row_indices.len()); - if const { !R::IS_BOUND } { - if let Some(nnz_per_col) = nnz_per_col { - for (&nnz_j, &c) in zip(nnz_per_col, col_ptrs) { - for &i in &row_indices[c.zx()..c.zx() + nnz_j.zx()] { - let i = i.unbound(); - let nrows = nrows.unbound(); - assert!(i < I::truncate(nrows)); - } - } - } else { - let c0 = col_ptrs[0].zx(); - let cn = col_ptrs[ncols.unbound()].zx(); - - for &i in &row_indices[c0..cn] { + if let Some(nnz_per_col) = nnz_per_col { + for (&nnz_j, &c) in zip(nnz_per_col, col_ptrs) { + for &i in &row_indices[c.zx()..c.zx() + nnz_j.zx()] { let i = i.unbound(); let nrows = nrows.unbound(); assert!(i < I::truncate(nrows)); } } + } else { + let c0 = col_ptrs[0].zx(); + let cn = col_ptrs[ncols.unbound()].zx(); + + for &i in &row_indices[c0..cn] { + let i = i.unbound(); + let nrows = nrows.unbound(); + assert!(i < I::truncate(nrows)); + } } Self { @@ -224,7 +218,7 @@ impl<'a, I: Index, R: Shape, C: Shape> SymbolicSparseColMatRef<'a, I, R, C> { ncols: C, col_ptrs: &'a [I], nnz_per_col: Option<&'a [I]>, - row_indices: &'a [Idx], + row_indices: &'a [I], ) -> Self { assert!(all( ncols.unbound() <= ::MAX.zx(), @@ -359,7 +353,7 @@ impl<'a, I: Index, R: Shape, C: Shape> SymbolicSparseColMatRef<'a, I, R, C> { /// Returns the row indices. #[inline] - pub fn row_indices(&self) -> &'a [Idx] { + pub fn row_indices(&self) -> &'a [I] { self.row_ind } @@ -370,7 +364,22 @@ impl<'a, I: Index, R: Shape, C: Shape> SymbolicSparseColMatRef<'a, I, R, C> { /// Panics if `j >= self.ncols()`. #[inline] #[track_caller] - pub fn row_indices_of_col_raw(&self, j: Idx) -> &'a [Idx] { + pub fn row_indices_of_col_raw(self, j: Idx) -> &'a [Idx] { + unsafe { + let slice = __get_unchecked(self.row_ind, self.col_range(j)); + let len = slice.len(); + core::slice::from_raw_parts(slice.as_ptr() as *const Idx, len) + } + } + + /// Returns the row indices of column `j`. + /// + /// # Panics + /// + /// Panics if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub fn row_indices_of_col_raw_unbound(self, j: Idx) -> &'a [I] { unsafe { __get_unchecked(self.row_ind, self.col_range(j)) } } @@ -382,12 +391,12 @@ impl<'a, I: Index, R: Shape, C: Shape> SymbolicSparseColMatRef<'a, I, R, C> { #[inline] #[track_caller] pub fn row_indices_of_col( - &self, + self, j: Idx, ) -> impl 'a + Clone + ExactSizeIterator + DoubleEndedIterator> { - self.row_indices_of_col_raw(j).iter().map( + self.row_indices_of_col_raw_unbound(j).iter().map( #[inline(always)] - |&i| unsafe { Idx::::new_unbound(i.unbound().zx()) }, + |&i| unsafe { Idx::::new_unbound(i.zx()) }, ) } diff --git a/src/sparse/csr/matmut.rs b/src/sparse/csr/matmut.rs index 3b3f3613..47a6004a 100644 --- a/src/sparse/csr/matmut.rs +++ b/src/sparse/csr/matmut.rs @@ -348,7 +348,7 @@ impl<'a, I: Index, E: Entity, R: Shape, C: Shape> SparseRowMatMut<'a, I, E, R, C /// Returns the column indices. #[inline] - pub fn col_indices(&self) -> &'a [Idx] { + pub fn col_indices(&self) -> &'a [I] { self.symbolic.col_indices() } diff --git a/src/sparse/csr/matown.rs b/src/sparse/csr/matown.rs index 05e83596..9275d241 100644 --- a/src/sparse/csr/matown.rs +++ b/src/sparse/csr/matown.rs @@ -340,7 +340,7 @@ impl SparseRowMat { /// Returns the column indices. #[inline] - pub fn col_indices(&self) -> &'_ [Idx] { + pub fn col_indices(&self) -> &'_ [I] { self.symbolic.col_indices() } diff --git a/src/sparse/csr/matref.rs b/src/sparse/csr/matref.rs index f2be504c..05b2f8d0 100644 --- a/src/sparse/csr/matref.rs +++ b/src/sparse/csr/matref.rs @@ -307,7 +307,7 @@ impl<'a, I: Index, E: Entity, R: Shape, C: Shape> SparseRowMatRef<'a, I, E, R, C /// Returns the column indices. #[inline] - pub fn col_indices(&self) -> &'a [Idx] { + pub fn col_indices(&self) -> &'a [I] { self.symbolic.col_indices() } diff --git a/src/sparse/csr/symbolic_own.rs b/src/sparse/csr/symbolic_own.rs index b05591fc..3e98c766 100644 --- a/src/sparse/csr/symbolic_own.rs +++ b/src/sparse/csr/symbolic_own.rs @@ -31,7 +31,7 @@ pub struct SymbolicSparseRowMat { pub(crate) ncols: C, pub(crate) row_ptr: alloc::vec::Vec, pub(crate) row_nnz: Option>, - pub(crate) col_ind: alloc::vec::Vec>, + pub(crate) col_ind: alloc::vec::Vec, } impl SymbolicSparseRowMat { @@ -47,7 +47,7 @@ impl SymbolicSparseRowMat { ncols: C, row_ptrs: alloc::vec::Vec, nnz_per_row: Option>, - col_indices: alloc::vec::Vec>, + col_indices: alloc::vec::Vec, ) -> Self { SymbolicSparseRowMatRef::new_checked( nrows, @@ -79,7 +79,7 @@ impl SymbolicSparseRowMat { ncols: C, row_ptrs: alloc::vec::Vec, nnz_per_row: Option>, - col_indices: alloc::vec::Vec>, + col_indices: alloc::vec::Vec, ) -> Self { SymbolicSparseRowMatRef::new_unsorted_checked( nrows, @@ -110,7 +110,7 @@ impl SymbolicSparseRowMat { ncols: C, row_ptrs: alloc::vec::Vec, nnz_per_row: Option>, - col_indices: alloc::vec::Vec>, + col_indices: alloc::vec::Vec, ) -> Self { SymbolicSparseRowMatRef::new_unchecked( nrows, @@ -143,7 +143,7 @@ impl SymbolicSparseRowMat { C, alloc::vec::Vec, Option>, - alloc::vec::Vec>, + alloc::vec::Vec, ) { ( self.nrows, @@ -242,7 +242,7 @@ impl SymbolicSparseRowMat { /// Returns the column indices. #[inline] - pub fn col_indices(&self) -> &[Idx] { + pub fn col_indices(&self) -> &[I] { &self.col_ind } @@ -257,6 +257,17 @@ impl SymbolicSparseRowMat { self.as_ref().col_indices_of_row_raw(i) } + /// Returns the column indices of row `i`. + /// + /// # Panics + /// + /// Panics if `i >= self.nrows()`. + #[inline] + #[track_caller] + pub fn col_indices_of_row_raw_unbound(&self, i: Idx) -> &[I] { + self.as_ref().col_indices_of_row_raw_unbound(i) + } + /// Returns the column indices of row `i`. /// /// # Panics diff --git a/src/sparse/csr/symbolic_ref.rs b/src/sparse/csr/symbolic_ref.rs index 35e8444d..f08777e9 100644 --- a/src/sparse/csr/symbolic_ref.rs +++ b/src/sparse/csr/symbolic_ref.rs @@ -29,7 +29,7 @@ pub struct SymbolicSparseRowMatRef<'a, I: Index, R: Shape = usize, C: Shape = us pub(crate) ncols: C, pub(crate) row_ptr: &'a [I], pub(crate) row_nnz: Option<&'a [I]>, - pub(crate) col_ind: &'a [Idx], + pub(crate) col_ind: &'a [I], } impl Copy for SymbolicSparseRowMatRef<'_, I, R, C> {} @@ -84,7 +84,7 @@ impl<'a, I: Index, R: Shape, C: Shape> SymbolicSparseRowMatRef<'a, I, R, C> { ncols: C, row_ptrs: &'a [I], nnz_per_row: Option<&'a [I]>, - col_indices: &'a [Idx], + col_indices: &'a [I], ) -> Self { #[track_caller] fn check_ptrs( @@ -167,7 +167,7 @@ impl<'a, I: Index, R: Shape, C: Shape> SymbolicSparseRowMatRef<'a, I, R, C> { ncols: C, row_ptrs: &'a [I], nnz_per_row: Option<&'a [I]>, - col_indices: &'a [Idx], + col_indices: &'a [I], ) -> Self { assert!(all( ncols.unbound() <= I::Signed::MAX.zx(), @@ -223,7 +223,7 @@ impl<'a, I: Index, R: Shape, C: Shape> SymbolicSparseRowMatRef<'a, I, R, C> { ncols: C, row_ptrs: &'a [I], nnz_per_row: Option<&'a [I]>, - col_indices: &'a [Idx], + col_indices: &'a [I], ) -> Self { assert!(all( ncols.unbound() <= ::MAX.zx(), @@ -316,7 +316,7 @@ impl<'a, I: Index, R: Shape, C: Shape> SymbolicSparseRowMatRef<'a, I, R, C> { /// Returns the column indices. #[inline] - pub fn col_indices(&self) -> &'a [Idx] { + pub fn col_indices(self) -> &'a [I] { self.col_ind } @@ -327,7 +327,22 @@ impl<'a, I: Index, R: Shape, C: Shape> SymbolicSparseRowMatRef<'a, I, R, C> { /// Panics if `i >= self.nrows()`. #[inline] #[track_caller] - pub fn col_indices_of_row_raw(&self, i: Idx) -> &'a [Idx] { + pub fn col_indices_of_row_raw(self, i: Idx) -> &'a [Idx] { + unsafe { + let slice = __get_unchecked(self.col_ind, self.row_range(i)); + let len = slice.len(); + core::slice::from_raw_parts(slice.as_ptr() as *const Idx, len) + } + } + + /// Returns the column indices of row i. + /// + /// # Panics + /// + /// Panics if `i >= self.nrows()`. + #[inline] + #[track_caller] + pub fn col_indices_of_row_raw_unbound(self, i: Idx) -> &'a [I] { unsafe { __get_unchecked(self.col_ind, self.row_range(i)) } } @@ -339,12 +354,12 @@ impl<'a, I: Index, R: Shape, C: Shape> SymbolicSparseRowMatRef<'a, I, R, C> { #[inline] #[track_caller] pub fn col_indices_of_row( - &self, + self, i: Idx, ) -> impl 'a + ExactSizeIterator + DoubleEndedIterator> { - self.col_indices_of_row_raw(i).iter().map( + self.col_indices_of_row_raw_unbound(i).iter().map( #[inline(always)] - |&i| unsafe { Idx::::new_unbound(i.unbound().zx()) }, + |&i| unsafe { Idx::::new_unbound(i.zx()) }, ) } diff --git a/src/sparse/linalg/amd.rs b/src/sparse/linalg/amd.rs index e0ef08db..2495f38e 100644 --- a/src/sparse/linalg/amd.rs +++ b/src/sparse/linalg/amd.rs @@ -35,11 +35,15 @@ // DAMAGE. use super::{ - ghost::{self, Array, Idx, MaybeIdx}, + ghost, mem::{self, NONE}, windows2, FaerError, Index, SignedIndex, SymbolicSparseColMatRef, }; -use crate::{assert, ComplexField}; +use crate::{ + assert, + utils::bound::{Array, Idx, MaybeIdx}, + ComplexField, +}; use core::{cell::Cell, iter::zip}; use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; @@ -800,7 +804,7 @@ fn amd_1( with_dim!(N, n); let (t_p, _) = stack.rb_mut().make_raw::(n); - let A = ghost::SymbolicSparseColMatRef::new(A, N, N); + let A = A.as_shape(N, N); let s_p = Array::from_mut(s_p, N); let t_p = Array::from_mut(t_p, N); @@ -908,7 +912,7 @@ fn preprocess<'out, I: Index>( let w = Array::from_mut(w, N); let flag = Array::from_mut(flag, N); - let A = ghost::SymbolicSparseColMatRef::new(A, N, N); + let A = A.as_shape(N, N); mem::fill_zero(w.as_mut()); mem::fill_none(flag.as_mut()); @@ -968,7 +972,7 @@ fn aat( let I = I::Signed::truncate; let zero = I(0); let one = I(1); - let A = ghost::SymbolicSparseColMatRef::new(A, N, N); + let A = A.as_shape(N, N); let n = *N; diff --git a/src/sparse/linalg/cholesky.rs b/src/sparse/linalg/cholesky.rs index e8e075a1..7fd515b6 100644 --- a/src/sparse/linalg/cholesky.rs +++ b/src/sparse/linalg/cholesky.rs @@ -525,8 +525,8 @@ use super::{ amd::{self, Control, FlopCount}, - ghost::{self, Array, Idx, MaybeIdx}, - ghost_permute_hermitian_unsorted, ghost_permute_hermitian_unsorted_symbolic, make_raw_req, mem, + ghost, ghost_permute_hermitian_unsorted, ghost_permute_hermitian_unsorted_symbolic, + make_raw_req, mem, mem::NONE, nomem, triangular_solve, try_collect, try_zeroed, utils::ghost_adjoint, @@ -541,7 +541,9 @@ pub use crate::linalg::cholesky::{ use crate::{ assert, linalg::{temp_mat_req, temp_mat_uninit}, - unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, SignedIndex, + unzipped, + utils::bound::{Array, Dim, Idx, MaybeIdx}, + zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, SignedIndex, }; use core::{cell::Cell, iter::zip}; use dyn_stack::{PodStack, SizeOverflow, StackReq}; @@ -599,7 +601,7 @@ pub mod simplicial { ghost_prefactorize_symbolic_cholesky( Array::from_mut(etree, N), Array::from_mut(col_counts, N), - ghost::SymbolicSparseColMatRef::new(A, N, N), + A.as_shape(N, N), stack, ); @@ -608,7 +610,7 @@ pub mod simplicial { fn ereach<'n, 'a, I: Index>( stack: &'a mut Array<'n, I>, - A: ghost::SymbolicSparseColMatRef<'n, 'n, '_, I>, + A: SymbolicSparseColMatRef<'_, I, Dim<'n>, Dim<'n>>, etree: &Array<'n, MaybeIdx<'n, I>>, k: Idx<'n, usize>, visited: &mut Array<'n, I::Signed>, @@ -683,7 +685,7 @@ pub mod simplicial { with_dim!(N, n); ghost_factorize_simplicial_symbolic_cholesky( - ghost::SymbolicSparseColMatRef::new(A, N, N), + A.as_shape(N, N), etree.ghost_inner(N), Array::from_ref(col_counts, N), stack, @@ -691,7 +693,7 @@ pub mod simplicial { } pub(crate) fn ghost_factorize_simplicial_symbolic_cholesky<'n, I: Index>( - A: ghost::SymbolicSparseColMatRef<'n, 'n, '_, I>, + A: SymbolicSparseColMatRef<'_, I, Dim<'n>, Dim<'n>>, etree: &Array<'n, MaybeIdx<'n, I>>, col_counts: &Array<'n, I>, stack: &mut PodStack, @@ -786,7 +788,7 @@ pub mod simplicial { with_dim!(L_NNZ, l_nnz); let etree = etree.ghost_inner(N); - let A = ghost::SparseColMatRef::new(A, N, N); + let A = A.as_shape(N, N); let eps = regularization.dynamic_regularization_epsilon.faer_abs(); let delta = regularization.dynamic_regularization_delta.faer_abs(); @@ -817,7 +819,7 @@ pub mod simplicial { ); for k in N.indices() { - let reach = ereach(ereach_stack, *A, etree, k, visited); + let reach = ereach(ereach_stack, A.symbolic(), etree, k, visited); for (i, aik) in zip( A.row_indices_of_col(k), @@ -941,7 +943,7 @@ pub mod simplicial { MaybeIdx::from_slice_ref_checked(bytemuck::cast_slice::(etree), N), N, ); - let A = ghost::SparseColMatRef::new(A, N, N); + let A = A.as_shape(N, N); let eps = regularization.dynamic_regularization_epsilon.faer_abs(); let delta = regularization.dynamic_regularization_delta.faer_abs(); @@ -972,7 +974,7 @@ pub mod simplicial { ); for k in N.indices() { - let reach = ereach(ereach_stack, *A, etree, k, visited); + let reach = ereach(ereach_stack, A.symbolic(), etree, k, visited); for (i, aik) in zip( A.row_indices_of_col(k), @@ -1470,7 +1472,7 @@ pub mod supernodal { #[doc(hidden)] pub fn ereach_super<'n, 'nsuper, I: Index>( - A: ghost::SymbolicSparseColMatRef<'n, 'n, '_, I>, + A: SymbolicSparseColMatRef<'_, I, Dim<'n>, Dim<'n>>, super_etree: &Array<'nsuper, MaybeIdx<'nsuper, I>>, index_to_super: &Array<'n, Idx<'nsuper, I>>, current_row_positions: &mut Array<'nsuper, I>, @@ -1500,8 +1502,8 @@ pub mod supernodal { } fn ereach_super_ata<'m, 'n, 'nsuper, I: Index>( - A: ghost::SymbolicSparseColMatRef<'m, 'n, '_, I>, - perm: Option>, + A: SymbolicSparseColMatRef<'_, I, Dim<'m>, Dim<'n>>, + perm: Option>>, min_col: &Array<'m, MaybeIdx<'n, I>>, super_etree: &Array<'nsuper, MaybeIdx<'nsuper, I>>, index_to_super: &Array<'n, Idx<'nsuper, I>>, @@ -1513,7 +1515,7 @@ pub mod supernodal { let k_: I = *k.truncate(); visited[index_to_super[k].zx()] = k_.to_signed(); - let fwd = perm.map(|perm| perm.arrays().0); + let fwd = perm.map(|perm| perm.bound_arrays().0); let fwd = |i: Idx<'n, usize>| fwd.map(|fwd| fwd[k].zx()).unwrap_or(i); for i in A.row_indices_of_col(fwd(k)) { let Some(i) = min_col[i].idx() else { continue }; @@ -2334,7 +2336,7 @@ pub mod supernodal { let N = self.nrows().bind(N); let N_SUPERNODES = self.nrows().bind(N_SUPERNODES); - let A = ghost::SymbolicSparseColMatRef::new(A, N, N); + let A = A.as_shape(N, N); let n = *N; let n_supernodes = *N_SUPERNODES; let none = I::Signed::truncate(NONE); @@ -2387,8 +2389,8 @@ pub mod supernodal { mem::fill_zero(current_row_positions.as_mut()); for s in N_SUPERNODES.indices() { - let k1 = ghost::IdxInclusive::new_checked(self.supernode_begin()[*s].zx(), N); - let k2 = ghost::IdxInclusive::new_checked(self.supernode_end()[*s].zx(), N); + let k1 = ghost::IdxInc::new_checked(self.supernode_begin()[*s].zx(), N); + let k2 = ghost::IdxInc::new_checked(self.supernode_end()[*s].zx(), N); for k in k1.range_to(k2) { ereach_super( @@ -2444,7 +2446,7 @@ pub mod supernodal { assert!(col_counts.len() == n); with_dim!(N, n); ghost_factorize_supernodal_symbolic( - ghost::SymbolicSparseColMatRef::new(A, N, N), + A.as_shape(N, N), None, None, CholeskyInput::A, @@ -2461,8 +2463,8 @@ pub mod supernodal { } pub(crate) fn ghost_factorize_supernodal_symbolic<'m, 'n, I: Index>( - A: ghost::SymbolicSparseColMatRef<'m, 'n, '_, I>, - col_perm: Option>, + A: SymbolicSparseColMatRef<'_, I, Dim<'m>, Dim<'n>>, + col_perm: Option>>, min_col: Option<&Array<'m, MaybeIdx<'n, I>>>, input: CholeskyInput, etree: &Array<'n, MaybeIdx<'n, I>>, @@ -2871,10 +2873,10 @@ pub mod supernodal { mem::fill_none::(visited.as_mut()); if matches!(input, CholeskyInput::A) { - let A = ghost::SymbolicSparseColMatRef::new(A.into_inner(), N, N); + let A = A.as_shape(N, N); for s in N_SUPERNODES.indices() { - let k1 = ghost::IdxInclusive::new_checked(supernode_begin__[*s].zx(), N); - let k2 = ghost::IdxInclusive::new_checked(supernode_begin__[*s + 1].zx(), N); + let k1 = ghost::IdxInc::new_checked(supernode_begin__[*s].zx(), N); + let k2 = ghost::IdxInc::new_checked(supernode_begin__[*s + 1].zx(), N); for k in k1.range_to(k2) { ereach_super( @@ -2891,8 +2893,8 @@ pub mod supernodal { } else { let min_col = min_col.unwrap(); for s in N_SUPERNODES.indices() { - let k1 = ghost::IdxInclusive::new_checked(supernode_begin__[*s].zx(), N); - let k2 = ghost::IdxInclusive::new_checked(supernode_begin__[*s + 1].zx(), N); + let k1 = ghost::IdxInc::new_checked(supernode_begin__[*s].zx(), N); + let k2 = ghost::IdxInc::new_checked(supernode_begin__[*s + 1].zx(), N); for k in k1.range_to(k2) { ereach_super_ata( @@ -3956,7 +3958,7 @@ pub mod supernodal { fn ghost_prefactorize_symbolic_cholesky<'n, 'out, I: Index>( etree: &'out mut Array<'n, I::Signed>, col_counts: &mut Array<'n, I>, - A: ghost::SymbolicSparseColMatRef<'n, 'n, '_, I>, + A: SymbolicSparseColMatRef<'_, I, Dim<'n>, Dim<'n>>, stack: &mut PodStack, ) -> &'out mut Array<'n, MaybeIdx<'n, I>> { let N = A.ncols(); @@ -4254,7 +4256,7 @@ impl SymbolicCholesky { let mut L_values = L_values; let A_nnz = self.A_nnz; - let A = ghost::SparseColMatRef::new(A, N, N); + let A = A.as_shape(N, N); let (mut new_values, stack) = crate::sparse::linalg::make_raw::(A_nnz, stack); let (new_col_ptr, stack) = stack.make_raw::(n + 1); @@ -4267,7 +4269,7 @@ impl SymbolicCholesky { let A = match self.perm() { Some(perm) => { - let perm = ghost::PermRef::new(perm, N); + let perm = perm.as_shape(N); unsafe { ghost_permute_hermitian_unsorted( new_values.rb_mut(), @@ -4303,7 +4305,7 @@ impl SymbolicCholesky { SymbolicCholeskyRaw::Simplicial(this) => { simplicial::factorize_simplicial_numeric_llt( E::faer_rb_mut(E::faer_as_mut(&mut L_values)), - A.into_inner().into_const(), + A.as_dyn().into_const(), regularization, this, stack, @@ -4312,7 +4314,7 @@ impl SymbolicCholesky { SymbolicCholeskyRaw::Supernodal(this) => { supernodal::factorize_supernodal_numeric_llt( E::faer_rb_mut(E::faer_as_mut(&mut L_values)), - A.into_inner().into_const(), + A.as_dyn().into_const(), regularization, this, parallelism, @@ -4343,7 +4345,7 @@ impl SymbolicCholesky { with_dim!(N, n); let A_nnz = self.A_nnz; - let A = ghost::SparseColMatRef::new(A, N, N); + let A = A.as_shape(N, N); let (new_signs, stack) = stack.make_raw::( if regularization.dynamic_regularization_signs.is_some() && self.perm().is_some() { @@ -4364,7 +4366,7 @@ impl SymbolicCholesky { let (A, signs) = match self.perm() { Some(perm) => { - let perm = ghost::PermRef::new(perm, N); + let perm = perm.as_shape(N); let A = unsafe { ghost_permute_hermitian_unsorted( new_values.rb_mut(), @@ -4379,7 +4381,7 @@ impl SymbolicCholesky { ) .into_const() }; - let fwd = perm.arrays().0; + let fwd = perm.bound_arrays().0; let signs = regularization.dynamic_regularization_signs.map(|signs| { { let new_signs = Array::from_mut(new_signs, N); @@ -4421,7 +4423,7 @@ impl SymbolicCholesky { SymbolicCholeskyRaw::Simplicial(this) => { simplicial::factorize_simplicial_numeric_ldlt( E::faer_rb_mut(E::faer_as_mut(&mut L_values)), - A.into_inner().into_const(), + A.as_dyn().into_const(), regularization, this, stack, @@ -4430,7 +4432,7 @@ impl SymbolicCholesky { SymbolicCholeskyRaw::Supernodal(this) => { supernodal::factorize_supernodal_numeric_ldlt( E::faer_rb_mut(E::faer_as_mut(&mut L_values)), - A.into_inner().into_const(), + A.as_dyn().into_const(), regularization, this, parallelism, @@ -4463,7 +4465,7 @@ impl SymbolicCholesky { with_dim!(N, n); let A_nnz = self.A_nnz; - let A = ghost::SparseColMatRef::new(A, N, N); + let A = A.as_shape(N, N); let (new_signs, stack) = stack.make_raw::(if regularization.dynamic_regularization_signs.is_some() { @@ -4483,8 +4485,8 @@ impl SymbolicCholesky { let (A, signs) = match self.perm() { Some(perm) => { - let perm = ghost::PermRef::new(perm, N); - let fwd = perm.arrays().0; + let perm = perm.as_shape(N); + let fwd = perm.bound_arrays().0; let signs = regularization.dynamic_regularization_signs.map(|signs| { { let new_signs = Array::from_mut(new_signs, N); @@ -4547,7 +4549,7 @@ impl SymbolicCholesky { } simplicial::factorize_simplicial_numeric_ldlt( E::faer_rb_mut(E::faer_as_mut(&mut L_values)), - A.into_inner().into_const(), + A.as_dyn().into_const(), regularization, this, stack, @@ -4565,7 +4567,7 @@ impl SymbolicCholesky { E::faer_rb_mut(E::faer_as_mut(&mut subdiag)), perm_forward, perm_inverse, - A.into_inner().into_const(), + A.as_dyn().into_const(), regularization, this, parallelism, @@ -5068,7 +5070,7 @@ pub fn factorize_symbolic_cholesky( assert!(A.nrows() == A.ncols()); with_dim!(N, n); - let A = ghost::SymbolicSparseColMatRef::new(A, N, N); + let A = A.as_shape(N, N); let req = || -> Result { let n_req = StackReq::try_new::(n)?; @@ -5117,7 +5119,7 @@ pub fn factorize_symbolic_cholesky( SymmetricOrdering::Amd => Some(amd::order_maybe_unsorted( perm_fwd.as_mut().unwrap(), perm_inv.as_mut().unwrap(), - A.into_inner(), + A.as_dyn(), params.amd_params, stack.rb_mut(), )?), @@ -5139,10 +5141,8 @@ pub fn factorize_symbolic_cholesky( new_col_ptr, new_row_ind, A, - ghost::PermRef::new( - PermRef::new_checked(perm_fwd.as_ref().unwrap(), perm_inv.as_ref().unwrap(), n), - N, - ), + PermRef::new_checked(perm_fwd.as_ref().unwrap(), perm_inv.as_ref().unwrap(), n) + .as_shape(N), side, Side::Upper, stack.rb_mut(), @@ -5248,7 +5248,7 @@ pub(crate) mod tests { let mut col_count = vec![zero; n]; { with_dim!(N, n); - let A = ghost::SymbolicSparseColMatRef::new(A, N, N); + let A = A.as_shape(N, N); let etree = ghost_prefactorize_symbolic_cholesky( Array::from_mut(&mut etree, N), Array::from_mut(&mut col_count, N), @@ -5511,16 +5511,16 @@ pub(crate) mod tests { let mut col_count = vec![zero; n]; with_dim!(N, n); - let A = ghost::SparseColMatRef::new(A, N, N); + let A = A.as_shape(N, N); let etree = ghost_prefactorize_symbolic_cholesky( Array::from_mut(&mut etree, N), Array::from_mut(&mut col_count, N), - *A, + A.symbolic(), PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(5 * n))), ); let symbolic = supernodal::ghost_factorize_supernodal_symbolic( - *A, + A.symbolic(), None, None, CholeskyInput::A, @@ -5546,7 +5546,7 @@ pub(crate) mod tests { supernodal::factorize_supernodal_numeric_ldlt( values.col_as_slice_mut(0), - A_lower.into_inner().into_const(), + A_lower.as_dyn().into_const(), Default::default(), &symbolic, Parallelism::None, @@ -5558,7 +5558,7 @@ pub(crate) mod tests { .unwrap(), )), ); - let mut A = sparse_to_dense(A.into_inner()); + let mut A = sparse_to_dense(A.as_dyn()); for j in 0..n { for i in j + 1..n { A.write(i, j, A.read(j, i).faer_conj()); @@ -5617,16 +5617,16 @@ pub(crate) mod tests { let mut col_count = vec![zero; n]; with_dim!(N, n); - let A = ghost::SparseColMatRef::new(A, N, N); + let A = A.as_shape(N, N); let etree = ghost_prefactorize_symbolic_cholesky( Array::from_mut(&mut etree, N), Array::from_mut(&mut col_count, N), - *A, + A.symbolic(), PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(5 * n))), ); let symbolic = supernodal::ghost_factorize_supernodal_symbolic( - *A, + A.symbolic(), None, None, CholeskyInput::A, @@ -5652,7 +5652,7 @@ pub(crate) mod tests { supernodal::factorize_supernodal_numeric_ldlt( values.col_as_slice_mut(0), - A_lower.into_inner().into_const(), + A_lower.as_dyn().into_const(), Default::default(), &symbolic, Parallelism::None, @@ -5742,16 +5742,16 @@ pub(crate) mod tests { let mut col_count = vec![zero; n]; with_dim!(N, n); - let A = ghost::SparseColMatRef::new(A, N, N); + let A = A.as_shape(N, N); let etree = ghost_prefactorize_symbolic_cholesky( Array::from_mut(&mut etree, N), Array::from_mut(&mut col_count, N), - *A, + A.symbolic(), PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(5 * n))), ); let symbolic = supernodal::ghost_factorize_supernodal_symbolic( - *A, + A.symbolic(), None, None, CholeskyInput::A, @@ -5784,7 +5784,7 @@ pub(crate) mod tests { subdiag.col_as_slice_mut(0), &mut fwd, &mut inv, - A_lower.into_inner().into_const(), + A_lower.as_dyn().into_const(), Default::default(), &symbolic, Parallelism::None, @@ -5883,16 +5883,16 @@ pub(crate) mod tests { let mut col_count = vec![zero; n]; with_dim!(N, n); - let A = ghost::SparseColMatRef::new(A, N, N); + let A = A.as_shape(N, N); let etree = ghost_prefactorize_symbolic_cholesky( Array::from_mut(&mut etree, N), Array::from_mut(&mut col_count, N), - *A, + A.symbolic(), PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(5 * n))), ); let symbolic = supernodal::ghost_factorize_supernodal_symbolic( - *A, + A.symbolic(), None, None, CholeskyInput::A, @@ -5925,7 +5925,7 @@ pub(crate) mod tests { subdiag.col_as_slice_mut(0), &mut fwd, &mut inv, - A_lower.into_inner().into_const(), + A_lower.as_dyn().into_const(), Default::default(), &symbolic, Parallelism::None, @@ -6025,16 +6025,16 @@ pub(crate) mod tests { let mut col_count = vec![zero; n]; with_dim!(N, n); - let A = ghost::SparseColMatRef::new(A, N, N); + let A = A.as_shape(N, N); let etree = ghost_prefactorize_symbolic_cholesky( Array::from_mut(&mut etree, N), Array::from_mut(&mut col_count, N), - *A, + A.symbolic(), PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(5 * n))), ); let symbolic = simplicial::ghost_factorize_simplicial_symbolic_cholesky( - *A, + A.symbolic(), etree, Array::from_ref(&col_count, N), PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(20 * n))), @@ -6045,14 +6045,14 @@ pub(crate) mod tests { simplicial::factorize_simplicial_numeric_ldlt::( values.col_as_slice_mut(0), - A.into_inner(), + A.as_dyn(), Default::default(), &symbolic, PodStack::new(&mut GlobalPodBuffer::new( simplicial::factorize_simplicial_numeric_ldlt_req::(n).unwrap(), )), ); - let mut A = sparse_to_dense(A.into_inner()); + let mut A = sparse_to_dense(A.as_dyn()); for j in 0..n { for i in j + 1..n { A.write(i, j, A.read(j, i).faer_conj()); diff --git a/src/sparse/linalg/lu.rs b/src/sparse/linalg/lu.rs index b0d45edf..466ca44c 100644 --- a/src/sparse/linalg/lu.rs +++ b/src/sparse/linalg/lu.rs @@ -10,8 +10,6 @@ use super::{ cholesky::simplicial::EliminationTreeRef, colamd::Control, - ghost, - ghost::{Array, Idx, MaybeIdx}, mem::{ NONE, {self}, }, @@ -23,7 +21,11 @@ use crate::{ linalg::{matmul, temp_mat_req, temp_mat_uninit, triangular_solve as solve}, perm::PermRef, sparse::SparseColMatRef, - utils::{slice::*, vec::*}, + utils::{ + bound::{Array, Idx, MaybeIdx}, + slice::*, + vec::*, + }, Conj, MatMut, Parallelism, SignedIndex, }; use core::{iter::zip, mem::MaybeUninit}; @@ -583,7 +585,7 @@ pub mod supernodal { with_dim!(N, n); let I = I::truncate; - let A = ghost::SymbolicSparseColMatRef::new(A, M, N); + let A = A.as_shape(M, N); let min_col = Array::from_ref( MaybeIdx::from_slice_ref_checked(bytemuck::cast_slice(min_col), N), M, @@ -593,7 +595,7 @@ pub mod supernodal { let L = crate::sparse::linalg::cholesky::supernodal::ghost_factorize_supernodal_symbolic( A, - col_perm.map(|perm| ghost::PermRef::new(perm, N)), + col_perm.map(|perm| perm.as_shape(N)), Some(min_col), crate::sparse::linalg::cholesky::supernodal::CholeskyInput::ATA, etree, @@ -2202,7 +2204,7 @@ pub fn factorize_symbolic_lu( with_dim!(M, m); with_dim!(N, n); - let A = ghost::SymbolicSparseColMatRef::new(A, M, N); + let A = A.as_shape(M, N); let req = || -> Result { let n_req = StackReq::try_new::(n)?; @@ -2243,12 +2245,12 @@ pub fn factorize_symbolic_lu( crate::sparse::linalg::colamd::order( &mut col_perm_fwd, &mut col_perm_inv, - A.into_inner(), + A.as_dyn(), params.colamd_params, stack.rb_mut(), )?; - let col_perm = ghost::PermRef::new(PermRef::new_checked(&col_perm_fwd, &col_perm_inv, n), N); + let col_perm = PermRef::new_checked(&col_perm_fwd, &col_perm_inv, n).as_shape(N); let (new_col_ptr, stack) = stack.make_raw::(m + 1); let (new_row_ind, mut stack) = stack.make_raw::(A_nnz); @@ -2322,8 +2324,8 @@ pub fn factorize_symbolic_lu( if threshold == SupernodalThreshold::FORCE_SUPERNODAL { let symbolic = supernodal::factorize_supernodal_symbolic_lu::( - A.into_inner(), - Some(col_perm.into_inner()), + A.as_dyn(), + Some(col_perm.as_shape(n)), &min_col, EliminationTreeRef::<'_, I> { inner: etree }, col_counts, diff --git a/src/sparse/linalg/matmul.rs b/src/sparse/linalg/matmul.rs index 948574cc..dff19155 100644 --- a/src/sparse/linalg/matmul.rs +++ b/src/sparse/linalg/matmul.rs @@ -6,7 +6,6 @@ use super::*; use crate::{ assert, mat::{As2D, As2DMut}, - utils::constrained::{self}, }; use core::cell::UnsafeCell; @@ -310,9 +309,9 @@ pub fn sparse_dense_matmul< with_dim!(n, n); with_dim!(k, k); - let mut acc = constrained::mat::MatMut::new(acc, m, n); - let lhs = constrained::sparse::SparseColMatRef::new(lhs, m, k); - let rhs = constrained::mat::MatRef::new(rhs, k, n); + let mut acc = acc.as_shape_mut(m, n); + let lhs = lhs.as_shape(m, k); + let rhs = rhs.as_shape(k, n); for j in n.indices() { for depth in k.indices() { @@ -401,9 +400,9 @@ pub fn dense_sparse_matmul< with_dim!(m, m); with_dim!(n, n); with_dim!(k, k); - let mut acc = constrained::mat::MatMut::new(acc, m, n); - let lhs = constrained::mat::MatRef::new(lhs, m, k); - let rhs = constrained::sparse::SparseColMatRef::new(rhs, k, n); + let mut acc = acc.as_shape_mut(m, n); + let lhs = lhs.as_shape(m, k); + let rhs = rhs.as_shape(k, n); for i in m.indices() { for j in n.indices() { diff --git a/src/sparse/linalg/mod.rs b/src/sparse/linalg/mod.rs index 8250e131..aa9fdbd6 100644 --- a/src/sparse/linalg/mod.rs +++ b/src/sparse/linalg/mod.rs @@ -239,7 +239,7 @@ impl From for CholeskyError { pub mod solvers; mod ghost { - pub use crate::utils::constrained::{perm::*, sparse::*, *}; + pub use crate::utils::bound::*; use crate::Index; pub const NONE_BYTE: u8 = u8::MAX; diff --git a/src/sparse/linalg/qr.rs b/src/sparse/linalg/qr.rs index 63d09e8b..dc80e09e 100644 --- a/src/sparse/linalg/qr.rs +++ b/src/sparse/linalg/qr.rs @@ -13,7 +13,7 @@ use super::{ supernodal::{SupernodalLltRef, SymbolicSupernodalCholesky}, }, colamd::{self, Control}, - ghost::{self, Array, Idx, MaybeIdx}, + ghost, mem::{self, NONE}, nomem, try_zeroed, FaerError, Index, SupernodalThreshold, SymbolicSupernodalParams, }; @@ -26,7 +26,10 @@ use crate::{ perm::PermRef, sparse::{SparseColMatRef, SymbolicSparseColMatRef}, unzipped, - utils::{constrained::Dim, slice::*}, + utils::{ + bound::{Array, Dim, Idx, MaybeIdx}, + slice::*, + }, zipped, Conj, MatMut, Parallelism, SignedIndex, }; use core::iter::zip; @@ -36,8 +39,8 @@ use reborrow::*; #[inline] pub(crate) fn ghost_col_etree<'n, I: Index>( - A: ghost::SymbolicSparseColMatRef<'_, 'n, '_, I>, - col_perm: Option>, + A: SymbolicSparseColMatRef<'_, I, Dim<'_>, Dim<'n>>, + col_perm: Option>>, etree: &mut Array<'n, I::Signed>, stack: &mut PodStack, ) { @@ -54,7 +57,9 @@ pub(crate) fn ghost_col_etree<'n, I: Index>( mem::fill_none(etree.as_mut()); for j in N.indices() { - let pj = col_perm.map(|perm| perm.arrays().0[j].zx()).unwrap_or(j); + let pj = col_perm + .map(|perm| perm.bound_arrays().0[j].zx()) + .unwrap_or(j); for i_ in A.row_indices_of_col(pj) { let mut i = prev[i_].sx(); while let Some(i_) = i.idx() { @@ -98,8 +103,8 @@ pub fn col_etree<'out, I: Index>( with_dim!(M, A.nrows()); with_dim!(N, A.ncols()); ghost_col_etree( - ghost::SymbolicSparseColMatRef::new(A, M, N), - col_perm.map(|perm| ghost::PermRef::new(perm, N)), + A.as_shape(M, N), + col_perm.map(|perm| perm.as_shape(N)), Array::from_mut(bytemuck::cast_slice_mut(etree), N), stack, ); @@ -145,8 +150,8 @@ pub(crate) fn ghost_least_common_ancestor<'n, I: Index>( pub(crate) fn ghost_column_counts_aat<'m, 'n, I: Index>( col_counts: &mut Array<'m, I>, min_row: &mut Array<'n, I::Signed>, - A: ghost::SymbolicSparseColMatRef<'m, 'n, '_, I>, - row_perm: Option>, + A: SymbolicSparseColMatRef<'_, I, Dim<'m>, Dim<'n>>, + row_perm: Option>>, etree: &Array<'m, MaybeIdx<'m, I>>, post: &Array<'m, Idx<'m, I>>, stack: &mut PodStack, @@ -177,7 +182,7 @@ pub(crate) fn ghost_column_counts_aat<'m, 'n, I: Index>( for j in N.indices() { if let Some(perm) = row_perm { - let inv = perm.arrays().1; + let inv = perm.bound_arrays().1; min_row[j] = match Iterator::min(A.row_indices_of_col(j).map(|j| inv[j].zx())) { Some(first_row) => I::Signed::truncate(*first_row), None => *MaybeIdx::<'_, I>::none(), @@ -190,7 +195,7 @@ pub(crate) fn ghost_column_counts_aat<'m, 'n, I: Index>( } let min_row = if let Some(perm) = row_perm { - let inv = perm.arrays().1; + let inv = perm.bound_arrays().1; Iterator::min(A.row_indices_of_col(j).map(|row| post_inv[inv[row].zx()])) } else { Iterator::min(A.row_indices_of_col(j).map(|row| post_inv[row])) @@ -248,7 +253,9 @@ pub(crate) fn ghost_column_counts_aat<'m, 'n, I: Index>( while let Some(j_) = j.idx() { for i in A.row_indices_of_col(j_) { - let i = row_perm.map(|perm| perm.arrays().1[i].zx()).unwrap_or(i); + let i = row_perm + .map(|perm| perm.bound_arrays().1[i].zx()) + .unwrap_or(i); let lca = ghost_least_common_ancestor::(i, pk, first, max_first, prev_leaf, ancestor); @@ -314,12 +321,12 @@ pub fn column_counts_ata<'m, 'n, I: Index>( with_dim!(M, AT.nrows()); with_dim!(N, AT.ncols()); - let A = ghost::SymbolicSparseColMatRef::new(AT, M, N); + let A = AT.as_shape(M, N); ghost_column_counts_aat( Array::from_mut(col_counts, M), Array::from_mut(bytemuck::cast_slice_mut(min_col), N), A, - col_perm.map(|perm| ghost::PermRef::new(perm, M)), + col_perm.map(|perm| perm.as_shape(M)), etree.ghost_inner(M), Array::from_ref(Idx::from_slice_ref_checked(post, M), M), stack, @@ -493,7 +500,7 @@ pub mod supernodal { with_dim!(M, m); with_dim!(N, n); - let A = ghost::SymbolicSparseColMatRef::new(A, M, N); + let A = A.as_shape(M, N); let mut stack = stack; let (L, H) = { let etree = etree.ghost_inner(N); @@ -504,7 +511,7 @@ pub mod supernodal { let L = crate::sparse::linalg::cholesky::supernodal::ghost_factorize_supernodal_symbolic( A, - col_perm.map(|perm| ghost::PermRef::new(perm, N)), + col_perm.map(|perm| perm.as_shape(N)), Some(min_col), crate::sparse::linalg::cholesky::supernodal::CholeskyInput::ATA, etree, @@ -2358,7 +2365,7 @@ pub fn factorize_symbolic_qr( with_dim!(M, m); with_dim!(N, n); - let A = ghost::SymbolicSparseColMatRef::new(A, M, N); + let A = A.as_shape(M, N); let req = || -> Result { let n_req = StackReq::try_new::(n)?; @@ -2400,12 +2407,12 @@ pub fn factorize_symbolic_qr( colamd::order( &mut col_perm_fwd, &mut col_perm_inv, - A.into_inner(), + A.as_dyn(), params.colamd_params, stack.rb_mut(), )?; - let col_perm = ghost::PermRef::new(PermRef::new_checked(&col_perm_fwd, &col_perm_inv, n), N); + let col_perm = PermRef::new_checked(&col_perm_fwd, &col_perm_inv, n).as_shape(N); let (new_col_ptr, stack) = stack.make_raw::(m + 1); let (new_row_ind, mut stack) = stack.make_raw::(A_nnz); @@ -2470,8 +2477,8 @@ pub fn factorize_symbolic_qr( if threshold == SupernodalThreshold::FORCE_SUPERNODAL { let symbolic = supernodal::factorize_supernodal_symbolic_qr::( - A.into_inner(), - Some(col_perm.into_inner()), + A.as_dyn(), + Some(col_perm.as_shape(n)), min_col, EliminationTreeRef::<'_, I> { inner: etree }, col_counts, @@ -2586,7 +2593,7 @@ mod tests { let mut col_counts = vec![zero; n]; with_dim!(N, n); - let A = ghost::SymbolicSparseColMatRef::new(A, N, N); + let A = A.as_shape(N, N); ghost_col_etree( A, None, @@ -2711,7 +2718,7 @@ mod tests { with_dim!(M, m); with_dim!(N, n); - let A = ghost::SparseColMatRef::new(A, M, N); + let A = A.as_shape(M, N); let mut new_col_ptrs = vec![zero; m + 1]; let mut new_row_ind = vec![zero; nnz]; let mut new_values = vec![0.0; nnz]; @@ -2731,7 +2738,7 @@ mod tests { let mut min_row = vec![zero; m]; ghost_col_etree( - *A, + A.symbolic(), None, Array::from_mut(&mut etree, N), PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(*M + *N))), @@ -2746,7 +2753,7 @@ mod tests { ghost_column_counts_aat( Array::from_mut(&mut col_counts, N), Array::from_mut(bytemuck::cast_slice_mut(&mut min_row), M), - *AT, + AT.symbolic(), None, etree_, Array::from_ref(Idx::from_slice_ref_checked(&post, N), N), @@ -2756,7 +2763,7 @@ mod tests { let min_col = min_row; let symbolic = factorize_supernodal_symbolic_qr::( - A.symbolic().into_inner(), + A.symbolic().as_dyn(), None, min_col, EliminationTreeRef::<'_, I> { inner: &etree }, @@ -2797,7 +2804,7 @@ mod tests { &mut L_values, &mut householder_values, &mut tau_values, - AT.into_inner(), + AT.as_dyn(), None, &symbolic, crate::Parallelism::None, @@ -2810,7 +2817,7 @@ mod tests { )), ); let llt = reconstruct_from_supernodal_llt::(symbolic.r_adjoint(), &L_values); - let a = sparse_to_dense(A.into_inner()); + let a = sparse_to_dense(A.as_dyn()); let ata = a.adjoint() * &a; let llt_diff = &llt - &ata; @@ -2843,7 +2850,7 @@ mod tests { with_dim!(M, m); with_dim!(N, n); - let A = ghost::SparseColMatRef::new(A, M, N); + let A = A.as_shape(M, N); let mut new_col_ptrs = vec![zero; m + 1]; let mut new_row_ind = vec![zero; nnz]; let mut new_values = vec![E::faer_zero(); nnz]; @@ -2867,7 +2874,7 @@ mod tests { let mut min_row = vec![zero; m]; ghost_col_etree( - *A, + A.symbolic(), None, Array::from_mut(&mut etree, N), PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(*M + *N))), @@ -2882,7 +2889,7 @@ mod tests { ghost_column_counts_aat( Array::from_mut(&mut col_counts, N), Array::from_mut(bytemuck::cast_slice_mut(&mut min_row), M), - *AT, + AT.symbolic(), None, etree_, Array::from_ref(Idx::from_slice_ref_checked(&post, N), N), @@ -2892,7 +2899,7 @@ mod tests { let min_col = min_row; let symbolic = factorize_supernodal_symbolic_qr::( - A.symbolic().into_inner(), + A.symbolic().as_dyn(), None, min_col, EliminationTreeRef::<'_, I> { inner: &etree }, @@ -2936,7 +2943,7 @@ mod tests { &mut L_values, &mut householder_values, &mut tau_values, - AT.into_inner(), + AT.as_dyn(), None, &symbolic, crate::Parallelism::None, @@ -2949,7 +2956,7 @@ mod tests { )), ); - let a = sparse_to_dense(A.into_inner()); + let a = sparse_to_dense(A.as_dyn()); let rhs = Mat::::from_fn(m, 2, |_, _| c64::new(gen.gen(), gen.gen())); let mut x = rhs.clone(); @@ -3002,7 +3009,7 @@ mod tests { with_dim!(M, m); with_dim!(N, n); - let A = ghost::SparseColMatRef::new(A, M, N); + let A = A.as_shape(M, N); let mut new_col_ptrs = vec![zero; m + 1]; let mut new_row_ind = vec![zero; nnz]; let mut new_values = vec![E::faer_zero(); nnz]; @@ -3026,7 +3033,7 @@ mod tests { let mut min_row = vec![zero; m]; ghost_col_etree( - *A, + A.symbolic(), None, Array::from_mut(&mut etree, N), PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(*M + *N))), @@ -3041,7 +3048,7 @@ mod tests { ghost_column_counts_aat( Array::from_mut(&mut col_counts, N), Array::from_mut(bytemuck::cast_slice_mut(&mut min_row), M), - *AT, + AT.symbolic(), None, etree_, Array::from_ref(Idx::from_slice_ref_checked(&post, N), N), @@ -3075,14 +3082,14 @@ mod tests { &mut householder_row_indices, &mut householder_values, &mut tau_values, - A.into_inner(), + A.as_dyn(), None, &symbolic, PodStack::new(&mut GlobalPodBuffer::new( factorize_simplicial_numeric_qr_req::(&symbolic).unwrap(), )), ); - let a = sparse_to_dense(A.into_inner()); + let a = sparse_to_dense(A.as_dyn()); let rhs = Mat::::from_fn(m, 2, |_, _| c64::new(gen.gen(), gen.gen())); { let mut x = rhs.clone(); diff --git a/src/sparse/linalg/triangular_solve.rs b/src/sparse/linalg/triangular_solve.rs index 070e7e38..face6751 100644 --- a/src/sparse/linalg/triangular_solve.rs +++ b/src/sparse/linalg/triangular_solve.rs @@ -1,6 +1,14 @@ -use crate::{assert, sparse::SparseColMatRef, utils::slice::*, Conj, Index, MatMut, Parallelism}; -use core::iter::zip; +use crate::{ + assert, + sparse::SparseColMatRef, + utils::{bound, slice::*}, + Conj, Index, MatMut, Parallelism, +}; +use core::iter; use faer_entity::ComplexField; +use reborrow::*; + +// FIXME: unsound get_unchecked(1..) calls /// Assuming `self` is a lower triangular matrix, solves the equation `Op(self) * X = rhs`, and /// stores the result in `rhs`, where `Op` is either the conjugate or the identity depending on the @@ -23,161 +31,161 @@ pub fn solve_lower_triangular_in_place( with_dim!(N, rhs.nrows()); with_dim!(K, rhs.ncols()); - let mut x = crate::utils::constrained::mat::MatMut::new(rhs, N, K); - let l = crate::utils::constrained::sparse::SparseColMatRef::new(l, N, N); - - let mut k = 0usize; - while k < *K { - let bs = Ord::min(*K - k, 4); - match bs { - 1 => { - let k0 = K.check(k); - - for j in N.indices() { - let d = slice_group(l.values_of_col(j)).read(0).faer_inv(); - let d = if conj == Conj::Yes { d.faer_conj() } else { d }; - - let xj0 = x.read(j, k0).faer_mul(d); - x.write(j, k0, xj0); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - x.write(i, k0, x.read(i, k0).faer_sub(lij.faer_mul(xj0))); - } + let mut x = rhs.as_shape_mut(N, K); + let l = l.as_shape(N, N); + + let mut k = bound::IdxInc::::zero(); + while let Some(k0) = K.try_check(*k) { + let k1 = K.try_check(*k + 1); + let k2 = K.try_check(*k + 2); + let k3 = K.try_check(*k + 3); + + if let Some(k3) = k3 { + let mut x = x.rb_mut().subcols_range_mut(k..k3.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = + (x.next(), x.next(), x.next(), x.next()) + else { + panic!() + }; + + for j in N.indices() { + let d = slice_group(l.values_of_col(j)).read(0).faer_inv(); + let d = if conj == Conj::Yes { d.faer_conj() } else { d }; + + let xj0 = x0.read(j).faer_mul(d); + x0.write(j, xj0); + let xj1 = x1.read(j).faer_mul(d); + x1.write(j, xj1); + let xj2 = x2.read(j).faer_mul(d); + x2.write(j, xj2); + let xj3 = x3.read(j).faer_mul(d); + x3.write(j, xj3); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + for (i, lij) in iter::zip( + &row_ind[1..], + slice_group(l.values_of_col(j)) + .subslice(1..len) + .into_ref_iter(), + ) { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + x0.write(i, x0.read(i).faer_sub(lij.faer_mul(xj0))); + x1.write(i, x1.read(i).faer_sub(lij.faer_mul(xj1))); + x2.write(i, x2.read(i).faer_sub(lij.faer_mul(xj2))); + x3.write(i, x3.read(i).faer_sub(lij.faer_mul(xj3))); } } - 2 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - for j in N.indices() { - let d = slice_group(l.values_of_col(j)).read(0).faer_inv(); - let d = if conj == Conj::Yes { d.faer_conj() } else { d }; - - let xj0 = x.read(j, k0).faer_mul(d); - x.write(j, k0, xj0); - let xj1 = x.read(j, k1).faer_mul(d); - x.write(j, k1, xj1); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - x.write(i, k0, x.read(i, k0).faer_sub(lij.faer_mul(xj0))); - x.write(i, k1, x.read(i, k1).faer_sub(lij.faer_mul(xj1))); - } + k = k3.next(); + } else if let Some(k2) = k2 { + let mut x = x.rb_mut().subcols_range_mut(k..k2.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else { + panic!() + }; + + for j in N.indices() { + let d = slice_group(l.values_of_col(j)).read(0).faer_inv(); + let d = if conj == Conj::Yes { d.faer_conj() } else { d }; + + let xj0 = x0.read(j).faer_mul(d); + x0.write(j, xj0); + let xj1 = x1.read(j).faer_mul(d); + x1.write(j, xj1); + let xj2 = x2.read(j).faer_mul(d); + x2.write(j, xj2); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + for (i, lij) in iter::zip( + &row_ind[1..], + slice_group(l.values_of_col(j)) + .subslice(1..len) + .into_ref_iter(), + ) { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + x0.write(i, x0.read(i).faer_sub(lij.faer_mul(xj0))); + x1.write(i, x1.read(i).faer_sub(lij.faer_mul(xj1))); + x2.write(i, x2.read(i).faer_sub(lij.faer_mul(xj2))); } } - 3 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - for j in N.indices() { - let d = slice_group(l.values_of_col(j)).read(0).faer_inv(); - let d = if conj == Conj::Yes { d.faer_conj() } else { d }; - - let xj0 = x.read(j, k0).faer_mul(d); - x.write(j, k0, xj0); - let xj1 = x.read(j, k1).faer_mul(d); - x.write(j, k1, xj1); - let xj2 = x.read(j, k2).faer_mul(d); - x.write(j, k2, xj2); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - x.write(i, k0, x.read(i, k0).faer_sub(lij.faer_mul(xj0))); - x.write(i, k1, x.read(i, k1).faer_sub(lij.faer_mul(xj1))); - x.write(i, k2, x.read(i, k2).faer_sub(lij.faer_mul(xj2))); - } + k = k2.next(); + } else if let Some(k1) = k1 { + let mut x = x.rb_mut().subcols_range_mut(k..k1.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { + panic!() + }; + + for j in N.indices() { + let d = slice_group(l.values_of_col(j)).read(0).faer_inv(); + let d = if conj == Conj::Yes { d.faer_conj() } else { d }; + + let xj0 = x0.read(j).faer_mul(d); + x0.write(j, xj0); + let xj1 = x1.read(j).faer_mul(d); + x1.write(j, xj1); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + for (i, lij) in iter::zip( + &row_ind[1..], + slice_group(l.values_of_col(j)) + .subslice(1..len) + .into_ref_iter(), + ) { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + x0.write(i, x0.read(i).faer_sub(lij.faer_mul(xj0))); + x1.write(i, x1.read(i).faer_sub(lij.faer_mul(xj1))); } } - 4 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - let k3 = K.check(k + 3); - for j in N.indices() { - let d = slice_group(l.values_of_col(j)).read(0).faer_inv(); - let d = if conj == Conj::Yes { d.faer_conj() } else { d }; - - let xj0 = x.read(j, k0).faer_mul(d); - x.write(j, k0, xj0); - let xj1 = x.read(j, k1).faer_mul(d); - x.write(j, k1, xj1); - let xj2 = x.read(j, k2).faer_mul(d); - x.write(j, k2, xj2); - let xj3 = x.read(j, k3).faer_mul(d); - x.write(j, k3, xj3); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - x.write(i, k0, x.read(i, k0).faer_sub(lij.faer_mul(xj0))); - x.write(i, k1, x.read(i, k1).faer_sub(lij.faer_mul(xj1))); - x.write(i, k2, x.read(i, k2).faer_sub(lij.faer_mul(xj2))); - x.write(i, k3, x.read(i, k3).faer_sub(lij.faer_mul(xj3))); - } + k = k1.next(); + } else { + let mut x0 = x.rb_mut().col_mut(k0); + + for j in N.indices() { + let d = slice_group(l.values_of_col(j)).read(0).faer_inv(); + let d = if conj == Conj::Yes { d.faer_conj() } else { d }; + + let xj0 = x0.read(j).faer_mul(d); + x0.write(j, xj0); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + for (i, lij) in iter::zip( + &row_ind[1..], + slice_group(l.values_of_col(j)) + .subslice(1..len) + .into_ref_iter(), + ) { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + x0.write(i, x0.read(i).faer_sub(lij.faer_mul(xj0))); } } - _ => unreachable!(), + k = k0.next(); } - k += bs; } } @@ -202,168 +210,165 @@ pub fn solve_lower_triangular_transpose_in_place( with_dim!(N, rhs.nrows()); with_dim!(K, rhs.ncols()); - let mut x = crate::utils::constrained::mat::MatMut::new(rhs, N, K); - let l = crate::utils::constrained::sparse::SparseColMatRef::new(l, N, N); - - let mut k = 0usize; - while k < *K { - let bs = Ord::min(*K - k, 4); - match bs { - 1 => { - let k0 = K.check(k); - - for j in N.indices().rev() { - let mut acc0a = E::faer_zero(); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - assert!(len >= 1); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - acc0a = acc0a.faer_add(lij.faer_mul(x.read(i, k0))); - } - - let d = slice_group(l.values_of_col(j)).read(0).faer_inv(); - let d = if conj == Conj::Yes { d.faer_conj() } else { d }; - x.write(j, k0, x.read(j, k0).faer_sub(acc0a).faer_mul(d)); + let mut x = rhs.as_shape_mut(N, K); + let l = l.as_shape(N, N); + + let mut k = bound::IdxInc::::zero(); + while let Some(k0) = K.try_check(*k) { + let k1 = K.try_check(*k + 1); + let k2 = K.try_check(*k + 2); + let k3 = K.try_check(*k + 3); + + if let Some(k3) = k3 { + let mut x = x.rb_mut().subcols_range_mut(k..k3.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = + (x.next(), x.next(), x.next(), x.next()) + else { + panic!() + }; + + for j in N.indices().rev() { + let mut acc0a = E::faer_zero(); + let mut acc1a = E::faer_zero(); + let mut acc2a = E::faer_zero(); + let mut acc3a = E::faer_zero(); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + assert!(len >= 1); + for (i, lij) in iter::zip( + &row_ind[1..], + slice_group(l.values_of_col(j)) + .subslice(1..len) + .into_ref_iter(), + ) { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + acc0a = acc0a.faer_add(lij.faer_mul(x0.read(i))); + acc1a = acc1a.faer_add(lij.faer_mul(x1.read(i))); + acc2a = acc2a.faer_add(lij.faer_mul(x2.read(i))); + acc3a = acc3a.faer_add(lij.faer_mul(x3.read(i))); } + + let d = slice_group(l.values_of_col(j)).read(0).faer_inv(); + let d = if conj == Conj::Yes { d.faer_conj() } else { d }; + x0.write(j, x0.read(j).faer_sub(acc0a).faer_mul(d)); + x1.write(j, x1.read(j).faer_sub(acc1a).faer_mul(d)); + x2.write(j, x2.read(j).faer_sub(acc2a).faer_mul(d)); + x3.write(j, x3.read(j).faer_sub(acc3a).faer_mul(d)); } - 2 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - - for j in N.indices().rev() { - let mut acc0a = E::faer_zero(); - let mut acc1a = E::faer_zero(); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - assert!(len >= 1); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - acc0a = acc0a.faer_add(lij.faer_mul(x.read(i, k0))); - acc1a = acc1a.faer_add(lij.faer_mul(x.read(i, k1))); - } - - let d = slice_group(l.values_of_col(j)).read(0).faer_inv(); - let d = if conj == Conj::Yes { d.faer_conj() } else { d }; - x.write(j, k0, x.read(j, k0).faer_sub(acc0a).faer_mul(d)); - x.write(j, k1, x.read(j, k1).faer_sub(acc1a).faer_mul(d)); + k = k3.next(); + } else if let Some(k2) = k2 { + let mut x = x.rb_mut().subcols_range_mut(k..k2.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else { + panic!() + }; + + for j in N.indices().rev() { + let mut acc0a = E::faer_zero(); + let mut acc1a = E::faer_zero(); + let mut acc2a = E::faer_zero(); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + assert!(len >= 1); + for (i, lij) in iter::zip( + &row_ind[1..], + slice_group(l.values_of_col(j)) + .subslice(1..len) + .into_ref_iter(), + ) { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + acc0a = acc0a.faer_add(lij.faer_mul(x0.read(i))); + acc1a = acc1a.faer_add(lij.faer_mul(x1.read(i))); + acc2a = acc2a.faer_add(lij.faer_mul(x2.read(i))); } + + let d = slice_group(l.values_of_col(j)).read(0).faer_inv(); + let d = if conj == Conj::Yes { d.faer_conj() } else { d }; + x0.write(j, x0.read(j).faer_sub(acc0a).faer_mul(d)); + x1.write(j, x1.read(j).faer_sub(acc1a).faer_mul(d)); + x2.write(j, x2.read(j).faer_sub(acc2a).faer_mul(d)); } - 3 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - - for j in N.indices().rev() { - let mut acc0a = E::faer_zero(); - let mut acc1a = E::faer_zero(); - let mut acc2a = E::faer_zero(); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - assert!(len >= 1); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - acc0a = acc0a.faer_add(lij.faer_mul(x.read(i, k0))); - acc1a = acc1a.faer_add(lij.faer_mul(x.read(i, k1))); - acc2a = acc2a.faer_add(lij.faer_mul(x.read(i, k2))); - } - - let d = slice_group(l.values_of_col(j)).read(0).faer_inv(); - let d = if conj == Conj::Yes { d.faer_conj() } else { d }; - x.write(j, k0, x.read(j, k0).faer_sub(acc0a).faer_mul(d)); - x.write(j, k1, x.read(j, k1).faer_sub(acc1a).faer_mul(d)); - x.write(j, k2, x.read(j, k2).faer_sub(acc2a).faer_mul(d)); + k = k2.next(); + } else if let Some(k1) = k1 { + let mut x = x.rb_mut().subcols_range_mut(k..k1.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { + panic!() + }; + + for j in N.indices().rev() { + let mut acc0a = E::faer_zero(); + let mut acc1a = E::faer_zero(); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + assert!(len >= 1); + for (i, lij) in iter::zip( + &row_ind[1..], + slice_group(l.values_of_col(j)) + .subslice(1..len) + .into_ref_iter(), + ) { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + acc0a = acc0a.faer_add(lij.faer_mul(x0.read(i))); + acc1a = acc1a.faer_add(lij.faer_mul(x1.read(i))); } + + let d = slice_group(l.values_of_col(j)).read(0).faer_inv(); + let d = if conj == Conj::Yes { d.faer_conj() } else { d }; + x0.write(j, x0.read(j).faer_sub(acc0a).faer_mul(d)); + x1.write(j, x1.read(j).faer_sub(acc1a).faer_mul(d)); } - 4 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - let k3 = K.check(k + 3); - - for j in N.indices().rev() { - let mut acc0a = E::faer_zero(); - let mut acc1a = E::faer_zero(); - let mut acc2a = E::faer_zero(); - let mut acc3a = E::faer_zero(); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - assert!(len >= 1); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - acc0a = acc0a.faer_add(lij.faer_mul(x.read(i, k0))); - acc1a = acc1a.faer_add(lij.faer_mul(x.read(i, k1))); - acc2a = acc2a.faer_add(lij.faer_mul(x.read(i, k2))); - acc3a = acc3a.faer_add(lij.faer_mul(x.read(i, k3))); - } - - let d = slice_group(l.values_of_col(j)).read(0).faer_inv(); - let d = if conj == Conj::Yes { d.faer_conj() } else { d }; - x.write(j, k0, x.read(j, k0).faer_sub(acc0a).faer_mul(d)); - x.write(j, k1, x.read(j, k1).faer_sub(acc1a).faer_mul(d)); - x.write(j, k2, x.read(j, k2).faer_sub(acc2a).faer_mul(d)); - x.write(j, k3, x.read(j, k3).faer_sub(acc3a).faer_mul(d)); + k = k1.next(); + } else { + let mut x0 = x.rb_mut().col_mut(k0); + + for j in N.indices().rev() { + let mut acc0a = E::faer_zero(); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + assert!(len >= 1); + for (i, lij) in iter::zip( + &row_ind[1..], + slice_group(l.values_of_col(j)) + .subslice(1..len) + .into_ref_iter(), + ) { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + acc0a = acc0a.faer_add(lij.faer_mul(x0.read(i))); } + + let d = slice_group(l.values_of_col(j)).read(0).faer_inv(); + let d = if conj == Conj::Yes { d.faer_conj() } else { d }; + x0.write(j, x0.read(j).faer_sub(acc0a).faer_mul(d)); } - _ => unreachable!(), + k = k0.next(); } - k += bs; } } @@ -389,144 +394,151 @@ pub fn solve_unit_lower_triangular_in_place( with_dim!(N, rhs.nrows()); with_dim!(K, rhs.ncols()); - let mut x = crate::utils::constrained::mat::MatMut::new(rhs, N, K); - let l = crate::utils::constrained::sparse::SparseColMatRef::new(l, N, N); - - let mut k = 0usize; - while k < *K { - let bs = Ord::min(*K - k, 4); - match bs { - 1 => { - let k0 = K.check(k); - - for j in N.indices() { - let xj0 = x.read(j, k0); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - assert!(len >= 1); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - x.write(i, k0, x.read(i, k0).faer_sub(lij.faer_mul(xj0))); - } + let mut x = rhs.as_shape_mut(N, K); + let l = l.as_shape(N, N); + + let mut k = bound::IdxInc::::zero(); + while let Some(k0) = K.try_check(*k) { + let k1 = K.try_check(*k + 1); + let k2 = K.try_check(*k + 2); + let k3 = K.try_check(*k + 3); + + if let Some(k3) = k3 { + let mut x = x.rb_mut().subcols_range_mut(k..k3.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = + (x.next(), x.next(), x.next(), x.next()) + else { + panic!() + }; + + for j in N.indices() { + let xj0 = x0.read(j); + let xj1 = x1.read(j); + let xj2 = x2.read(j); + let xj3 = x3.read(j); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + assert!(len >= 1); + for (i, lij) in unsafe { + iter::zip( + row_ind.get_unchecked(1..), + slice_group(l.values_of_col(j)) + .subslice_unchecked(1..len) + .into_ref_iter(), + ) + } { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + x0.write(i, x0.read(i).faer_sub(lij.faer_mul(xj0))); + x1.write(i, x1.read(i).faer_sub(lij.faer_mul(xj1))); + x2.write(i, x2.read(i).faer_sub(lij.faer_mul(xj2))); + x3.write(i, x3.read(i).faer_sub(lij.faer_mul(xj3))); } } - 2 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - for j in N.indices() { - let xj0 = x.read(j, k0); - let xj1 = x.read(j, k1); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - assert!(len >= 1); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - x.write(i, k0, x.read(i, k0).faer_sub(lij.faer_mul(xj0))); - x.write(i, k1, x.read(i, k1).faer_sub(lij.faer_mul(xj1))); - } + k = k3.next(); + } else if let Some(k2) = k2 { + let mut x = x.rb_mut().subcols_range_mut(k..k2.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else { + panic!() + }; + + for j in N.indices() { + let xj0 = x0.read(j); + let xj1 = x1.read(j); + let xj2 = x2.read(j); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + assert!(len >= 1); + for (i, lij) in unsafe { + iter::zip( + row_ind.get_unchecked(1..), + slice_group(l.values_of_col(j)) + .subslice_unchecked(1..len) + .into_ref_iter(), + ) + } { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + x0.write(i, x0.read(i).faer_sub(lij.faer_mul(xj0))); + x1.write(i, x1.read(i).faer_sub(lij.faer_mul(xj1))); + x2.write(i, x2.read(i).faer_sub(lij.faer_mul(xj2))); } } - 3 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - for j in N.indices() { - let xj0 = x.read(j, k0); - let xj1 = x.read(j, k1); - let xj2 = x.read(j, k2); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - assert!(len >= 1); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - x.write(i, k0, x.read(i, k0).faer_sub(lij.faer_mul(xj0))); - x.write(i, k1, x.read(i, k1).faer_sub(lij.faer_mul(xj1))); - x.write(i, k2, x.read(i, k2).faer_sub(lij.faer_mul(xj2))); - } + k = k2.next(); + } else if let Some(k1) = k1 { + let mut x = x.rb_mut().subcols_range_mut(k..k1.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { + panic!() + }; + + for j in N.indices() { + let xj0 = x0.read(j); + let xj1 = x1.read(j); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + assert!(len >= 1); + for (i, lij) in unsafe { + iter::zip( + row_ind.get_unchecked(1..), + slice_group(l.values_of_col(j)) + .subslice_unchecked(1..len) + .into_ref_iter(), + ) + } { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + x0.write(i, x0.read(i).faer_sub(lij.faer_mul(xj0))); + x1.write(i, x1.read(i).faer_sub(lij.faer_mul(xj1))); } } - 4 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - let k3 = K.check(k + 3); - for j in N.indices() { - let xj0 = x.read(j, k0); - let xj1 = x.read(j, k1); - let xj2 = x.read(j, k2); - let xj3 = x.read(j, k3); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - assert!(len >= 1); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - x.write(i, k0, x.read(i, k0).faer_sub(lij.faer_mul(xj0))); - x.write(i, k1, x.read(i, k1).faer_sub(lij.faer_mul(xj1))); - x.write(i, k2, x.read(i, k2).faer_sub(lij.faer_mul(xj2))); - x.write(i, k3, x.read(i, k3).faer_sub(lij.faer_mul(xj3))); - } + k = k1.next(); + } else { + let mut x0 = x.rb_mut().col_mut(k0); + + for j in N.indices() { + let xj0 = x0.read(j); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + assert!(len >= 1); + for (i, lij) in unsafe { + iter::zip( + row_ind.get_unchecked(1..), + slice_group(l.values_of_col(j)) + .subslice_unchecked(1..len) + .into_ref_iter(), + ) + } { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + x0.write(i, x0.read(i).faer_sub(lij.faer_mul(xj0))); } } - _ => unreachable!(), + k = k0.next(); } - - k += bs; } } @@ -552,160 +564,165 @@ pub fn solve_unit_lower_triangular_transpose_in_place with_dim!(N, rhs.nrows()); with_dim!(K, rhs.ncols()); - let mut x = crate::utils::constrained::mat::MatMut::new(rhs, N, K); - let l = crate::utils::constrained::sparse::SparseColMatRef::new(l, N, N); - - let mut k = 0usize; - while k < *K { - let bs = Ord::min(*K - k, 4); - match bs { - 1 => { - let k0 = K.check(k); - - for j in N.indices().rev() { - let mut acc0a = E::faer_zero(); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - assert!(len >= 1); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - acc0a = acc0a.faer_add(lij.faer_mul(x.read(i, k0))); - } - - x.write(j, k0, x.read(j, k0).faer_sub(acc0a)); + let mut x = rhs.as_shape_mut(N, K); + let l = l.as_shape(N, N); + + let mut k = bound::IdxInc::::zero(); + while let Some(k0) = K.try_check(*k) { + let k1 = K.try_check(*k + 1); + let k2 = K.try_check(*k + 2); + let k3 = K.try_check(*k + 3); + + if let Some(k3) = k3 { + let mut x = x.rb_mut().subcols_range_mut(k..k3.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = + (x.next(), x.next(), x.next(), x.next()) + else { + panic!() + }; + + for j in N.indices().rev() { + let mut acc0a = E::faer_zero(); + let mut acc1a = E::faer_zero(); + let mut acc2a = E::faer_zero(); + let mut acc3a = E::faer_zero(); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + assert!(len >= 1); + for (i, lij) in unsafe { + iter::zip( + row_ind.get_unchecked(1..), + slice_group(l.values_of_col(j)) + .subslice_unchecked(1..len) + .into_ref_iter(), + ) + } { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + acc0a = acc0a.faer_add(lij.faer_mul(x0.read(i))); + acc1a = acc1a.faer_add(lij.faer_mul(x1.read(i))); + acc2a = acc2a.faer_add(lij.faer_mul(x2.read(i))); + acc3a = acc3a.faer_add(lij.faer_mul(x3.read(i))); } + + x0.write(j, x0.read(j).faer_sub(acc0a)); + x1.write(j, x1.read(j).faer_sub(acc1a)); + x2.write(j, x2.read(j).faer_sub(acc2a)); + x3.write(j, x3.read(j).faer_sub(acc3a)); } - 2 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - - for j in N.indices().rev() { - let mut acc0a = E::faer_zero(); - let mut acc1a = E::faer_zero(); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - assert!(len >= 1); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - acc0a = acc0a.faer_add(lij.faer_mul(x.read(i, k0))); - acc1a = acc1a.faer_add(lij.faer_mul(x.read(i, k1))); - } - - x.write(j, k0, x.read(j, k0).faer_sub(acc0a)); - x.write(j, k1, x.read(j, k1).faer_sub(acc1a)); + k = k3.next(); + } else if let Some(k2) = k2 { + let mut x = x.rb_mut().subcols_range_mut(k..k2.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else { + panic!() + }; + + for j in N.indices().rev() { + let mut acc0a = E::faer_zero(); + let mut acc1a = E::faer_zero(); + let mut acc2a = E::faer_zero(); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + assert!(len >= 1); + for (i, lij) in unsafe { + iter::zip( + row_ind.get_unchecked(1..), + slice_group(l.values_of_col(j)) + .subslice_unchecked(1..len) + .into_ref_iter(), + ) + } { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + acc0a = acc0a.faer_add(lij.faer_mul(x0.read(i))); + acc1a = acc1a.faer_add(lij.faer_mul(x1.read(i))); + acc2a = acc2a.faer_add(lij.faer_mul(x2.read(i))); } + + x0.write(j, x0.read(j).faer_sub(acc0a)); + x1.write(j, x1.read(j).faer_sub(acc1a)); + x2.write(j, x2.read(j).faer_sub(acc2a)); } - 3 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - - for j in N.indices().rev() { - let mut acc0a = E::faer_zero(); - let mut acc1a = E::faer_zero(); - let mut acc2a = E::faer_zero(); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - assert!(len >= 1); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - acc0a = acc0a.faer_add(lij.faer_mul(x.read(i, k0))); - acc1a = acc1a.faer_add(lij.faer_mul(x.read(i, k1))); - acc2a = acc2a.faer_add(lij.faer_mul(x.read(i, k2))); - } - - x.write(j, k0, x.read(j, k0).faer_sub(acc0a)); - x.write(j, k1, x.read(j, k1).faer_sub(acc1a)); - x.write(j, k2, x.read(j, k2).faer_sub(acc2a)); + k = k2.next(); + } else if let Some(k1) = k1 { + let mut x = x.rb_mut().subcols_range_mut(k..k1.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { + panic!() + }; + + for j in N.indices().rev() { + let mut acc0a = E::faer_zero(); + let mut acc1a = E::faer_zero(); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + assert!(len >= 1); + for (i, lij) in unsafe { + iter::zip( + row_ind.get_unchecked(1..), + slice_group(l.values_of_col(j)) + .subslice_unchecked(1..len) + .into_ref_iter(), + ) + } { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + acc0a = acc0a.faer_add(lij.faer_mul(x0.read(i))); + acc1a = acc1a.faer_add(lij.faer_mul(x1.read(i))); } + + x0.write(j, x0.read(j).faer_sub(acc0a)); + x1.write(j, x1.read(j).faer_sub(acc1a)); } - 4 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - let k3 = K.check(k + 3); - - for j in N.indices().rev() { - let mut acc0a = E::faer_zero(); - let mut acc1a = E::faer_zero(); - let mut acc2a = E::faer_zero(); - let mut acc3a = E::faer_zero(); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - assert!(len >= 1); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - acc0a = acc0a.faer_add(lij.faer_mul(x.read(i, k0))); - acc1a = acc1a.faer_add(lij.faer_mul(x.read(i, k1))); - acc2a = acc2a.faer_add(lij.faer_mul(x.read(i, k2))); - acc3a = acc3a.faer_add(lij.faer_mul(x.read(i, k3))); - } - - x.write(j, k0, x.read(j, k0).faer_sub(acc0a)); - x.write(j, k1, x.read(j, k1).faer_sub(acc1a)); - x.write(j, k2, x.read(j, k2).faer_sub(acc2a)); - x.write(j, k3, x.read(j, k3).faer_sub(acc3a)); + k = k1.next(); + } else { + let mut x0 = x.rb_mut().col_mut(k0); + + for j in N.indices().rev() { + let mut acc0a = E::faer_zero(); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + assert!(len >= 1); + for (i, lij) in unsafe { + iter::zip( + row_ind.get_unchecked(1..), + slice_group(l.values_of_col(j)) + .subslice_unchecked(1..len) + .into_ref_iter(), + ) + } { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + acc0a = acc0a.faer_add(lij.faer_mul(x0.read(i))); } + + x0.write(j, x0.read(j).faer_sub(acc0a)); } - _ => unreachable!(), + k = k0.next(); } - k += bs; } } @@ -725,180 +742,185 @@ pub fn ldlt_scale_solve_unit_lower_triangular_transpose_in_place { - let k0 = K.check(k); - - for j in N.indices().rev() { - let d = slice_group(l.values_of_col(j)) - .read(0) - .faer_real() - .faer_inv(); - - let mut acc0a = E::faer_zero(); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - assert!(len >= 1); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - acc0a = acc0a.faer_add(lij.faer_mul(x.read(i, k0))); - } - - x.write(j, k0, x.read(j, k0).faer_scale_real(d).faer_sub(acc0a)); + let mut x = rhs.as_shape_mut(N, K); + let l = l.as_shape(N, N); + + let mut k = bound::IdxInc::::zero(); + while let Some(k0) = K.try_check(*k) { + let k1 = K.try_check(*k + 1); + let k2 = K.try_check(*k + 2); + let k3 = K.try_check(*k + 3); + + if let Some(k3) = k3 { + let mut x = x.rb_mut().subcols_range_mut(k..k3.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = + (x.next(), x.next(), x.next(), x.next()) + else { + panic!() + }; + + for j in N.indices().rev() { + let d = slice_group(l.values_of_col(j)) + .read(0) + .faer_real() + .faer_inv(); + + let mut acc0a = E::faer_zero(); + let mut acc1a = E::faer_zero(); + let mut acc2a = E::faer_zero(); + let mut acc3a = E::faer_zero(); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + assert!(len >= 1); + for (i, lij) in unsafe { + iter::zip( + row_ind.get_unchecked(1..), + slice_group(l.values_of_col(j)) + .subslice_unchecked(1..len) + .into_ref_iter(), + ) + } { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + acc0a = acc0a.faer_add(lij.faer_mul(x0.read(i))); + acc1a = acc1a.faer_add(lij.faer_mul(x1.read(i))); + acc2a = acc2a.faer_add(lij.faer_mul(x2.read(i))); + acc3a = acc3a.faer_add(lij.faer_mul(x3.read(i))); } + + x0.write(j, x0.read(j).faer_scale_real(d).faer_sub(acc0a)); + x1.write(j, x1.read(j).faer_scale_real(d).faer_sub(acc1a)); + x2.write(j, x2.read(j).faer_scale_real(d).faer_sub(acc2a)); + x3.write(j, x3.read(j).faer_scale_real(d).faer_sub(acc3a)); } - 2 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - - for j in N.indices().rev() { - let d = slice_group(l.values_of_col(j)) - .read(0) - .faer_real() - .faer_inv(); - - let mut acc0a = E::faer_zero(); - let mut acc1a = E::faer_zero(); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - assert!(len >= 1); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - acc0a = acc0a.faer_add(lij.faer_mul(x.read(i, k0))); - acc1a = acc1a.faer_add(lij.faer_mul(x.read(i, k1))); - } - - x.write(j, k0, x.read(j, k0).faer_scale_real(d).faer_sub(acc0a)); - x.write(j, k1, x.read(j, k1).faer_scale_real(d).faer_sub(acc1a)); + k = k3.next(); + } else if let Some(k2) = k2 { + let mut x = x.rb_mut().subcols_range_mut(k..k2.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else { + panic!() + }; + + for j in N.indices().rev() { + let d = slice_group(l.values_of_col(j)) + .read(0) + .faer_real() + .faer_inv(); + + let mut acc0a = E::faer_zero(); + let mut acc1a = E::faer_zero(); + let mut acc2a = E::faer_zero(); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + assert!(len >= 1); + for (i, lij) in unsafe { + iter::zip( + row_ind.get_unchecked(1..), + slice_group(l.values_of_col(j)) + .subslice_unchecked(1..len) + .into_ref_iter(), + ) + } { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + acc0a = acc0a.faer_add(lij.faer_mul(x0.read(i))); + acc1a = acc1a.faer_add(lij.faer_mul(x1.read(i))); + acc2a = acc2a.faer_add(lij.faer_mul(x2.read(i))); } + + x0.write(j, x0.read(j).faer_scale_real(d).faer_sub(acc0a)); + x1.write(j, x1.read(j).faer_scale_real(d).faer_sub(acc1a)); + x2.write(j, x2.read(j).faer_scale_real(d).faer_sub(acc2a)); } - 3 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - - for j in N.indices().rev() { - let d = slice_group(l.values_of_col(j)) - .read(0) - .faer_real() - .faer_inv(); - - let mut acc0a = E::faer_zero(); - let mut acc1a = E::faer_zero(); - let mut acc2a = E::faer_zero(); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - assert!(len >= 1); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - acc0a = acc0a.faer_add(lij.faer_mul(x.read(i, k0))); - acc1a = acc1a.faer_add(lij.faer_mul(x.read(i, k1))); - acc2a = acc2a.faer_add(lij.faer_mul(x.read(i, k2))); - } - - x.write(j, k0, x.read(j, k0).faer_scale_real(d).faer_sub(acc0a)); - x.write(j, k1, x.read(j, k1).faer_scale_real(d).faer_sub(acc1a)); - x.write(j, k2, x.read(j, k2).faer_scale_real(d).faer_sub(acc2a)); + k = k2.next(); + } else if let Some(k1) = k1 { + let mut x = x.rb_mut().subcols_range_mut(k..k1.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { + panic!() + }; + + for j in N.indices().rev() { + let d = slice_group(l.values_of_col(j)) + .read(0) + .faer_real() + .faer_inv(); + + let mut acc0a = E::faer_zero(); + let mut acc1a = E::faer_zero(); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + assert!(len >= 1); + for (i, lij) in unsafe { + iter::zip( + row_ind.get_unchecked(1..), + slice_group(l.values_of_col(j)) + .subslice_unchecked(1..len) + .into_ref_iter(), + ) + } { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + acc0a = acc0a.faer_add(lij.faer_mul(x0.read(i))); + acc1a = acc1a.faer_add(lij.faer_mul(x1.read(i))); } + + x0.write(j, x0.read(j).faer_scale_real(d).faer_sub(acc0a)); + x1.write(j, x1.read(j).faer_scale_real(d).faer_sub(acc1a)); } - 4 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - let k3 = K.check(k + 3); - - for j in N.indices().rev() { - let d = slice_group(l.values_of_col(j)) - .read(0) - .faer_real() - .faer_inv(); - - let mut acc0a = E::faer_zero(); - let mut acc1a = E::faer_zero(); - let mut acc2a = E::faer_zero(); - let mut acc3a = E::faer_zero(); - - let row_ind = l.row_indices_of_col_raw(j); - let len = row_ind.len(); - assert!(len >= 1); - for (i, lij) in unsafe { - zip( - row_ind.get_unchecked(1..), - slice_group(l.values_of_col(j)) - .subslice_unchecked(1..len) - .into_ref_iter(), - ) - } { - let i = i.zx(); - let lij = lij.read(); - let lij = if conj == Conj::Yes { - lij.faer_conj() - } else { - lij - }; - acc0a = acc0a.faer_add(lij.faer_mul(x.read(i, k0))); - acc1a = acc1a.faer_add(lij.faer_mul(x.read(i, k1))); - acc2a = acc2a.faer_add(lij.faer_mul(x.read(i, k2))); - acc3a = acc3a.faer_add(lij.faer_mul(x.read(i, k3))); - } - - x.write(j, k0, x.read(j, k0).faer_scale_real(d).faer_sub(acc0a)); - x.write(j, k1, x.read(j, k1).faer_scale_real(d).faer_sub(acc1a)); - x.write(j, k2, x.read(j, k2).faer_scale_real(d).faer_sub(acc2a)); - x.write(j, k3, x.read(j, k3).faer_scale_real(d).faer_sub(acc3a)); + k = k1.next(); + } else { + let mut x0 = x.rb_mut().col_mut(k0); + + for j in N.indices().rev() { + let d = slice_group(l.values_of_col(j)) + .read(0) + .faer_real() + .faer_inv(); + + let mut acc0a = E::faer_zero(); + + let row_ind = l.row_indices_of_col_raw(j); + let len = row_ind.len(); + assert!(len >= 1); + for (i, lij) in unsafe { + iter::zip( + row_ind.get_unchecked(1..), + slice_group(l.values_of_col(j)) + .subslice_unchecked(1..len) + .into_ref_iter(), + ) + } { + let i = i.zx(); + let lij = lij.read(); + let lij = if conj == Conj::Yes { + lij.faer_conj() + } else { + lij + }; + acc0a = acc0a.faer_add(lij.faer_mul(x0.read(i))); } + + x0.write(j, x0.read(j).faer_scale_real(d).faer_sub(acc0a)); } - _ => unreachable!(), + k = k0.next(); } - k += bs; } } @@ -922,164 +944,169 @@ pub fn solve_upper_triangular_in_place( with_dim!(N, rhs.nrows()); with_dim!(K, rhs.ncols()); - let mut x = crate::utils::constrained::mat::MatMut::new(rhs, N, K); - let u = crate::utils::constrained::sparse::SparseColMatRef::new(u, N, N); - - let mut k = 0usize; - while k < *K { - let bs = Ord::min(*K - k, 4); - match bs { - 1 => { - let k0 = K.check(k); - - for j in N.indices().rev() { - let ui = u.row_indices_of_col_raw(j); - let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); - - let u_inv = ux.read(ui.len() - 1).faer_inv(); - let u_inv = if conj == Conj::Yes { - u_inv.faer_conj() + let mut x = rhs.as_shape_mut(N, K); + let u = u.as_shape(N, N); + + let mut k = bound::IdxInc::::zero(); + while let Some(k0) = K.try_check(*k) { + let k1 = K.try_check(*k + 1); + let k2 = K.try_check(*k + 2); + let k3 = K.try_check(*k + 3); + + if let Some(k3) = k3 { + let mut x = x.rb_mut().subcols_range_mut(k..k3.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = + (x.next(), x.next(), x.next(), x.next()) + else { + panic!() + }; + + for j in N.indices().rev() { + let ui = u.row_indices_of_col_raw(j); + let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); + + let u_inv = ux.read(ui.len() - 1).faer_inv(); + let u_inv = if conj == Conj::Yes { + u_inv.faer_conj() + } else { + u_inv + }; + let xj0 = x0.read(j).faer_mul(u_inv); + let xj1 = x1.read(j).faer_mul(u_inv); + let xj2 = x2.read(j).faer_mul(u_inv); + let xj3 = x3.read(j).faer_mul(u_inv); + x0.write(j, xj0); + x1.write(j, xj1); + x2.write(j, xj2); + x3.write(j, xj3); + + for (i, u) in iter::zip( + &ui[..ui.len() - 1], + ux.subslice(0..ui.len() - 1).into_ref_iter(), + ) { + let i = i.zx(); + let u = if conj == Conj::Yes { + u.read().faer_conj() } else { - u_inv + u.read() }; - let xj = x.read(j, k0).faer_mul(u_inv); - x.write(j, k0, xj); - - for (i, u) in zip( - &ui[..ui.len() - 1], - ux.subslice(0..ui.len() - 1).into_ref_iter(), - ) { - let i = i.zx(); - let u = if conj == Conj::Yes { - u.read().faer_conj() - } else { - u.read() - }; - - x.write(i, k0, x.read(i, k0).faer_sub(E::faer_mul(u, xj))); - } + + x0.write(i, x0.read(i).faer_sub(E::faer_mul(u, xj0))); + x1.write(i, x1.read(i).faer_sub(E::faer_mul(u, xj1))); + x2.write(i, x2.read(i).faer_sub(E::faer_mul(u, xj2))); + x3.write(i, x3.read(i).faer_sub(E::faer_mul(u, xj3))); } } - 2 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - - for j in N.indices().rev() { - let ui = u.row_indices_of_col_raw(j); - let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); - - let u_inv = ux.read(ui.len() - 1).faer_inv(); - let u_inv = if conj == Conj::Yes { - u_inv.faer_conj() + k = k3.next(); + } else if let Some(k2) = k2 { + let mut x = x.rb_mut().subcols_range_mut(k..k2.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else { + panic!() + }; + + for j in N.indices().rev() { + let ui = u.row_indices_of_col_raw(j); + let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); + + let u_inv = ux.read(ui.len() - 1).faer_inv(); + let u_inv = if conj == Conj::Yes { + u_inv.faer_conj() + } else { + u_inv + }; + let xj0 = x0.read(j).faer_mul(u_inv); + let xj1 = x1.read(j).faer_mul(u_inv); + let xj2 = x2.read(j).faer_mul(u_inv); + x0.write(j, xj0); + x1.write(j, xj1); + x2.write(j, xj2); + + for (i, u) in iter::zip( + &ui[..ui.len() - 1], + ux.subslice(0..ui.len() - 1).into_ref_iter(), + ) { + let i = i.zx(); + let u = if conj == Conj::Yes { + u.read().faer_conj() } else { - u_inv + u.read() }; - let xj0 = x.read(j, k0).faer_mul(u_inv); - let xj1 = x.read(j, k1).faer_mul(u_inv); - x.write(j, k0, xj0); - x.write(j, k1, xj1); - - for (i, u) in zip( - &ui[..ui.len() - 1], - ux.subslice(0..ui.len() - 1).into_ref_iter(), - ) { - let i = i.zx(); - let u = if conj == Conj::Yes { - u.read().faer_conj() - } else { - u.read() - }; - - x.write(i, k0, x.read(i, k0).faer_sub(E::faer_mul(u, xj0))); - x.write(i, k1, x.read(i, k1).faer_sub(E::faer_mul(u, xj1))); - } + + x0.write(i, x0.read(i).faer_sub(E::faer_mul(u, xj0))); + x1.write(i, x1.read(i).faer_sub(E::faer_mul(u, xj1))); + x2.write(i, x2.read(i).faer_sub(E::faer_mul(u, xj2))); } } - 3 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - - for j in N.indices().rev() { - let ui = u.row_indices_of_col_raw(j); - let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); - - let u_inv = ux.read(ui.len() - 1).faer_inv(); - let u_inv = if conj == Conj::Yes { - u_inv.faer_conj() + k = k2.next(); + } else if let Some(k1) = k1 { + let mut x = x.rb_mut().subcols_range_mut(k..k1.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { + panic!() + }; + + for j in N.indices().rev() { + let ui = u.row_indices_of_col_raw(j); + let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); + + let u_inv = ux.read(ui.len() - 1).faer_inv(); + let u_inv = if conj == Conj::Yes { + u_inv.faer_conj() + } else { + u_inv + }; + let xj0 = x0.read(j).faer_mul(u_inv); + let xj1 = x1.read(j).faer_mul(u_inv); + x0.write(j, xj0); + x1.write(j, xj1); + + for (i, u) in iter::zip( + &ui[..ui.len() - 1], + ux.subslice(0..ui.len() - 1).into_ref_iter(), + ) { + let i = i.zx(); + let u = if conj == Conj::Yes { + u.read().faer_conj() } else { - u_inv + u.read() }; - let xj0 = x.read(j, k0).faer_mul(u_inv); - let xj1 = x.read(j, k1).faer_mul(u_inv); - let xj2 = x.read(j, k2).faer_mul(u_inv); - x.write(j, k0, xj0); - x.write(j, k1, xj1); - x.write(j, k2, xj2); - - for (i, u) in zip( - &ui[..ui.len() - 1], - ux.subslice(0..ui.len() - 1).into_ref_iter(), - ) { - let i = i.zx(); - let u = if conj == Conj::Yes { - u.read().faer_conj() - } else { - u.read() - }; - - x.write(i, k0, x.read(i, k0).faer_sub(E::faer_mul(u, xj0))); - x.write(i, k1, x.read(i, k1).faer_sub(E::faer_mul(u, xj1))); - x.write(i, k2, x.read(i, k2).faer_sub(E::faer_mul(u, xj2))); - } + + x0.write(i, x0.read(i).faer_sub(E::faer_mul(u, xj0))); + x1.write(i, x1.read(i).faer_sub(E::faer_mul(u, xj1))); } } - 4 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - let k3 = K.check(k + 3); - - for j in N.indices().rev() { - let ui = u.row_indices_of_col_raw(j); - let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); - - let u_inv = ux.read(ui.len() - 1).faer_inv(); - let u_inv = if conj == Conj::Yes { - u_inv.faer_conj() + k = k1.next(); + } else { + let mut x0 = x.rb_mut().col_mut(k0); + + for j in N.indices().rev() { + let ui = u.row_indices_of_col_raw(j); + let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); + + let u_inv = ux.read(ui.len() - 1).faer_inv(); + let u_inv = if conj == Conj::Yes { + u_inv.faer_conj() + } else { + u_inv + }; + let xj = x0.read(j).faer_mul(u_inv); + x0.write(j, xj); + + for (i, u) in iter::zip( + &ui[..ui.len() - 1], + ux.subslice(0..ui.len() - 1).into_ref_iter(), + ) { + let i = i.zx(); + let u = if conj == Conj::Yes { + u.read().faer_conj() } else { - u_inv + u.read() }; - let xj0 = x.read(j, k0).faer_mul(u_inv); - let xj1 = x.read(j, k1).faer_mul(u_inv); - let xj2 = x.read(j, k2).faer_mul(u_inv); - let xj3 = x.read(j, k3).faer_mul(u_inv); - x.write(j, k0, xj0); - x.write(j, k1, xj1); - x.write(j, k2, xj2); - x.write(j, k3, xj3); - - for (i, u) in zip( - &ui[..ui.len() - 1], - ux.subslice(0..ui.len() - 1).into_ref_iter(), - ) { - let i = i.zx(); - let u = if conj == Conj::Yes { - u.read().faer_conj() - } else { - u.read() - }; - - x.write(i, k0, x.read(i, k0).faer_sub(E::faer_mul(u, xj0))); - x.write(i, k1, x.read(i, k1).faer_sub(E::faer_mul(u, xj1))); - x.write(i, k2, x.read(i, k2).faer_sub(E::faer_mul(u, xj2))); - x.write(i, k3, x.read(i, k3).faer_sub(E::faer_mul(u, xj3))); - } + + x0.write(i, x0.read(i).faer_sub(E::faer_mul(u, xj))); } } - _ => unreachable!(), + k = k0.next(); } - k += bs; } } @@ -1103,252 +1130,249 @@ pub fn solve_upper_triangular_transpose_in_place( with_dim!(N, rhs.nrows()); with_dim!(K, rhs.ncols()); - let mut x = crate::utils::constrained::mat::MatMut::new(rhs, N, K); - let u = crate::utils::constrained::sparse::SparseColMatRef::new(u, N, N); - - let mut k = 0usize; - while k < *K { - let bs = Ord::min(*K - k, 4); - match bs { - 1 => { - let k0 = K.check(k); - for j in N.indices() { - let ui = u.row_indices_of_col_raw(j); - let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); - - let u_inv = ux.read(ui.len() - 1).faer_inv(); - let u_inv = if conj == Conj::Yes { - u_inv.faer_conj() + let mut x = rhs.as_shape_mut(N, K); + let u = u.as_shape(N, N); + + let mut k = bound::IdxInc::::zero(); + while let Some(k0) = K.try_check(*k) { + let k1 = K.try_check(*k + 1); + let k2 = K.try_check(*k + 2); + let k3 = K.try_check(*k + 3); + + if let Some(k3) = k3 { + let mut x = x.rb_mut().subcols_range_mut(k..k3.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = + (x.next(), x.next(), x.next(), x.next()) + else { + panic!() + }; + + for j in N.indices() { + let ui = u.row_indices_of_col_raw(j); + let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); + + let u_inv = ux.read(ui.len() - 1).faer_inv(); + let u_inv = if conj == Conj::Yes { + u_inv.faer_conj() + } else { + u_inv + }; + + let mut acc0a = E::faer_zero(); + let mut acc1a = E::faer_zero(); + let mut acc2a = E::faer_zero(); + let mut acc3a = E::faer_zero(); + + let rows = &ui[..ui.len() - 1]; + let values = ux.subslice(0..ui.len() - 1); + + for (i, uij) in iter::zip(rows, values.into_ref_iter()) { + let uija = uij.read(); + let uija = if conj == Conj::Yes { + uija.faer_conj() } else { - u_inv + uija }; - - let mut acc0a = E::faer_zero(); - let mut acc0b = E::faer_zero(); - let mut acc0c = E::faer_zero(); - let mut acc0d = E::faer_zero(); - - let a = 0; - let b = 1; - let c = 2; - let d = 3; - - let rows_head = ui[..ui.len() - 1].chunks_exact(4); - let rows_tail = rows_head.remainder(); - let (values_head, values_tail) = - ux.subslice(0..ui.len() - 1).into_chunks_exact(4); - - for (i, uij) in zip(rows_head, values_head) { - let uija = uij.read(a); - let uijb = uij.read(b); - let uijc = uij.read(c); - let uijd = uij.read(d); - let uija = if conj == Conj::Yes { - uija.faer_conj() - } else { - uija - }; - let uijb = if conj == Conj::Yes { - uijb.faer_conj() - } else { - uijb - }; - let uijc = if conj == Conj::Yes { - uijc.faer_conj() - } else { - uijc - }; - let uijd = if conj == Conj::Yes { - uijd.faer_conj() - } else { - uijd - }; - acc0a = acc0a.faer_add(uija.faer_mul(x.read(i[a].zx(), k0))); - acc0b = acc0b.faer_add(uijb.faer_mul(x.read(i[b].zx(), k0))); - acc0c = acc0c.faer_add(uijc.faer_mul(x.read(i[c].zx(), k0))); - acc0d = acc0d.faer_add(uijd.faer_mul(x.read(i[d].zx(), k0))); - } - - for (i, uij) in zip(rows_tail, values_tail.into_ref_iter()) { - let uija = uij.read(); - let uija = if conj == Conj::Yes { - uija.faer_conj() - } else { - uija - }; - acc0a = acc0a.faer_add(uija.faer_mul(x.read(i.zx(), k0))); - } - - x.write( - j, - k0, - x.read(j, k0) - .faer_sub(acc0a.faer_add(acc0b).faer_add(acc0c.faer_add(acc0d))) - .faer_mul(u_inv), - ); + acc0a = acc0a.faer_add(uija.faer_mul(x0.read(i.zx()))); + acc1a = acc1a.faer_add(uija.faer_mul(x1.read(i.zx()))); + acc2a = acc2a.faer_add(uija.faer_mul(x2.read(i.zx()))); + acc3a = acc3a.faer_add(uija.faer_mul(x3.read(i.zx()))); } - } - 2 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - for j in N.indices() { - let ui = u.row_indices_of_col_raw(j); - let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); - - let u_inv = ux.read(ui.len() - 1).faer_inv(); - let u_inv = if conj == Conj::Yes { - u_inv.faer_conj() + x0.write(j, x0.read(j).faer_sub(acc0a).faer_mul(u_inv)); + x1.write(j, x1.read(j).faer_sub(acc1a).faer_mul(u_inv)); + x2.write(j, x2.read(j).faer_sub(acc2a).faer_mul(u_inv)); + x3.write(j, x3.read(j).faer_sub(acc3a).faer_mul(u_inv)); + } + k = k3.next(); + } else if let Some(k2) = k2 { + let mut x = x.rb_mut().subcols_range_mut(k..k2.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else { + panic!() + }; + + for j in N.indices() { + let ui = u.row_indices_of_col_raw(j); + let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); + + let u_inv = ux.read(ui.len() - 1).faer_inv(); + let u_inv = if conj == Conj::Yes { + u_inv.faer_conj() + } else { + u_inv + }; + + let mut acc0a = E::faer_zero(); + let mut acc1a = E::faer_zero(); + let mut acc2a = E::faer_zero(); + + let rows = &ui[..ui.len() - 1]; + let values = ux.subslice(0..ui.len() - 1); + + for (i, uij) in iter::zip(rows, values.into_ref_iter()) { + let uija = uij.read(); + let uija = if conj == Conj::Yes { + uija.faer_conj() } else { - u_inv + uija }; - - let mut acc0a = E::faer_zero(); - let mut acc0b = E::faer_zero(); - let mut acc1a = E::faer_zero(); - let mut acc1b = E::faer_zero(); - - let a = 0; - let b = 1; - - let rows_head = ui[..ui.len() - 1].chunks_exact(2); - let rows_tail = rows_head.remainder(); - let (values_head, values_tail) = - ux.subslice(0..ui.len() - 1).into_chunks_exact(2); - - for (i, uij) in zip(rows_head, values_head) { - let uija = uij.read(a); - let uijb = uij.read(b); - let uija = if conj == Conj::Yes { - uija.faer_conj() - } else { - uija - }; - let uijb = if conj == Conj::Yes { - uijb.faer_conj() - } else { - uijb - }; - acc0a = acc0a.faer_add(uija.faer_mul(x.read(i[a].zx(), k0))); - acc0b = acc0b.faer_add(uijb.faer_mul(x.read(i[b].zx(), k0))); - acc1a = acc1a.faer_add(uija.faer_mul(x.read(i[a].zx(), k1))); - acc1b = acc1b.faer_add(uijb.faer_mul(x.read(i[b].zx(), k1))); - } - - for (i, uij) in zip(rows_tail, values_tail.into_ref_iter()) { - let uija = uij.read(); - let uija = if conj == Conj::Yes { - uija.faer_conj() - } else { - uija - }; - acc0a = acc0a.faer_add(uija.faer_mul(x.read(i.zx(), k0))); - acc1a = acc1a.faer_add(uija.faer_mul(x.read(i.zx(), k1))); - } - - x.write( - j, - k0, - x.read(j, k0) - .faer_sub(acc0a.faer_add(acc0b)) - .faer_mul(u_inv), - ); - x.write( - j, - k1, - x.read(j, k1) - .faer_sub(acc1a.faer_add(acc1b)) - .faer_mul(u_inv), - ); + acc0a = acc0a.faer_add(uija.faer_mul(x0.read(i.zx()))); + acc1a = acc1a.faer_add(uija.faer_mul(x1.read(i.zx()))); + acc2a = acc2a.faer_add(uija.faer_mul(x2.read(i.zx()))); } + + x0.write(j, x0.read(j).faer_sub(acc0a).faer_mul(u_inv)); + x1.write(j, x1.read(j).faer_sub(acc1a).faer_mul(u_inv)); + x2.write(j, x2.read(j).faer_sub(acc2a).faer_mul(u_inv)); } - 3 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - - for j in N.indices() { - let ui = u.row_indices_of_col_raw(j); - let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); - - let u_inv = ux.read(ui.len() - 1).faer_inv(); - let u_inv = if conj == Conj::Yes { - u_inv.faer_conj() + k = k2.next(); + } else if let Some(k1) = k1 { + let mut x = x.rb_mut().subcols_range_mut(k..k1.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { + panic!() + }; + + for j in N.indices() { + let ui = u.row_indices_of_col_raw(j); + let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); + + let u_inv = ux.read(ui.len() - 1).faer_inv(); + let u_inv = if conj == Conj::Yes { + u_inv.faer_conj() + } else { + u_inv + }; + + let mut acc0a = E::faer_zero(); + let mut acc0b = E::faer_zero(); + let mut acc1a = E::faer_zero(); + let mut acc1b = E::faer_zero(); + + let a = 0; + let b = 1; + + let rows_head = ui[..ui.len() - 1].chunks_exact(2); + let rows_tail = rows_head.remainder(); + let (values_head, values_tail) = ux.subslice(0..ui.len() - 1).into_chunks_exact(2); + + for (i, uij) in iter::zip(rows_head, values_head) { + let uija = uij.read(a); + let uijb = uij.read(b); + let uija = if conj == Conj::Yes { + uija.faer_conj() } else { - u_inv + uija }; + let uijb = if conj == Conj::Yes { + uijb.faer_conj() + } else { + uijb + }; + acc0a = acc0a.faer_add(uija.faer_mul(x0.read(i[a].zx()))); + acc0b = acc0b.faer_add(uijb.faer_mul(x0.read(i[b].zx()))); + acc1a = acc1a.faer_add(uija.faer_mul(x1.read(i[a].zx()))); + acc1b = acc1b.faer_add(uijb.faer_mul(x1.read(i[b].zx()))); + } - let mut acc0a = E::faer_zero(); - let mut acc1a = E::faer_zero(); - let mut acc2a = E::faer_zero(); - - let rows = &ui[..ui.len() - 1]; - let values = ux.subslice(0..ui.len() - 1); - - for (i, uij) in zip(rows, values.into_ref_iter()) { - let uija = uij.read(); - let uija = if conj == Conj::Yes { - uija.faer_conj() - } else { - uija - }; - acc0a = acc0a.faer_add(uija.faer_mul(x.read(i.zx(), k0))); - acc1a = acc1a.faer_add(uija.faer_mul(x.read(i.zx(), k1))); - acc2a = acc2a.faer_add(uija.faer_mul(x.read(i.zx(), k2))); - } - - x.write(j, k0, x.read(j, k0).faer_sub(acc0a).faer_mul(u_inv)); - x.write(j, k1, x.read(j, k1).faer_sub(acc1a).faer_mul(u_inv)); - x.write(j, k2, x.read(j, k2).faer_sub(acc2a).faer_mul(u_inv)); + for (i, uij) in iter::zip(rows_tail, values_tail.into_ref_iter()) { + let uija = uij.read(); + let uija = if conj == Conj::Yes { + uija.faer_conj() + } else { + uija + }; + acc0a = acc0a.faer_add(uija.faer_mul(x0.read(i.zx()))); + acc1a = acc1a.faer_add(uija.faer_mul(x1.read(i.zx()))); } + + x0.write( + j, + x0.read(j).faer_sub(acc0a.faer_add(acc0b)).faer_mul(u_inv), + ); + x1.write( + j, + x1.read(j).faer_sub(acc1a.faer_add(acc1b)).faer_mul(u_inv), + ); } - 4 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - let k3 = K.check(k + 3); - - for j in N.indices() { - let ui = u.row_indices_of_col_raw(j); - let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); - - let u_inv = ux.read(ui.len() - 1).faer_inv(); - let u_inv = if conj == Conj::Yes { - u_inv.faer_conj() + k = k1.next(); + } else { + let mut x0 = x.rb_mut().col_mut(k0); + + for j in N.indices() { + let ui = u.row_indices_of_col_raw(j); + let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); + + let u_inv = ux.read(ui.len() - 1).faer_inv(); + let u_inv = if conj == Conj::Yes { + u_inv.faer_conj() + } else { + u_inv + }; + + let mut acc0a = E::faer_zero(); + let mut acc0b = E::faer_zero(); + let mut acc0c = E::faer_zero(); + let mut acc0d = E::faer_zero(); + + let a = 0; + let b = 1; + let c = 2; + let d = 3; + + let rows_head = ui[..ui.len() - 1].chunks_exact(4); + let rows_tail = rows_head.remainder(); + let (values_head, values_tail) = ux.subslice(0..ui.len() - 1).into_chunks_exact(4); + + for (i, uij) in iter::zip(rows_head, values_head) { + let uija = uij.read(a); + let uijb = uij.read(b); + let uijc = uij.read(c); + let uijd = uij.read(d); + let uija = if conj == Conj::Yes { + uija.faer_conj() } else { - u_inv + uija }; + let uijb = if conj == Conj::Yes { + uijb.faer_conj() + } else { + uijb + }; + let uijc = if conj == Conj::Yes { + uijc.faer_conj() + } else { + uijc + }; + let uijd = if conj == Conj::Yes { + uijd.faer_conj() + } else { + uijd + }; + acc0a = acc0a.faer_add(uija.faer_mul(x0.read(i[a].zx()))); + acc0b = acc0b.faer_add(uijb.faer_mul(x0.read(i[b].zx()))); + acc0c = acc0c.faer_add(uijc.faer_mul(x0.read(i[c].zx()))); + acc0d = acc0d.faer_add(uijd.faer_mul(x0.read(i[d].zx()))); + } - let mut acc0a = E::faer_zero(); - let mut acc1a = E::faer_zero(); - let mut acc2a = E::faer_zero(); - let mut acc3a = E::faer_zero(); - - let rows = &ui[..ui.len() - 1]; - let values = ux.subslice(0..ui.len() - 1); - - for (i, uij) in zip(rows, values.into_ref_iter()) { - let uija = uij.read(); - let uija = if conj == Conj::Yes { - uija.faer_conj() - } else { - uija - }; - acc0a = acc0a.faer_add(uija.faer_mul(x.read(i.zx(), k0))); - acc1a = acc1a.faer_add(uija.faer_mul(x.read(i.zx(), k1))); - acc2a = acc2a.faer_add(uija.faer_mul(x.read(i.zx(), k2))); - acc3a = acc3a.faer_add(uija.faer_mul(x.read(i.zx(), k3))); - } - - x.write(j, k0, x.read(j, k0).faer_sub(acc0a).faer_mul(u_inv)); - x.write(j, k1, x.read(j, k1).faer_sub(acc1a).faer_mul(u_inv)); - x.write(j, k2, x.read(j, k2).faer_sub(acc2a).faer_mul(u_inv)); - x.write(j, k3, x.read(j, k3).faer_sub(acc3a).faer_mul(u_inv)); + for (i, uij) in iter::zip(rows_tail, values_tail.into_ref_iter()) { + let uija = uij.read(); + let uija = if conj == Conj::Yes { + uija.faer_conj() + } else { + uija + }; + acc0a = acc0a.faer_add(uija.faer_mul(x0.read(i.zx()))); } + + x0.write( + j, + x0.read(j) + .faer_sub(acc0a.faer_add(acc0b).faer_add(acc0c.faer_add(acc0d))) + .faer_mul(u_inv), + ); } - _ => unreachable!(), + k = k0.next(); } - k += bs; } } @@ -1372,130 +1396,135 @@ pub fn solve_unit_upper_triangular_in_place( with_dim!(N, rhs.nrows()); with_dim!(K, rhs.ncols()); - let mut x = crate::utils::constrained::mat::MatMut::new(rhs, N, K); - let u = crate::utils::constrained::sparse::SparseColMatRef::new(u, N, N); - - let mut k = 0usize; - while k < *K { - let bs = Ord::min(*K - k, 4); - match bs { - 1 => { - let k0 = K.check(k); - - for j in N.indices().rev() { - let ui = u.row_indices_of_col_raw(j); - let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); - - let xj = x.read(j, k0); - - for (i, u) in zip( - &ui[..ui.len() - 1], - ux.subslice(0..ui.len() - 1).into_ref_iter(), - ) { - let i = i.zx(); - let u = if conj == Conj::Yes { - u.read().faer_conj() - } else { - u.read() - }; - - x.write(i, k0, x.read(i, k0).faer_sub(E::faer_mul(u, xj))); - } + let mut x = rhs.as_shape_mut(N, K); + let u = u.as_shape(N, N); + + let mut k = bound::IdxInc::::zero(); + while let Some(k0) = K.try_check(*k) { + let k1 = K.try_check(*k + 1); + let k2 = K.try_check(*k + 2); + let k3 = K.try_check(*k + 3); + + if let Some(k3) = k3 { + let mut x = x.rb_mut().subcols_range_mut(k..k3.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = + (x.next(), x.next(), x.next(), x.next()) + else { + panic!() + }; + + for j in N.indices().rev() { + let ui = u.row_indices_of_col_raw(j); + let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); + + let xj0 = x0.read(j); + let xj1 = x1.read(j); + let xj2 = x2.read(j); + let xj3 = x3.read(j); + + for (i, u) in iter::zip( + &ui[..ui.len() - 1], + ux.subslice(0..ui.len() - 1).into_ref_iter(), + ) { + let i = i.zx(); + let u = if conj == Conj::Yes { + u.read().faer_conj() + } else { + u.read() + }; + + x0.write(i, x0.read(i).faer_sub(E::faer_mul(u, xj0))); + x1.write(i, x1.read(i).faer_sub(E::faer_mul(u, xj1))); + x2.write(i, x2.read(i).faer_sub(E::faer_mul(u, xj2))); + x3.write(i, x3.read(i).faer_sub(E::faer_mul(u, xj3))); } } - 2 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - - for j in N.indices().rev() { - let ui = u.row_indices_of_col_raw(j); - let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); - - let xj0 = x.read(j, k0); - let xj1 = x.read(j, k1); - - for (i, u) in zip( - &ui[..ui.len() - 1], - ux.subslice(0..ui.len() - 1).into_ref_iter(), - ) { - let i = i.zx(); - let u = if conj == Conj::Yes { - u.read().faer_conj() - } else { - u.read() - }; - - x.write(i, k0, x.read(i, k0).faer_sub(E::faer_mul(u, xj0))); - x.write(i, k1, x.read(i, k1).faer_sub(E::faer_mul(u, xj1))); - } + k = k3.next(); + } else if let Some(k2) = k2 { + let mut x = x.rb_mut().subcols_range_mut(k..k2.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else { + panic!() + }; + + for j in N.indices().rev() { + let ui = u.row_indices_of_col_raw(j); + let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); + + let xj0 = x0.read(j); + let xj1 = x1.read(j); + let xj2 = x2.read(j); + + for (i, u) in iter::zip( + &ui[..ui.len() - 1], + ux.subslice(0..ui.len() - 1).into_ref_iter(), + ) { + let i = i.zx(); + let u = if conj == Conj::Yes { + u.read().faer_conj() + } else { + u.read() + }; + + x0.write(i, x0.read(i).faer_sub(E::faer_mul(u, xj0))); + x1.write(i, x1.read(i).faer_sub(E::faer_mul(u, xj1))); + x2.write(i, x2.read(i).faer_sub(E::faer_mul(u, xj2))); } } - 3 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - - for j in N.indices().rev() { - let ui = u.row_indices_of_col_raw(j); - let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); - - let xj0 = x.read(j, k0); - let xj1 = x.read(j, k1); - let xj2 = x.read(j, k2); - - for (i, u) in zip( - &ui[..ui.len() - 1], - ux.subslice(0..ui.len() - 1).into_ref_iter(), - ) { - let i = i.zx(); - let u = if conj == Conj::Yes { - u.read().faer_conj() - } else { - u.read() - }; - - x.write(i, k0, x.read(i, k0).faer_sub(E::faer_mul(u, xj0))); - x.write(i, k1, x.read(i, k1).faer_sub(E::faer_mul(u, xj1))); - x.write(i, k2, x.read(i, k2).faer_sub(E::faer_mul(u, xj2))); - } + k = k2.next(); + } else if let Some(k1) = k1 { + let mut x = x.rb_mut().subcols_range_mut(k..k1.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { + panic!() + }; + + for j in N.indices().rev() { + let ui = u.row_indices_of_col_raw(j); + let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); + + let xj0 = x0.read(j); + let xj1 = x1.read(j); + + for (i, u) in iter::zip( + &ui[..ui.len() - 1], + ux.subslice(0..ui.len() - 1).into_ref_iter(), + ) { + let i = i.zx(); + let u = if conj == Conj::Yes { + u.read().faer_conj() + } else { + u.read() + }; + + x0.write(i, x0.read(i).faer_sub(E::faer_mul(u, xj0))); + x1.write(i, x1.read(i).faer_sub(E::faer_mul(u, xj1))); } } - 4 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - let k3 = K.check(k + 3); - - for j in N.indices().rev() { - let ui = u.row_indices_of_col_raw(j); - let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); - - let xj0 = x.read(j, k0); - let xj1 = x.read(j, k1); - let xj2 = x.read(j, k2); - let xj3 = x.read(j, k3); - - for (i, u) in zip( - &ui[..ui.len() - 1], - ux.subslice(0..ui.len() - 1).into_ref_iter(), - ) { - let i = i.zx(); - let u = if conj == Conj::Yes { - u.read().faer_conj() - } else { - u.read() - }; - - x.write(i, k0, x.read(i, k0).faer_sub(E::faer_mul(u, xj0))); - x.write(i, k1, x.read(i, k1).faer_sub(E::faer_mul(u, xj1))); - x.write(i, k2, x.read(i, k2).faer_sub(E::faer_mul(u, xj2))); - x.write(i, k3, x.read(i, k3).faer_sub(E::faer_mul(u, xj3))); - } + k = k1.next(); + } else { + let mut x0 = x.rb_mut().col_mut(k0); + + for j in N.indices().rev() { + let ui = u.row_indices_of_col_raw(j); + let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); + + let xj = x0.read(j); + + for (i, u) in iter::zip( + &ui[..ui.len() - 1], + ux.subslice(0..ui.len() - 1).into_ref_iter(), + ) { + let i = i.zx(); + let u = if conj == Conj::Yes { + u.read().faer_conj() + } else { + u.read() + }; + + x0.write(i, x0.read(i).faer_sub(E::faer_mul(u, xj))); } } - _ => unreachable!(), + k = k0.next(); } - k += bs; } } @@ -1519,243 +1548,244 @@ pub fn solve_unit_upper_triangular_transpose_in_place with_dim!(N, rhs.nrows()); with_dim!(K, rhs.ncols()); - let mut x = crate::utils::constrained::mat::MatMut::new(rhs, N, K); - let u = crate::utils::constrained::sparse::SparseColMatRef::new(u, N, N); - - let mut k = 0usize; - while k < *K { - let bs = Ord::min(*K - k, 4); - match bs { - 1 => { - let k0 = K.check(k); - for j in N.indices() { - let ui = u.row_indices_of_col_raw(j); - let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); - - let mut acc0a = E::faer_zero(); - let mut acc0b = E::faer_zero(); - let mut acc0c = E::faer_zero(); - let mut acc0d = E::faer_zero(); - - let a = 0; - let b = 1; - let c = 2; - let d = 3; - - let rows_head = ui[..ui.len() - 1].chunks_exact(4); - let rows_tail = rows_head.remainder(); - let (values_head, values_tail) = - ux.subslice(0..ui.len() - 1).into_chunks_exact(4); - - for (i, uij) in zip(rows_head, values_head) { - let uija = uij.read(a); - let uijb = uij.read(b); - let uijc = uij.read(c); - let uijd = uij.read(d); - let uija = if conj == Conj::Yes { - uija.faer_conj() - } else { - uija - }; - let uijb = if conj == Conj::Yes { - uijb.faer_conj() - } else { - uijb - }; - let uijc = if conj == Conj::Yes { - uijc.faer_conj() - } else { - uijc - }; - let uijd = if conj == Conj::Yes { - uijd.faer_conj() - } else { - uijd - }; - acc0a = acc0a.faer_add(uija.faer_mul(x.read(i[a].zx(), k0))); - acc0b = acc0b.faer_add(uijb.faer_mul(x.read(i[b].zx(), k0))); - acc0c = acc0c.faer_add(uijc.faer_mul(x.read(i[c].zx(), k0))); - acc0d = acc0d.faer_add(uijd.faer_mul(x.read(i[d].zx(), k0))); - } - - for (i, uij) in zip(rows_tail, values_tail.into_ref_iter()) { - let uija = uij.read(); - let uija = if conj == Conj::Yes { - uija.faer_conj() - } else { - uija - }; - acc0a = acc0a.faer_add(uija.faer_mul(x.read(i.zx(), k0))); - } - - x.write( - j, - k0, - x.read(j, k0) - .faer_sub(acc0a.faer_add(acc0b).faer_add(acc0c.faer_add(acc0d))), - ); + let mut x = rhs.as_shape_mut(N, K); + let u = u.as_shape(N, N); + + let mut k = bound::IdxInc::::zero(); + while let Some(k0) = K.try_check(*k) { + let k1 = K.try_check(*k + 1); + let k2 = K.try_check(*k + 2); + let k3 = K.try_check(*k + 3); + + if let Some(k3) = k3 { + let mut x = x.rb_mut().subcols_range_mut(k..k3.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = + (x.next(), x.next(), x.next(), x.next()) + else { + panic!() + }; + + for j in N.indices() { + let ui = u.row_indices_of_col_raw(j); + let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); + + let mut acc0a = E::faer_zero(); + let mut acc1a = E::faer_zero(); + let mut acc2a = E::faer_zero(); + let mut acc3a = E::faer_zero(); + + let a = 0; + + let rows_head = ui[..ui.len() - 1].chunks_exact(4); + let rows_tail = rows_head.remainder(); + let (values_head, values_tail) = ux.subslice(0..ui.len() - 1).into_chunks_exact(4); + + for (i, uij) in iter::zip(rows_head, values_head) { + let uija = uij.read(a); + let uija = if conj == Conj::Yes { + uija.faer_conj() + } else { + uija + }; + acc0a = acc0a.faer_add(uija.faer_mul(x0.read(i[a].zx()))); + acc1a = acc1a.faer_add(uija.faer_mul(x1.read(i[a].zx()))); + acc2a = acc2a.faer_add(uija.faer_mul(x2.read(i[a].zx()))); + acc3a = acc3a.faer_add(uija.faer_mul(x3.read(i[a].zx()))); + } + + for (i, uij) in iter::zip(rows_tail, values_tail.into_ref_iter()) { + let uija = uij.read(); + let uija = if conj == Conj::Yes { + uija.faer_conj() + } else { + uija + }; + acc0a = acc0a.faer_add(uija.faer_mul(x0.read(i.zx()))); + acc1a = acc1a.faer_add(uija.faer_mul(x1.read(i.zx()))); + acc2a = acc2a.faer_add(uija.faer_mul(x2.read(i.zx()))); + acc3a = acc3a.faer_add(uija.faer_mul(x3.read(i.zx()))); } + + x0.write(j, x0.read(j).faer_sub(acc0a)); + x1.write(j, x1.read(j).faer_sub(acc1a)); + x2.write(j, x2.read(j).faer_sub(acc2a)); + x3.write(j, x3.read(j).faer_sub(acc3a)); } - 2 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - - for j in N.indices() { - let ui = u.row_indices_of_col_raw(j); - let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); - - let mut acc0a = E::faer_zero(); - let mut acc0b = E::faer_zero(); - let mut acc1a = E::faer_zero(); - let mut acc1b = E::faer_zero(); - - let a = 0; - let b = 1; - - let rows_head = ui[..ui.len() - 1].chunks_exact(4); - let rows_tail = rows_head.remainder(); - let (values_head, values_tail) = - ux.subslice(0..ui.len() - 1).into_chunks_exact(4); - - for (i, uij) in zip(rows_head, values_head) { - let uija = uij.read(a); - let uijb = uij.read(b); - let uija = if conj == Conj::Yes { - uija.faer_conj() - } else { - uija - }; - let uijb = if conj == Conj::Yes { - uijb.faer_conj() - } else { - uijb - }; - acc0a = acc0a.faer_add(uija.faer_mul(x.read(i[a].zx(), k0))); - acc0b = acc0b.faer_add(uijb.faer_mul(x.read(i[b].zx(), k0))); - acc1a = acc1a.faer_add(uija.faer_mul(x.read(i[a].zx(), k1))); - acc1b = acc1b.faer_add(uijb.faer_mul(x.read(i[b].zx(), k1))); - } - - for (i, uij) in zip(rows_tail, values_tail.into_ref_iter()) { - let uija = uij.read(); - let uija = if conj == Conj::Yes { - uija.faer_conj() - } else { - uija - }; - acc0a = acc0a.faer_add(uija.faer_mul(x.read(i.zx(), k0))); - acc1a = acc1a.faer_add(uija.faer_mul(x.read(i.zx(), k1))); - } - - x.write(j, k0, x.read(j, k0).faer_sub(acc0a.faer_add(acc0b))); - x.write(j, k1, x.read(j, k1).faer_sub(acc1a.faer_add(acc1b))); + k = k3.next(); + } else if let Some(k2) = k2 { + let mut x = x.rb_mut().subcols_range_mut(k..k2.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else { + panic!() + }; + + for j in N.indices() { + let ui = u.row_indices_of_col_raw(j); + let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); + + let mut acc0a = E::faer_zero(); + let mut acc1a = E::faer_zero(); + let mut acc2a = E::faer_zero(); + + let a = 0; + + let rows_head = ui[..ui.len() - 1].chunks_exact(4); + let rows_tail = rows_head.remainder(); + let (values_head, values_tail) = ux.subslice(0..ui.len() - 1).into_chunks_exact(4); + + for (i, uij) in iter::zip(rows_head, values_head) { + let uija = uij.read(a); + let uija = if conj == Conj::Yes { + uija.faer_conj() + } else { + uija + }; + acc0a = acc0a.faer_add(uija.faer_mul(x0.read(i[a].zx()))); + acc1a = acc1a.faer_add(uija.faer_mul(x1.read(i[a].zx()))); + acc2a = acc2a.faer_add(uija.faer_mul(x2.read(i[a].zx()))); } + + for (i, uij) in iter::zip(rows_tail, values_tail.into_ref_iter()) { + let uija = uij.read(); + let uija = if conj == Conj::Yes { + uija.faer_conj() + } else { + uija + }; + acc0a = acc0a.faer_add(uija.faer_mul(x0.read(i.zx()))); + acc1a = acc1a.faer_add(uija.faer_mul(x1.read(i.zx()))); + acc2a = acc2a.faer_add(uija.faer_mul(x2.read(i.zx()))); + } + + x0.write(j, x0.read(j).faer_sub(acc0a)); + x1.write(j, x1.read(j).faer_sub(acc1a)); + x2.write(j, x2.read(j).faer_sub(acc2a)); } - 3 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - - for j in N.indices() { - let ui = u.row_indices_of_col_raw(j); - let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); - - let mut acc0a = E::faer_zero(); - let mut acc1a = E::faer_zero(); - let mut acc2a = E::faer_zero(); - - let a = 0; - - let rows_head = ui[..ui.len() - 1].chunks_exact(4); - let rows_tail = rows_head.remainder(); - let (values_head, values_tail) = - ux.subslice(0..ui.len() - 1).into_chunks_exact(4); - - for (i, uij) in zip(rows_head, values_head) { - let uija = uij.read(a); - let uija = if conj == Conj::Yes { - uija.faer_conj() - } else { - uija - }; - acc0a = acc0a.faer_add(uija.faer_mul(x.read(i[a].zx(), k0))); - acc1a = acc1a.faer_add(uija.faer_mul(x.read(i[a].zx(), k1))); - acc2a = acc2a.faer_add(uija.faer_mul(x.read(i[a].zx(), k2))); - } - - for (i, uij) in zip(rows_tail, values_tail.into_ref_iter()) { - let uija = uij.read(); - let uija = if conj == Conj::Yes { - uija.faer_conj() - } else { - uija - }; - acc0a = acc0a.faer_add(uija.faer_mul(x.read(i.zx(), k0))); - acc1a = acc1a.faer_add(uija.faer_mul(x.read(i.zx(), k1))); - acc2a = acc2a.faer_add(uija.faer_mul(x.read(i.zx(), k2))); - } - - x.write(j, k0, x.read(j, k0).faer_sub(acc0a)); - x.write(j, k1, x.read(j, k1).faer_sub(acc1a)); - x.write(j, k2, x.read(j, k2).faer_sub(acc2a)); + k = k2.next(); + } else if let Some(k1) = k1 { + let mut x = x.rb_mut().subcols_range_mut(k..k1.next()).col_iter_mut(); + let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { + panic!() + }; + + for j in N.indices() { + let ui = u.row_indices_of_col_raw(j); + let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); + + let mut acc0a = E::faer_zero(); + let mut acc0b = E::faer_zero(); + let mut acc1a = E::faer_zero(); + let mut acc1b = E::faer_zero(); + + let a = 0; + let b = 1; + + let rows_head = ui[..ui.len() - 1].chunks_exact(4); + let rows_tail = rows_head.remainder(); + let (values_head, values_tail) = ux.subslice(0..ui.len() - 1).into_chunks_exact(4); + + for (i, uij) in iter::zip(rows_head, values_head) { + let uija = uij.read(a); + let uijb = uij.read(b); + let uija = if conj == Conj::Yes { + uija.faer_conj() + } else { + uija + }; + let uijb = if conj == Conj::Yes { + uijb.faer_conj() + } else { + uijb + }; + acc0a = acc0a.faer_add(uija.faer_mul(x0.read(i[a].zx()))); + acc0b = acc0b.faer_add(uijb.faer_mul(x0.read(i[b].zx()))); + acc1a = acc1a.faer_add(uija.faer_mul(x1.read(i[a].zx()))); + acc1b = acc1b.faer_add(uijb.faer_mul(x1.read(i[b].zx()))); } + + for (i, uij) in iter::zip(rows_tail, values_tail.into_ref_iter()) { + let uija = uij.read(); + let uija = if conj == Conj::Yes { + uija.faer_conj() + } else { + uija + }; + acc0a = acc0a.faer_add(uija.faer_mul(x0.read(i.zx()))); + acc1a = acc1a.faer_add(uija.faer_mul(x1.read(i.zx()))); + } + + x0.write(j, x0.read(j).faer_sub(acc0a.faer_add(acc0b))); + x1.write(j, x1.read(j).faer_sub(acc1a.faer_add(acc1b))); } - 4 => { - let k0 = K.check(k); - let k1 = K.check(k + 1); - let k2 = K.check(k + 2); - let k3 = K.check(k + 3); - - for j in N.indices() { - let ui = u.row_indices_of_col_raw(j); - let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); - - let mut acc0a = E::faer_zero(); - let mut acc1a = E::faer_zero(); - let mut acc2a = E::faer_zero(); - let mut acc3a = E::faer_zero(); - - let a = 0; - - let rows_head = ui[..ui.len() - 1].chunks_exact(4); - let rows_tail = rows_head.remainder(); - let (values_head, values_tail) = - ux.subslice(0..ui.len() - 1).into_chunks_exact(4); - - for (i, uij) in zip(rows_head, values_head) { - let uija = uij.read(a); - let uija = if conj == Conj::Yes { - uija.faer_conj() - } else { - uija - }; - acc0a = acc0a.faer_add(uija.faer_mul(x.read(i[a].zx(), k0))); - acc1a = acc1a.faer_add(uija.faer_mul(x.read(i[a].zx(), k1))); - acc2a = acc2a.faer_add(uija.faer_mul(x.read(i[a].zx(), k2))); - acc3a = acc3a.faer_add(uija.faer_mul(x.read(i[a].zx(), k3))); - } - - for (i, uij) in zip(rows_tail, values_tail.into_ref_iter()) { - let uija = uij.read(); - let uija = if conj == Conj::Yes { - uija.faer_conj() - } else { - uija - }; - acc0a = acc0a.faer_add(uija.faer_mul(x.read(i.zx(), k0))); - acc1a = acc1a.faer_add(uija.faer_mul(x.read(i.zx(), k1))); - acc2a = acc2a.faer_add(uija.faer_mul(x.read(i.zx(), k2))); - acc3a = acc3a.faer_add(uija.faer_mul(x.read(i.zx(), k3))); - } - - x.write(j, k0, x.read(j, k0).faer_sub(acc0a)); - x.write(j, k1, x.read(j, k1).faer_sub(acc1a)); - x.write(j, k2, x.read(j, k2).faer_sub(acc2a)); - x.write(j, k3, x.read(j, k3).faer_sub(acc3a)); + k = k1.next(); + } else { + let mut x0 = x.rb_mut().col_mut(k0); + + for j in N.indices() { + let ui = u.row_indices_of_col_raw(j); + let ux = SliceGroup::<'_, E>::new(u.values_of_col(j)); + + let mut acc0a = E::faer_zero(); + let mut acc0b = E::faer_zero(); + let mut acc0c = E::faer_zero(); + let mut acc0d = E::faer_zero(); + + let a = 0; + let b = 1; + let c = 2; + let d = 3; + + let rows_head = ui[..ui.len() - 1].chunks_exact(4); + let rows_tail = rows_head.remainder(); + let (values_head, values_tail) = ux.subslice(0..ui.len() - 1).into_chunks_exact(4); + + for (i, uij) in iter::zip(rows_head, values_head) { + let uija = uij.read(a); + let uijb = uij.read(b); + let uijc = uij.read(c); + let uijd = uij.read(d); + let uija = if conj == Conj::Yes { + uija.faer_conj() + } else { + uija + }; + let uijb = if conj == Conj::Yes { + uijb.faer_conj() + } else { + uijb + }; + let uijc = if conj == Conj::Yes { + uijc.faer_conj() + } else { + uijc + }; + let uijd = if conj == Conj::Yes { + uijd.faer_conj() + } else { + uijd + }; + acc0a = acc0a.faer_add(uija.faer_mul(x0.read(i[a].zx()))); + acc0b = acc0b.faer_add(uijb.faer_mul(x0.read(i[b].zx()))); + acc0c = acc0c.faer_add(uijc.faer_mul(x0.read(i[c].zx()))); + acc0d = acc0d.faer_add(uijd.faer_mul(x0.read(i[d].zx()))); + } + + for (i, uij) in iter::zip(rows_tail, values_tail.into_ref_iter()) { + let uija = uij.read(); + let uija = if conj == Conj::Yes { + uija.faer_conj() + } else { + uija + }; + acc0a = acc0a.faer_add(uija.faer_mul(x0.read(i.zx()))); } + + x0.write( + j, + x0.read(j) + .faer_sub(acc0a.faer_add(acc0b).faer_add(acc0c.faer_add(acc0d))), + ); } - _ => unreachable!(), + k = k0.next(); } - k += bs; } } diff --git a/src/sparse/mod.rs b/src/sparse/mod.rs index ca96b619..1acf3e77 100644 --- a/src/sparse/mod.rs +++ b/src/sparse/mod.rs @@ -35,7 +35,7 @@ use faer_entity::*; use reborrow::*; mod ghost { - pub use crate::utils::constrained::{perm::*, sparse::*, *}; + pub use crate::utils::bound::*; } const TOP_BIT: usize = 1usize << (usize::BITS - 1); @@ -224,7 +224,7 @@ pub use csr::*; /// Useful sparse matrix primitives. pub mod utils { use super::*; - use crate::{assert, debug_assert}; + use crate::{assert, debug_assert, perm::PermRef, utils::bound::Dim}; /// Sorts `row_indices` and `values` simultaneously so that `row_indices` is nonincreasing. pub fn sort_indices( @@ -301,19 +301,19 @@ pub mod utils { new_values: SliceGroupMut<'out, E>, new_col_ptrs: &'out mut [I], new_row_indices: &'out mut [I], - A: ghost::SparseColMatRef<'n, 'n, '_, I, E>, - perm: ghost::PermRef<'n, '_, I>, + A: SparseColMatRef<'_, I, E, Dim<'n>, Dim<'n>>, + perm: PermRef<'_, I, Dim<'n>>, in_side: Side, out_side: Side, sort: bool, stack: &mut PodStack, - ) -> ghost::SparseColMatMut<'n, 'n, 'out, I, E> { + ) -> SparseColMatMut<'out, I, E, Dim<'n>, Dim<'n>> { let N = A.ncols(); let n = *A.ncols(); // (1) assert!(new_col_ptrs.len() == n + 1); - let (_, perm_inv) = perm.arrays(); + let (_, perm_inv) = perm.bound_arrays(); let (current_row_position, _) = stack.make_raw::(n); let current_row_position = ghost::Array::from_mut(current_row_position, N); @@ -538,20 +538,11 @@ pub mod utils { // 1. new_values.len() == new_row_indices.len() // 2. all written row indices are less than n unsafe { - ghost::SparseColMatMut::new( - SparseColMatMut::new( - SymbolicSparseColMatRef::new_unchecked( - n, - n, - new_col_ptrs, - None, - new_row_indices, - ), - new_values.into_inner(), - ), - N, - N, + SparseColMatMut::new( + SymbolicSparseColMatRef::new_unchecked(n, n, new_col_ptrs, None, new_row_indices), + new_values.into_inner(), ) + .as_shape_mut(N, N) } } @@ -559,29 +550,26 @@ pub mod utils { pub unsafe fn ghost_permute_hermitian_unsorted_symbolic<'n, 'out, I: Index>( new_col_ptrs: &'out mut [I], new_row_indices: &'out mut [I], - A: ghost::SymbolicSparseColMatRef<'n, 'n, '_, I>, - perm: ghost::PermRef<'n, '_, I>, + A: SymbolicSparseColMatRef<'_, I, Dim<'n>, Dim<'n>>, + perm: PermRef<'_, I, Dim<'n>>, in_side: Side, out_side: Side, stack: &mut PodStack, - ) -> ghost::SymbolicSparseColMatRef<'n, 'n, 'out, I> { - let old_values = &*Symbolic::materialize(A.into_inner().row_indices().len()); + ) -> SymbolicSparseColMatRef<'out, I, Dim<'n>, Dim<'n>> { + let old_values = &*Symbolic::materialize(A.row_indices().len()); let new_values = Symbolic::materialize(new_row_indices.len()); - *ghost_permute_hermitian_unsorted( + ghost_permute_hermitian_unsorted( SliceGroupMut::<'_, Symbolic>::new(new_values), new_col_ptrs, new_row_indices, - ghost::SparseColMatRef::new( - SparseColMatRef::new(A.into_inner(), old_values), - A.nrows(), - A.ncols(), - ), + SparseColMatRef::new(A, old_values), perm, in_side, out_side, false, stack, ) + .symbolic() } /// Computes the self-adjoint permutation $P A P^\top$ of the matrix `A` without sorting the row @@ -605,14 +593,14 @@ pub mod utils { SliceGroupMut::new(new_values), new_col_ptrs, new_row_indices, - ghost::SparseColMatRef::new(A, N, N), - ghost::PermRef::new(perm, N), + A.as_shape(N, N), + perm.as_shape(N), in_side, out_side, false, stack, ) - .into_inner() + .as_shape_mut(A.nrows(), A.nrows()) } /// Computes the size and alignment of the workspace required to compute a two-sided permutation @@ -646,37 +634,34 @@ pub mod utils { SliceGroupMut::new(new_values), new_col_ptrs, new_row_indices, - ghost::SparseColMatRef::new(A, N, N), - ghost::PermRef::new(perm, N), + A.as_shape(N, N), + perm.as_shape(N), in_side, out_side, true, stack, ) } - .into_inner() + .as_shape_mut(A.nrows(), A.nrows()) } #[doc(hidden)] pub fn ghost_adjoint_symbolic<'m, 'n, 'a, I: Index>( new_col_ptrs: &'a mut [I], new_row_indices: &'a mut [I], - A: ghost::SymbolicSparseColMatRef<'m, 'n, '_, I>, + A: SymbolicSparseColMatRef<'_, I, Dim<'m>, Dim<'n>>, stack: &mut PodStack, - ) -> ghost::SymbolicSparseColMatRef<'n, 'm, 'a, I> { - let old_values = &*Symbolic::materialize(A.into_inner().row_indices().len()); + ) -> SymbolicSparseColMatRef<'a, I, Dim<'n>, Dim<'m>> { + let old_values = &*Symbolic::materialize(A.row_indices().len()); let new_values = Symbolic::materialize(new_row_indices.len()); - *ghost_adjoint( + ghost_adjoint( new_col_ptrs, new_row_indices, SliceGroupMut::<'_, Symbolic>::new(new_values), - ghost::SparseColMatRef::new( - SparseColMatRef::new(A.into_inner(), old_values), - A.nrows(), - A.ncols(), - ), + SparseColMatRef::new(A, old_values), stack, ) + .symbolic() } #[doc(hidden)] @@ -684,9 +669,9 @@ pub mod utils { new_col_ptrs: &'a mut [I], new_row_indices: &'a mut [I], new_values: SliceGroupMut<'a, E>, - A: ghost::SparseColMatRef<'m, 'n, '_, I, E>, + A: SparseColMatRef<'_, I, E, Dim<'m>, Dim<'n>>, stack: &mut PodStack, - ) -> ghost::SparseColMatMut<'n, 'm, 'a, I, E> { + ) -> SparseColMatMut<'a, I, E, Dim<'n>, Dim<'m>> { let M = A.nrows(); let N = A.ncols(); assert!(new_col_ptrs.len() == *M + 1); @@ -743,22 +728,12 @@ pub mod utils { // SAFETY: // 0. new_col_ptrs is non-decreasing // 1. all written row indices are less than n - ghost::SparseColMatMut::new( - unsafe { - SparseColMatMut::new( - SymbolicSparseColMatRef::new_unchecked( - *N, - *M, - new_col_ptrs, - None, - new_row_indices, - ), - new_values.into_inner(), - ) - }, - N, - M, - ) + unsafe { + SparseColMatMut::new( + SymbolicSparseColMatRef::new_unchecked(N, M, new_col_ptrs, None, new_row_indices), + new_values.into_inner(), + ) + } } #[doc(hidden)] @@ -766,9 +741,9 @@ pub mod utils { new_col_ptrs: &'a mut [I], new_row_indices: &'a mut [I], new_values: SliceGroupMut<'a, E>, - A: ghost::SparseColMatRef<'m, 'n, '_, I, E>, + A: SparseColMatRef<'_, I, E, Dim<'m>, Dim<'n>>, stack: &mut PodStack, - ) -> ghost::SparseColMatMut<'n, 'm, 'a, I, E> { + ) -> SparseColMatMut<'a, I, E, Dim<'n>, Dim<'m>> { let M = A.nrows(); let N = A.ncols(); assert!(new_col_ptrs.len() == *M + 1); @@ -825,22 +800,12 @@ pub mod utils { // SAFETY: // 0. new_col_ptrs is non-decreasing // 1. all written row indices are less than n - ghost::SparseColMatMut::new( - unsafe { - SparseColMatMut::new( - SymbolicSparseColMatRef::new_unchecked( - *N, - *M, - new_col_ptrs, - None, - new_row_indices, - ), - new_values.into_inner(), - ) - }, - N, - M, - ) + unsafe { + SparseColMatMut::new( + SymbolicSparseColMatRef::new_unchecked(N, M, new_col_ptrs, None, new_row_indices), + new_values.into_inner(), + ) + } } /// Computes the transpose of the matrix `A` and returns a view over it. @@ -856,22 +821,20 @@ pub mod utils { A: SparseColMatRef<'_, I, E, R, C>, stack: &mut PodStack, ) -> SparseColMatMut<'a, I, E, C, R> { - let M = A.nrows(); - let N = A.ncols(); + let m = A.nrows(); + let n = A.ncols(); let A = A.as_dyn(); - { - with_dim!(M, A.nrows()); - with_dim!(N, A.ncols()); - ghost_transpose( - new_col_ptrs, - new_row_indices, - SliceGroupMut::new(new_values), - ghost::SparseColMatRef::new(A, M, N), - stack, - ) - .into_inner() - } - .as_shape_mut(N, M) + + with_dim!(M, A.nrows()); + with_dim!(N, A.ncols()); + ghost_transpose( + new_col_ptrs, + new_row_indices, + SliceGroupMut::new(new_values), + A.as_shape(M, N), + stack, + ) + .as_shape_mut(n, m) } /// Computes the adjoint of the matrix `A` and returns a view over it. @@ -887,22 +850,20 @@ pub mod utils { A: SparseColMatRef<'_, I, E, R, C>, stack: &mut PodStack, ) -> SparseColMatMut<'a, I, E, C, R> { - let M = A.nrows(); - let N = A.ncols(); + let m = A.nrows(); + let n = A.ncols(); let A = A.as_dyn(); - { - with_dim!(M, A.nrows()); - with_dim!(N, A.ncols()); - ghost_adjoint( - new_col_ptrs, - new_row_indices, - SliceGroupMut::new(new_values), - ghost::SparseColMatRef::new(A, M, N), - stack, - ) - .into_inner() - } - .as_shape_mut(N, M) + + with_dim!(M, A.nrows()); + with_dim!(N, A.ncols()); + ghost_adjoint( + new_col_ptrs, + new_row_indices, + SliceGroupMut::new(new_values), + A.as_shape(M, N), + stack, + ) + .as_shape_mut(n, m) } /// Computes the adjoint of the symbolic matrix `A` and returns a view over it. @@ -917,21 +878,14 @@ pub mod utils { A: SymbolicSparseColMatRef<'_, I, R, C>, stack: &mut PodStack, ) -> SymbolicSparseColMatRef<'a, I, C, R> { - let M = A.nrows(); - let N = A.ncols(); + let m = A.nrows(); + let n = A.ncols(); let A = A.as_dyn(); - { - with_dim!(M, A.nrows()); - with_dim!(N, A.ncols()); - ghost_adjoint_symbolic( - new_col_ptrs, - new_row_indices, - ghost::SymbolicSparseColMatRef::new(A, M, N), - stack, - ) - .into_inner() - } - .as_shape(N, M) + + with_dim!(M, A.nrows()); + with_dim!(N, A.ncols()); + ghost_adjoint_symbolic(new_col_ptrs, new_row_indices, A.as_shape(M, N), stack) + .as_shape(n, m) } } diff --git a/src/utils/bound.rs b/src/utils/bound.rs index fd2fd98e..bfd652c3 100644 --- a/src/utils/bound.rs +++ b/src/utils/bound.rs @@ -1,7 +1,8 @@ -use core::{fmt, marker::PhantomData}; +use crate::{utils::slice, Entity, Index, Shape, ShapeIdx, SignedIndex, Unbind}; +use core::{fmt, marker::PhantomData, ops::Range}; +use faer_entity::*; use generativity::Guard; - -use crate::{Index, Shape, ShapeIdx, SignedIndex, Unbind}; +use reborrow::*; type Invariant<'a> = fn(&'a ()) -> &'a (); @@ -252,7 +253,7 @@ impl<'n, I: Index> Idx<'n, I> { /// Zero-extends the internal value into a `usize`. #[inline(always)] - pub fn zx(self) -> Idx<'n, usize> { + pub fn zx(self) -> Idx<'n> { Idx { unbound: self.unbound.zx(), __marker: PhantomData, @@ -261,6 +262,17 @@ impl<'n, I: Index> Idx<'n, I> { } impl<'n, I: Index> IdxInc<'n, I> { + /// Create new branded value with an arbitrary brand. + /// # Safety + /// See struct safety invariant. + #[inline(always)] + pub const fn zero() -> Self { + Self { + unbound: unsafe { core::mem::zeroed() }, + __marker: PhantomData, + } + } + /// Create new branded value with an arbitrary brand. /// # Safety /// See struct safety invariant. @@ -312,7 +324,7 @@ impl<'n, I: Index> IdxInc<'n, I> { /// Zero-extends the internal value into a `usize`. #[inline(always)] - pub fn zx(self) -> IdxInc<'n, usize> { + pub fn zx(self) -> IdxInc<'n> { IdxInc { unbound: self.unbound.zx(), __marker: PhantomData, @@ -392,36 +404,6 @@ impl<'n, I: Index> From> for IdxInc<'n, I> { } impl<'size> Dim<'size> { - #[deprecated] - /// Create a new [`Dim`] with a lifetime tied to `n`. - #[track_caller] - #[inline] - pub fn with(n: usize, f: impl for<'n> FnOnce(Dim<'n>) -> R) -> R { - f(unsafe { Dim::new_unbound(n) }) - } - - #[deprecated] - /// Create two new [`Dim`] with lifetimes tied to `m` and `n`. - #[track_caller] - #[inline] - pub fn with2(m: usize, n: usize, f: impl for<'m, 'n> FnOnce(Dim<'m>, Dim<'n>) -> R) -> R { - unsafe { f(Dim::new_unbound(m), Dim::new_unbound(n)) } - } - - #[deprecated] - /// Create a new [`Dim`] tied to the lifetime `'n`. - #[inline] - pub unsafe fn new_raw_unchecked(n: usize) -> Self { - Dim::new_unbound(n) - } - - #[deprecated] - /// Returns the unconstrained value. - #[inline] - pub fn into_inner(self) -> usize { - self.unbound - } - /// Check that the index is bounded by `self`, or panic otherwise. #[track_caller] #[inline] @@ -429,7 +411,6 @@ impl<'size> Dim<'size> { Idx::new_checked(idx, self) } - #[deprecated] /// Check that the index is bounded by `self`, or return `None` otherwise. #[inline] pub fn try_check(self, idx: I) -> Option> { @@ -441,7 +422,7 @@ impl<'size> Dim<'size> { } } -impl<'n> Idx<'n, usize> { +impl<'n> Idx<'n> { /// Truncate `self` to a smaller type `I`. pub fn truncate(self) -> Idx<'n, I> { unsafe { Idx::new_unbound(I::truncate(self.unbound())) } @@ -449,21 +430,6 @@ impl<'n> Idx<'n, usize> { } impl<'n, I: Index> Idx<'n, I> { - /// Returns a new index without asserting that it's bounded by the value tied to the - /// lifetime `'n`. - #[inline] - #[deprecated] - pub unsafe fn new_raw_unchecked(idx: I) -> Self { - Self::new_unbound(idx) - } - - /// Returns the unconstrained value. - #[inline] - #[deprecated] - pub fn into_inner(self) -> I { - self.unbound - } - /// Returns the index, bounded inclusively by the value tied to `'n`. #[inline] pub const fn to_inclusive(self) -> IdxInc<'n, I> { @@ -507,27 +473,10 @@ impl<'n, I: Index> Idx<'n, I> { } } -impl<'n, I: Index> IdxInc<'n, I> { - /// Returns a constrained inclusive index, assuming that it's bounded (inclusively) by - /// the size tied to `'n`. - #[inline] - #[deprecated] - pub unsafe fn new_raw_unchecked(idx: I) -> Self { - Self::new_unbound(idx) - } - - /// Returns the unconstrained value. - #[inline] - #[deprecated] - pub fn into_inner(self) -> I { - self.unbound - } -} - /// `I` value smaller than the size corresponding to the lifetime `'n`, or `None`. #[derive(Copy, Clone, PartialEq, Eq)] #[repr(transparent)] -pub struct MaybeIdx<'n, I: Index> { +pub struct MaybeIdx<'n, I: Index = usize> { unbound: I, __marker: PhantomData>, } @@ -564,16 +513,6 @@ impl<'n, I: Index> MaybeIdx<'n, I> { } } - /// Returns a constrained index value if `idx` is nonnegative, `None` otherwise. - #[deprecated] - #[inline] - pub unsafe fn new_raw_unchecked(idx: I) -> Self { - Self { - unbound: idx, - __marker: PhantomData, - } - } - /// Returns a constrained index value if `idx` is nonnegative, `None` otherwise. #[inline] pub unsafe fn new_unbound(idx: I) -> Self { @@ -583,13 +522,6 @@ impl<'n, I: Index> MaybeIdx<'n, I> { } } - /// Returns the inner value. - #[inline] - #[deprecated] - pub fn into_inner(self) -> I { - self.unbound - } - /// Returns the inner value. #[inline] pub fn unbound(self) -> I { @@ -608,7 +540,7 @@ impl<'n, I: Index> MaybeIdx<'n, I> { /// Sign extend the value. #[inline] - pub fn sx(self) -> MaybeIdx<'n, usize> { + pub fn sx(self) -> MaybeIdx<'n> { unsafe { MaybeIdx::new_unbound(self.unbound.to_signed().sx()) } } @@ -686,3 +618,270 @@ impl core::ops::Deref for IdxInc<'_, I> { &self.unbound } } + +/// Array of length equal to the value tied to `'n`. +#[derive(PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +pub struct Array<'n, T> { + __marker: PhantomData>, + unbound: [T], +} + +/// Immutable array group of length equal to the value tied to `'n`. +pub struct ArrayGroup<'n, 'a, E: Entity> { + __marker: PhantomData>, + unbound: slice::SliceGroup<'a, E>, +} +/// Mutable array group of length equal to the value tied to `'n`. +pub struct ArrayGroupMut<'n, 'a, E: Entity> { + __marker: PhantomData>, + unbound: slice::SliceGroupMut<'a, E>, +} + +impl core::fmt::Debug for ArrayGroup<'_, '_, E> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.unbound.fmt(f) + } +} +impl core::fmt::Debug for ArrayGroupMut<'_, '_, E> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.unbound.fmt(f) + } +} + +impl Copy for ArrayGroup<'_, '_, E> {} +impl Clone for ArrayGroup<'_, '_, E> { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +impl<'short, 'n, 'a, E: Entity> reborrow::ReborrowMut<'short> for ArrayGroup<'n, 'a, E> { + type Target = ArrayGroup<'n, 'short, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + *self + } +} + +impl<'short, 'n, 'a, E: Entity> reborrow::Reborrow<'short> for ArrayGroup<'n, 'a, E> { + type Target = ArrayGroup<'n, 'short, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + *self + } +} + +impl<'short, 'n, 'a, E: Entity> reborrow::ReborrowMut<'short> for ArrayGroupMut<'n, 'a, E> { + type Target = ArrayGroupMut<'n, 'short, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + ArrayGroupMut { + __marker: PhantomData, + unbound: self.unbound.rb_mut(), + } + } +} + +impl<'short, 'n, 'a, E: Entity> reborrow::Reborrow<'short> for ArrayGroupMut<'n, 'a, E> { + type Target = ArrayGroup<'n, 'short, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + ArrayGroup { + __marker: PhantomData, + unbound: self.unbound.rb(), + } + } +} + +impl<'n, 'a, E: Entity> ArrayGroupMut<'n, 'a, E> { + /// Returns an array group with length after checking that its length matches + /// the value tied to `'n`. + #[inline] + pub fn new(slice: GroupFor, len: Dim<'n>) -> Self { + let slice = slice::SliceGroupMut::<'_, E>::new(slice); + assert!(slice.len() == len.unbound()); + ArrayGroupMut { + __marker: PhantomData, + unbound: slice, + } + } + + /// Returns the unconstrained slice. + #[inline] + pub fn into_slice(self) -> GroupFor { + self.unbound.into_inner() + } + + /// Returns a subslice at from the range start to its end. + #[inline] + pub fn subslice(self, range: Range>) -> GroupFor { + unsafe { + slice::SliceGroupMut::<'_, E>::new(self.into_slice()) + .subslice_unchecked(range.start.unbound()..range.end.unbound()) + .into_inner() + } + } + + /// Read the element at position `j`. + #[inline] + pub fn read(&self, j: Idx<'n>) -> E { + self.rb().read(j) + } + + /// Write `value` to the location at position `j`. + #[inline] + pub fn write(&mut self, j: Idx<'n>, value: E) { + unsafe { + slice::SliceGroupMut::new(self.rb_mut().into_slice()) + .write_unchecked(j.unbound(), value) + } + } +} + +impl<'n, 'a, E: Entity> ArrayGroup<'n, 'a, E> { + /// Returns an array group with length after checking that its length matches + /// the value tied to `'n`. + #[inline] + pub fn new(slice: GroupFor, len: Dim<'n>) -> Self { + let slice = slice::SliceGroup::<'_, E>::new(slice); + assert!(slice.rb().len() == len.unbound()); + ArrayGroup { + __marker: PhantomData, + unbound: slice, + } + } + + /// Returns the unconstrained slice. + #[inline] + pub fn into_slice(self) -> GroupFor { + self.unbound.into_inner() + } + + /// Returns a subslice at from the range start to its end. + #[inline] + pub fn subslice(self, range: Range>) -> GroupFor { + unsafe { + slice::SliceGroup::<'_, E>::new(self.into_slice()) + .subslice_unchecked(range.start.unbound()..range.end.unbound()) + .into_inner() + } + } + + /// Read the element at position `j`. + #[inline] + pub fn read(&self, j: Idx<'n>) -> E { + unsafe { slice::SliceGroup::new(self.into_slice()).read_unchecked(j.unbound()) } + } +} + +impl<'n, T> Array<'n, T> { + /// Returns a constrained array after checking that its length matches `size`. + #[inline] + #[track_caller] + pub fn from_ref<'a>(slice: &'a [T], size: Dim<'n>) -> &'a Self { + assert!(slice.len() == size.unbound()); + unsafe { &*(slice as *const [T] as *const Self) } + } + + /// Returns a constrained array after checking that its length matches `size`. + #[inline] + #[track_caller] + pub fn from_mut<'a>(slice: &'a mut [T], size: Dim<'n>) -> &'a mut Self { + assert!(slice.len() == size.unbound()); + unsafe { &mut *(slice as *mut [T] as *mut Self) } + } + + /// Returns the unconstrained slice. + #[inline] + #[track_caller] + pub fn as_ref(&self) -> &[T] { + unsafe { &*(self as *const _ as *const _) } + } + + /// Returns the unconstrained slice. + #[inline] + #[track_caller] + pub fn as_mut<'a>(&mut self) -> &'a mut [T] { + unsafe { &mut *(self as *mut _ as *mut _) } + } + + /// Returns the length of `self`. + #[inline] + pub fn len(&self) -> Dim<'n> { + unsafe { Dim::new_unbound(self.unbound.len()) } + } +} + +impl core::fmt::Debug for Array<'_, T> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.unbound.fmt(f) + } +} + +impl<'n, T> core::ops::Index>> for Array<'n, T> { + type Output = [T]; + #[track_caller] + fn index(&self, idx: Range>) -> &Self::Output { + #[cfg(debug_assertions)] + { + &self.unbound[idx.start.unbound()..idx.end.unbound()] + } + #[cfg(not(debug_assertions))] + unsafe { + self.0 + .inner + .get_unchecked(idx.start.unbound()..idx.end.unbound()) + } + } +} +impl<'n, T> core::ops::IndexMut>> for Array<'n, T> { + #[track_caller] + fn index_mut(&mut self, idx: Range>) -> &mut Self::Output { + #[cfg(debug_assertions)] + { + &mut self.unbound[idx.start.unbound()..idx.end.unbound()] + } + #[cfg(not(debug_assertions))] + unsafe { + self.0 + .inner + .get_unchecked_mut(idx.start.unbound()..idx.end.unbound()) + } + } +} +impl<'n, T> core::ops::Index> for Array<'n, T> { + type Output = T; + #[track_caller] + fn index(&self, idx: Idx<'n>) -> &Self::Output { + #[cfg(debug_assertions)] + { + &self.unbound[idx.unbound()] + } + #[cfg(not(debug_assertions))] + unsafe { + self.unbound.get_unchecked(idx.unbound()) + } + } +} +impl<'n, T> core::ops::IndexMut> for Array<'n, T> { + #[track_caller] + fn index_mut(&mut self, idx: Idx<'n>) -> &mut Self::Output { + #[cfg(debug_assertions)] + { + &mut self.unbound[idx.unbound()] + } + #[cfg(not(debug_assertions))] + unsafe { + self.unbound.get_unchecked_mut(idx.unbound()) + } + } +} diff --git a/src/utils/constrained/mat.rs b/src/utils/constrained/mat.rs deleted file mode 100644 index 46e4e9cd..00000000 --- a/src/utils/constrained/mat.rs +++ /dev/null @@ -1,201 +0,0 @@ -use super::*; -use crate::{assert, mat}; - -/// Immutable dense matrix view with dimensions equal to the values tied to `('nrows, 'ncols)`. -#[repr(transparent)] -pub struct MatRef<'nrows, 'ncols, 'a, E: Entity>( - Branded<'ncols, Branded<'nrows, mat::MatRef<'a, E>>>, -); -/// Mutable dense matrix view with dimensions equal to the values tied to `('nrows, 'ncols)`. -#[repr(transparent)] -pub struct MatMut<'nrows, 'ncols, 'a, E: Entity>( - Branded<'ncols, Branded<'nrows, mat::MatMut<'a, E>>>, -); - -impl<'nrows, 'ncols, 'a, E: Entity> MatRef<'nrows, 'ncols, 'a, E> { - /// Returns a new matrix view after checking that its dimensions match the - /// dimensions tied to `('nrows, 'ncols)`. - #[inline] - #[track_caller] - pub fn new(inner: mat::MatRef<'a, E>, nrows: Dim<'nrows>, ncols: Dim<'ncols>) -> Self { - assert!(all( - inner.nrows() == nrows.unbound(), - inner.ncols() == ncols.unbound(), - )); - Self(Branded { - __marker: PhantomData, - inner: Branded { - __marker: PhantomData, - inner, - }, - }) - } - - /// Returns the number of rows of the matrix. - #[inline] - pub fn nrows(&self) -> Dim<'nrows> { - unsafe { Dim::new_unbound(self.0.inner.inner.nrows()) } - } - - /// Returns the number of columns of the matrix. - #[inline] - pub fn ncols(&self) -> Dim<'ncols> { - unsafe { Dim::new_unbound(self.0.inner.inner.ncols()) } - } - - /// Returns the unconstrained matrix. - #[inline] - pub fn into_inner(self) -> mat::MatRef<'a, E> { - self.0.inner.inner - } - - /// Returns the element at position `(i, j)`. - #[inline] - #[track_caller] - pub fn read(&self, i: Idx<'nrows, usize>, j: Idx<'ncols, usize>) -> E { - unsafe { self.0.inner.inner.read_unchecked(i.unbound(), j.unbound()) } - } -} - -impl<'nrows, 'ncols, 'a, E: Entity> MatMut<'nrows, 'ncols, 'a, E> { - /// Returns a new matrix view after checking that its dimensions match the - /// dimensions tied to `('nrows, 'ncols)`. - #[inline] - #[track_caller] - pub fn new(inner: mat::MatMut<'a, E>, nrows: Dim<'nrows>, ncols: Dim<'ncols>) -> Self { - assert!(all( - inner.nrows() == nrows.unbound(), - inner.ncols() == ncols.unbound(), - )); - Self(Branded { - __marker: PhantomData, - inner: Branded { - __marker: PhantomData, - inner, - }, - }) - } - - /// Returns the number of rows of the matrix. - #[inline] - pub fn nrows(&self) -> Dim<'nrows> { - unsafe { Dim::new_unbound(self.0.inner.inner.nrows()) } - } - - /// Returns the number of columns of the matrix. - #[inline] - pub fn ncols(&self) -> Dim<'ncols> { - unsafe { Dim::new_unbound(self.0.inner.inner.ncols()) } - } - - /// Returns the unconstrained matrix. - #[inline] - pub fn into_inner(self) -> mat::MatMut<'a, E> { - self.0.inner.inner - } - - /// Returns the element at position `(i, j)`. - #[inline] - #[track_caller] - pub fn read(&self, i: Idx<'nrows, usize>, j: Idx<'ncols, usize>) -> E { - unsafe { self.0.inner.inner.read_unchecked(i.unbound(), j.unbound()) } - } - - /// Writes `value` to the location at position `(i, j)`. - #[inline] - #[track_caller] - pub fn write(&mut self, i: Idx<'nrows, usize>, j: Idx<'ncols, usize>, value: E) { - unsafe { - self.0 - .inner - .inner - .write_unchecked(i.unbound(), j.unbound(), value) - }; - } -} - -impl Clone for MatRef<'_, '_, '_, E> { - #[inline] - fn clone(&self) -> Self { - *self - } -} -impl Copy for MatRef<'_, '_, '_, E> {} - -impl<'nrows, 'ncols, 'a, E: Entity> IntoConst for MatRef<'nrows, 'ncols, 'a, E> { - type Target = MatRef<'nrows, 'ncols, 'a, E>; - #[inline] - fn into_const(self) -> Self::Target { - self - } -} -impl<'nrows, 'ncols, 'a, 'short, E: Entity> Reborrow<'short> for MatRef<'nrows, 'ncols, 'a, E> { - type Target = MatRef<'nrows, 'ncols, 'short, E>; - #[inline] - fn rb(&'short self) -> Self::Target { - *self - } -} -impl<'nrows, 'ncols, 'a, 'short, E: Entity> ReborrowMut<'short> for MatRef<'nrows, 'ncols, 'a, E> { - type Target = MatRef<'nrows, 'ncols, 'short, E>; - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - *self - } -} - -impl<'nrows, 'ncols, 'a, E: Entity> IntoConst for MatMut<'nrows, 'ncols, 'a, E> { - type Target = MatRef<'nrows, 'ncols, 'a, E>; - #[inline] - fn into_const(self) -> Self::Target { - let inner = self.0.inner.inner.into_const(); - MatRef(Branded { - __marker: PhantomData, - inner: Branded { - __marker: PhantomData, - inner, - }, - }) - } -} -impl<'nrows, 'ncols, 'a, 'short, E: Entity> Reborrow<'short> for MatMut<'nrows, 'ncols, 'a, E> { - type Target = MatRef<'nrows, 'ncols, 'short, E>; - #[inline] - fn rb(&'short self) -> Self::Target { - let inner = self.0.inner.inner.rb(); - MatRef(Branded { - __marker: PhantomData, - inner: Branded { - __marker: PhantomData, - inner, - }, - }) - } -} -impl<'nrows, 'ncols, 'a, 'short, E: Entity> ReborrowMut<'short> for MatMut<'nrows, 'ncols, 'a, E> { - type Target = MatMut<'nrows, 'ncols, 'short, E>; - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - let inner = self.0.inner.inner.rb_mut(); - MatMut(Branded { - __marker: PhantomData, - inner: Branded { - __marker: PhantomData, - inner, - }, - }) - } -} - -impl Debug for MatRef<'_, '_, '_, E> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.0.inner.inner.fmt(f) - } -} -impl Debug for MatMut<'_, '_, '_, E> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.0.inner.inner.fmt(f) - } -} diff --git a/src/utils/constrained/mod.rs b/src/utils/constrained/mod.rs deleted file mode 100644 index b207f9a3..00000000 --- a/src/utils/constrained/mod.rs +++ /dev/null @@ -1,283 +0,0 @@ -use super::*; -use crate::{assert, Index}; -use core::{fmt::Debug, marker::PhantomData, ops::Range}; -use faer_entity::*; -use reborrow::*; - -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] -#[repr(transparent)] -struct Branded<'a, T: ?Sized> { - __marker: PhantomData &'a ()>, - inner: T, -} - -pub(crate) type Dim<'n> = crate::utils::bound::Dim<'n>; -pub(crate) type Idx<'n, I> = crate::utils::bound::Idx<'n, I>; -pub(crate) type IdxInclusive<'n, I> = crate::utils::bound::IdxInc<'n, I>; -pub(crate) type MaybeIdx<'n, I> = crate::utils::bound::MaybeIdx<'n, I>; - -/// Array of length equal to the value tied to `'n`. -#[derive(PartialEq, Eq, PartialOrd, Ord)] -#[repr(transparent)] -pub struct Array<'n, T>(Branded<'n, [T]>); - -/// Dense matrices with compile-time access checks. -pub mod mat; -/// Permutations with compile-time checks. -pub mod perm; -/// Sparse matrices with compile-time access checks. -#[cfg(feature = "sparse")] -pub mod sparse; - -/// Immutable array group of length equal to the value tied to `'n`. -pub struct ArrayGroup<'n, 'a, E: Entity>(Branded<'n, slice::SliceGroup<'a, E>>); -/// Mutable array group of length equal to the value tied to `'n`. -pub struct ArrayGroupMut<'n, 'a, E: Entity>(Branded<'n, slice::SliceGroupMut<'a, E>>); - -impl Debug for ArrayGroup<'_, '_, E> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.0.inner.fmt(f) - } -} -impl Debug for ArrayGroupMut<'_, '_, E> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.0.inner.fmt(f) - } -} - -impl Copy for ArrayGroup<'_, '_, E> {} -impl Clone for ArrayGroup<'_, '_, E> { - #[inline] - fn clone(&self) -> Self { - *self - } -} - -impl<'short, 'n, 'a, E: Entity> reborrow::ReborrowMut<'short> for ArrayGroup<'n, 'a, E> { - type Target = ArrayGroup<'n, 'short, E>; - - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - *self - } -} - -impl<'short, 'n, 'a, E: Entity> reborrow::Reborrow<'short> for ArrayGroup<'n, 'a, E> { - type Target = ArrayGroup<'n, 'short, E>; - - #[inline] - fn rb(&'short self) -> Self::Target { - *self - } -} - -impl<'short, 'n, 'a, E: Entity> reborrow::ReborrowMut<'short> for ArrayGroupMut<'n, 'a, E> { - type Target = ArrayGroupMut<'n, 'short, E>; - - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - ArrayGroupMut(Branded { - __marker: PhantomData, - inner: self.0.inner.rb_mut(), - }) - } -} - -impl<'short, 'n, 'a, E: Entity> reborrow::Reborrow<'short> for ArrayGroupMut<'n, 'a, E> { - type Target = ArrayGroup<'n, 'short, E>; - - #[inline] - fn rb(&'short self) -> Self::Target { - ArrayGroup(Branded { - __marker: PhantomData, - inner: self.0.inner.rb(), - }) - } -} - -impl<'n, 'a, E: Entity> ArrayGroupMut<'n, 'a, E> { - /// Returns an array group with length after checking that its length matches - /// the value tied to `'n`. - #[inline] - pub fn new(slice: GroupFor, len: Dim<'n>) -> Self { - let slice = slice::SliceGroupMut::<'_, E>::new(slice); - assert!(slice.rb().len() == len.unbound()); - ArrayGroupMut(Branded { - __marker: PhantomData, - inner: slice, - }) - } - - /// Returns the unconstrained slice. - #[inline] - pub fn into_slice(self) -> GroupFor { - self.0.inner.into_inner() - } - - /// Returns a subslice at from the range start to its end. - #[inline] - pub fn subslice(self, range: Range>) -> GroupFor { - unsafe { - slice::SliceGroupMut::<'_, E>::new(self.into_slice()) - .subslice_unchecked(range.start.unbound()..range.end.unbound()) - .into_inner() - } - } - - /// Read the element at position `j`. - #[inline] - pub fn read(&self, j: Idx<'n, usize>) -> E { - self.rb().read(j) - } - - /// Write `value` to the location at position `j`. - #[inline] - pub fn write(&mut self, j: Idx<'n, usize>, value: E) { - unsafe { - slice::SliceGroupMut::new(self.rb_mut().into_slice()) - .write_unchecked(j.unbound(), value) - } - } -} - -impl<'n, 'a, E: Entity> ArrayGroup<'n, 'a, E> { - /// Returns an array group with length after checking that its length matches - /// the value tied to `'n`. - #[inline] - pub fn new(slice: GroupFor, len: Dim<'n>) -> Self { - let slice = slice::SliceGroup::<'_, E>::new(slice); - assert!(slice.rb().len() == len.unbound()); - ArrayGroup(Branded { - __marker: PhantomData, - inner: slice, - }) - } - - /// Returns the unconstrained slice. - #[inline] - pub fn into_slice(self) -> GroupFor { - self.0.inner.into_inner() - } - - /// Returns a subslice at from the range start to its end. - #[inline] - pub fn subslice(self, range: Range>) -> GroupFor { - unsafe { - slice::SliceGroup::<'_, E>::new(self.into_slice()) - .subslice_unchecked(range.start.unbound()..range.end.unbound()) - .into_inner() - } - } - - /// Read the element at position `j`. - #[inline] - pub fn read(&self, j: Idx<'n, usize>) -> E { - unsafe { slice::SliceGroup::new(self.into_slice()).read_unchecked(j.unbound()) } - } -} - -impl<'n, T> Array<'n, T> { - /// Returns a constrained array after checking that its length matches `size`. - #[inline] - #[track_caller] - pub fn from_ref<'a>(slice: &'a [T], size: Dim<'n>) -> &'a Self { - assert!(slice.len() == size.unbound()); - unsafe { &*(slice as *const [T] as *const Self) } - } - - /// Returns a constrained array after checking that its length matches `size`. - #[inline] - #[track_caller] - pub fn from_mut<'a>(slice: &'a mut [T], size: Dim<'n>) -> &'a mut Self { - assert!(slice.len() == size.unbound()); - unsafe { &mut *(slice as *mut [T] as *mut Self) } - } - - /// Returns the unconstrained slice. - #[inline] - #[track_caller] - pub fn as_ref(&self) -> &[T] { - unsafe { &*(self as *const _ as *const _) } - } - - /// Returns the unconstrained slice. - #[inline] - #[track_caller] - pub fn as_mut<'a>(&mut self) -> &'a mut [T] { - unsafe { &mut *(self as *mut _ as *mut _) } - } - - /// Returns the length of `self`. - #[inline] - pub fn len(&self) -> Dim<'n> { - unsafe { Dim::new_unbound(self.0.inner.len()) } - } -} - -impl Debug for Array<'_, T> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.0.inner.fmt(f) - } -} - -impl<'n, T> core::ops::Index>> for Array<'n, T> { - type Output = [T]; - #[track_caller] - fn index(&self, idx: Range>) -> &Self::Output { - #[cfg(debug_assertions)] - { - &self.0.inner[idx.start.unbound()..idx.end.unbound()] - } - #[cfg(not(debug_assertions))] - unsafe { - self.0 - .inner - .get_unchecked(idx.start.unbound()..idx.end.unbound()) - } - } -} -impl<'n, T> core::ops::IndexMut>> for Array<'n, T> { - #[track_caller] - fn index_mut(&mut self, idx: Range>) -> &mut Self::Output { - #[cfg(debug_assertions)] - { - &mut self.0.inner[idx.start.unbound()..idx.end.unbound()] - } - #[cfg(not(debug_assertions))] - unsafe { - self.0 - .inner - .get_unchecked_mut(idx.start.unbound()..idx.end.unbound()) - } - } -} -impl<'n, T> core::ops::Index> for Array<'n, T> { - type Output = T; - #[track_caller] - fn index(&self, idx: Idx<'n, usize>) -> &Self::Output { - #[cfg(debug_assertions)] - { - &self.0.inner[idx.unbound()] - } - #[cfg(not(debug_assertions))] - unsafe { - self.0.inner.get_unchecked(idx.unbound()) - } - } -} -impl<'n, T> core::ops::IndexMut> for Array<'n, T> { - #[track_caller] - fn index_mut(&mut self, idx: Idx<'n, usize>) -> &mut Self::Output { - #[cfg(debug_assertions)] - { - &mut self.0.inner[idx.unbound()] - } - #[cfg(not(debug_assertions))] - unsafe { - self.0.inner.get_unchecked_mut(idx.unbound()) - } - } -} diff --git a/src/utils/constrained/perm.rs b/src/utils/constrained/perm.rs deleted file mode 100644 index d512f564..00000000 --- a/src/utils/constrained/perm.rs +++ /dev/null @@ -1,70 +0,0 @@ -use super::*; -use crate::{assert, perm}; - -/// Permutation of length equal to the value tied to `'n`. -#[repr(transparent)] -pub struct PermRef<'n, 'a, I: Index>(Branded<'n, perm::PermRef<'a, I>>); - -impl<'n, 'a, I: Index> PermRef<'n, 'a, I> { - /// Returns a new permutation after checking that it matches the size tied to `'n`. - #[inline] - #[track_caller] - pub fn new(perm: perm::PermRef<'a, I>, size: Dim<'n>) -> Self { - let (fwd, inv) = perm.arrays(); - assert!(all( - fwd.len() == size.unbound(), - inv.len() == size.unbound(), - )); - Self(Branded { - __marker: PhantomData, - inner: perm, - }) - } - - /// Returns the inverse permutation. - #[inline] - pub fn inverse(self) -> PermRef<'n, 'a, I> { - PermRef(Branded { - __marker: PhantomData, - inner: self.0.inner.inverse(), - }) - } - - /// Returns the forward and inverse permutation indices. - #[inline] - pub fn arrays(self) -> (&'a Array<'n, Idx<'n, I>>, &'a Array<'n, Idx<'n, I>>) { - unsafe { - let (fwd, inv) = self.0.inner.arrays(); - let fwd = &*(fwd as *const [I] as *const Array<'n, Idx<'n, I>>); - let inv = &*(inv as *const [I] as *const Array<'n, Idx<'n, I>>); - (fwd, inv) - } - } - - /// Returns the unconstrained permutation. - #[inline] - pub fn into_inner(self) -> perm::PermRef<'a, I> { - self.0.inner - } - - /// Returns the length of the permutation. - #[inline] - pub fn len(&self) -> Dim<'n> { - unsafe { Dim::new_unbound(self.into_inner().len()) } - } -} - -impl Clone for PermRef<'_, '_, I> { - #[inline] - fn clone(&self) -> Self { - *self - } -} -impl Copy for PermRef<'_, '_, I> {} - -impl core::fmt::Debug for PermRef<'_, '_, I> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.0.inner.fmt(f) - } -} diff --git a/src/utils/constrained/sparse.rs b/src/utils/constrained/sparse.rs deleted file mode 100644 index 0ccea59c..00000000 --- a/src/utils/constrained/sparse.rs +++ /dev/null @@ -1,292 +0,0 @@ -use super::*; -use crate::{assert, sparse::__get_unchecked, utils::slice::*}; - -/// Symbolic structure view with dimensions equal to the values tied to `('nrows, 'ncols)`, -/// in column-major order. -#[repr(transparent)] -pub struct SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I: Index>( - Branded<'ncols, Branded<'nrows, crate::sparse::SymbolicSparseColMatRef<'a, I>>>, -); -/// Immutable sparse matrix view with dimensions equal to the values tied to `('nrows, -/// 'ncols)`, in column-major order. -pub struct SparseColMatRef<'nrows, 'ncols, 'a, I: Index, E: Entity> { - symbolic: SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I>, - values: SliceGroup<'a, E>, -} -/// Mutable sparse matrix view with dimensions equal to the values tied to `('nrows, -/// 'ncols)`, in column-major order. -pub struct SparseColMatMut<'nrows, 'ncols, 'a, I: Index, E: Entity> { - symbolic: SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I>, - values: SliceGroupMut<'a, E>, -} - -impl<'nrows, 'ncols, 'a, I: Index> SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I> { - /// Returns a new symbolic structure after checking that its dimensions match the - /// dimensions tied to `('nrows, 'ncols)`. - #[inline] - pub fn new( - inner: crate::sparse::SymbolicSparseColMatRef<'a, I>, - nrows: Dim<'nrows>, - ncols: Dim<'ncols>, - ) -> Self { - assert!(all( - inner.nrows() == nrows.unbound(), - inner.ncols() == ncols.unbound(), - )); - Self(Branded { - __marker: PhantomData, - inner: Branded { - __marker: PhantomData, - inner, - }, - }) - } - - /// Returns the unconstrained symbolic structure. - #[inline] - pub fn into_inner(self) -> crate::sparse::SymbolicSparseColMatRef<'a, I> { - self.0.inner.inner - } - - /// Returns the number of rows of the matrix. - #[inline] - pub fn nrows(&self) -> Dim<'nrows> { - unsafe { Dim::new_unbound(self.0.inner.inner.nrows()) } - } - - /// Returns the number of columns of the matrix. - #[inline] - pub fn ncols(&self) -> Dim<'ncols> { - unsafe { Dim::new_unbound(self.0.inner.inner.ncols()) } - } - - #[inline] - #[track_caller] - #[doc(hidden)] - pub fn col_range(&self, j: Idx<'ncols, usize>) -> Range { - unsafe { self.into_inner().col_range_unchecked(j.unbound()) } - } - - /// Returns the row indices in column `j`. - #[inline] - #[track_caller] - pub fn row_indices_of_col_raw(&self, j: Idx<'ncols, usize>) -> &'a [Idx<'nrows, I>] { - unsafe { - &*(__get_unchecked(self.into_inner().row_indices(), self.col_range(j)) as *const [I] - as *const [Idx<'_, I>]) - } - } - - /// Returns the row indices in column `j`. - #[inline] - #[track_caller] - pub fn row_indices_of_col( - &self, - j: Idx<'ncols, usize>, - ) -> impl 'a + ExactSizeIterator + DoubleEndedIterator> { - unsafe { - __get_unchecked( - self.into_inner().row_indices(), - self.into_inner().col_range_unchecked(j.unbound()), - ) - .iter() - .map( - #[inline(always)] - move |&row| Idx::new_unbound(row.zx()), - ) - } - } -} - -impl<'nrows, 'ncols, 'a, I: Index, E: Entity> SparseColMatRef<'nrows, 'ncols, 'a, I, E> { - /// Returns a new matrix view after checking that its dimensions match the - /// dimensions tied to `('nrows, 'ncols)`. - pub fn new( - inner: crate::sparse::SparseColMatRef<'a, I, E>, - nrows: Dim<'nrows>, - ncols: Dim<'ncols>, - ) -> Self { - assert!(all( - inner.nrows() == nrows.unbound(), - inner.ncols() == ncols.unbound(), - )); - Self { - symbolic: SymbolicSparseColMatRef::new(inner.symbolic(), nrows, ncols), - values: SliceGroup::new(inner.values()), - } - } - - /// Returns the unconstrained matrix. - #[inline] - pub fn into_inner(self) -> crate::sparse::SparseColMatRef<'a, I, E> { - crate::sparse::SparseColMatRef::new(self.symbolic.into_inner(), self.values.into_inner()) - } - - /// Returns the values in column `j`. - #[inline] - pub fn values_of_col(&self, j: Idx<'ncols, usize>) -> GroupFor { - unsafe { - self.values - .subslice_unchecked(self.col_range(j)) - .into_inner() - } - } - - /// Returns the symbolic structure of the matrix. - #[inline] - pub fn symbolic(&self) -> SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I> { - self.symbolic - } -} - -impl<'nrows, 'ncols, 'a, I: Index, E: Entity> SparseColMatMut<'nrows, 'ncols, 'a, I, E> { - /// Returns a new matrix view after checking that its dimensions match the - /// dimensions tied to `('nrows, 'ncols)`. - pub fn new( - inner: crate::sparse::SparseColMatMut<'a, I, E>, - nrows: Dim<'nrows>, - ncols: Dim<'ncols>, - ) -> Self { - assert!(all( - inner.nrows() == nrows.unbound(), - inner.ncols() == ncols.unbound(), - )); - Self { - symbolic: SymbolicSparseColMatRef::new(inner.symbolic(), nrows, ncols), - values: SliceGroupMut::new(inner.values_mut()), - } - } - - /// Returns the unconstrained matrix. - #[inline] - pub fn into_inner(self) -> crate::sparse::SparseColMatMut<'a, I, E> { - crate::sparse::SparseColMatMut::new(self.symbolic.into_inner(), self.values.into_inner()) - } - - /// Returns the values in column `j`. - #[inline] - pub fn values_of_col_mut(&mut self, j: Idx<'ncols, usize>) -> GroupFor { - unsafe { - let range = self.col_range(j); - self.values.rb_mut().subslice_unchecked(range).into_inner() - } - } - - /// Returns the symbolic structure of the matrix. - #[inline] - pub fn symbolic(&self) -> SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I> { - self.symbolic - } -} - -impl Copy for SparseColMatRef<'_, '_, '_, I, E> {} -impl Clone for SparseColMatRef<'_, '_, '_, I, E> { - #[inline] - fn clone(&self) -> Self { - *self - } -} -impl Copy for SymbolicSparseColMatRef<'_, '_, '_, I> {} -impl Clone for SymbolicSparseColMatRef<'_, '_, '_, I> { - #[inline] - fn clone(&self) -> Self { - *self - } -} - -impl<'nrows, 'ncols, 'a, I: Index, E: Entity> core::ops::Deref - for SparseColMatRef<'nrows, 'ncols, 'a, I, E> -{ - type Target = SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I>; - - #[inline] - fn deref(&self) -> &Self::Target { - &self.symbolic - } -} - -impl<'nrows, 'ncols, 'a, I: Index, E: Entity> core::ops::Deref - for SparseColMatMut<'nrows, 'ncols, 'a, I, E> -{ - type Target = SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I>; - - #[inline] - fn deref(&self) -> &Self::Target { - &self.symbolic - } -} - -impl<'short, 'nrows, 'ncols, 'a, I: Index, E: Entity> ReborrowMut<'short> - for SparseColMatRef<'nrows, 'ncols, 'a, I, E> -{ - type Target = SparseColMatRef<'nrows, 'ncols, 'short, I, E>; - - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - *self - } -} - -impl<'short, 'nrows, 'ncols, 'a, I: Index, E: Entity> Reborrow<'short> - for SparseColMatRef<'nrows, 'ncols, 'a, I, E> -{ - type Target = SparseColMatRef<'nrows, 'ncols, 'short, I, E>; - - #[inline] - fn rb(&'short self) -> Self::Target { - *self - } -} - -impl<'nrows, 'ncols, 'a, I: Index, E: Entity> IntoConst - for SparseColMatRef<'nrows, 'ncols, 'a, I, E> -{ - type Target = SparseColMatRef<'nrows, 'ncols, 'a, I, E>; - - #[inline] - fn into_const(self) -> Self::Target { - self - } -} - -impl<'short, 'nrows, 'ncols, 'a, I: Index, E: Entity> ReborrowMut<'short> - for SparseColMatMut<'nrows, 'ncols, 'a, I, E> -{ - type Target = SparseColMatMut<'nrows, 'ncols, 'short, I, E>; - - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - SparseColMatMut::<'nrows, 'ncols, 'short, I, E> { - symbolic: self.symbolic, - values: self.values.rb_mut(), - } - } -} - -impl<'short, 'nrows, 'ncols, 'a, I: Index, E: Entity> Reborrow<'short> - for SparseColMatMut<'nrows, 'ncols, 'a, I, E> -{ - type Target = SparseColMatRef<'nrows, 'ncols, 'short, I, E>; - - #[inline] - fn rb(&'short self) -> Self::Target { - SparseColMatRef::<'nrows, 'ncols, 'short, I, E> { - symbolic: self.symbolic, - values: self.values.rb(), - } - } -} - -impl<'nrows, 'ncols, 'a, I: Index, E: Entity> IntoConst - for SparseColMatMut<'nrows, 'ncols, 'a, I, E> -{ - type Target = SparseColMatRef<'nrows, 'ncols, 'a, I, E>; - - #[inline] - fn into_const(self) -> Self::Target { - SparseColMatRef::<'nrows, 'ncols, 'a, I, E> { - symbolic: self.symbolic, - values: self.values.into_const(), - } - } -} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index a76238c3..57e25f27 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -45,8 +45,6 @@ impl DivCeil for usize { } } -/// Index and matrix types with compile time checks, whichh can replace bound checks at runtime. -pub mod constrained; /// Simd operations for a specific type satisfying [`ComplexField`](crate::ComplexField). pub mod simd; /// Slice types for [entities](crate::Entity).