Skip to content

Commit

Permalink
feat: support visit_seq for Deserialize
Browse files Browse the repository at this point in the history
Signed-off-by: Keming <[email protected]>
  • Loading branch information
kemingy committed Sep 18, 2024
1 parent bb213d1 commit 6913304
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 102 deletions.
3 changes: 2 additions & 1 deletion src/col/colown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,8 @@ impl<E: Entity> Clone for Col<E> {
E::faer_from_units(E::faer_deref(this.get_unchecked(i)))
})
}
}fn clone_from(&mut self,other:&Self){
}
fn clone_from(&mut self, other: &Self) {
self.resize_with(0, |_| E::zeroed());
self.resize_with(
other.nrows(),
Expand Down
6 changes: 3 additions & 3 deletions src/mat/matown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ impl<E: Entity> Mat<E> {
this: &mut Mat<E>,
other: MatRef<'_, ViewE>,
) {
let (rows, cols)=other.shape();
let (rows, cols) = other.shape();
this.resize_with(0, 0, |_, _| E::zeroed());
this.resize_with(
rows,
Expand Down Expand Up @@ -2070,9 +2070,9 @@ impl<E: Entity> Clone for Mat<E> {
})
}
}

fn clone_from(&mut self, other: &Self) {
let (rows, cols)=other.shape();
let (rows, cols) = other.shape();
self.resize_with(0, 0, |_, _| E::zeroed());
self.resize_with(
rows,
Expand Down
4 changes: 2 additions & 2 deletions src/row/rowown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1030,8 +1030,8 @@ impl<E: Entity> Clone for Row<E> {
})
}
}
fn clone_from(&mut self, other: &Self){

fn clone_from(&mut self, other: &Self) {
self.resize_with(0, |_| E::zeroed());
self.resize_with(
other.nrows(),
Expand Down
201 changes: 108 additions & 93 deletions src/serde/mat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,113 +88,128 @@ where
}
const FIELDS: &'static [&'static str] = &["nrows", "ncols", "data"];
struct MatVisitor<E: Entity>(PhantomData<E>);
impl<'a, E: Entity + Deserialize<'a>> Visitor<'a> for MatVisitor<E> {
type Value = Mat<E>;
enum MatrixOrVec<E: Entity> {
Matrix(Mat<E>),
Vec(Vec<E>),
}
impl<E: Entity> MatrixOrVec<E> {
fn into_mat(self, nrows: usize, ncols: usize) -> Mat<E> {
match self {
MatrixOrVec::Matrix(m) => m,
MatrixOrVec::Vec(v) => Mat::from_fn(nrows, ncols, |i, j| v[i * ncols + j]),
}
}
}
struct MatrixOrVecDeserializer<'a, E: Entity + Deserialize<'a>> {
marker: PhantomData<&'a E>,
nrows: Option<usize>,
ncols: Option<usize>,
}
impl<'a, E: Entity + Deserialize<'a>> MatrixOrVecDeserializer<'a, E> {
fn new(nrows: Option<usize>, ncols: Option<usize>) -> Self {
Self {
marker: PhantomData,
nrows,
ncols,
}
}
}
impl<'a, E: Entity> DeserializeSeed<'a> for MatrixOrVecDeserializer<'a, E>
where
E: Deserialize<'a>,
{
type Value = MatrixOrVec<E>;

fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'a>,
{
deserializer.deserialize_seq(self)
}
}
impl<'a, E: Entity> Visitor<'a> for MatrixOrVecDeserializer<'a, E>
where
E: Deserialize<'a>,
{
type Value = MatrixOrVec<E>;

fn expecting(&self, formatter: &mut alloc::fmt::Formatter) -> alloc::fmt::Result {
formatter.write_str("a faer matrix")
formatter.write_str("a sequence")
}

fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'a>,
A: SeqAccess<'a>,
{
enum MatrixOrVec<E: Entity> {
Matrix(Mat<E>),
Vec(Vec<E>),
}
impl<E: Entity> MatrixOrVec<E> {
fn into_mat(self, nrows: usize, ncols: usize) -> Mat<E> {
match self {
MatrixOrVec::Matrix(m) => m,
MatrixOrVec::Vec(v) => {
Mat::from_fn(nrows, ncols, |i, j| v[i * ncols + j])
}
match (self.ncols, self.nrows) {
(Some(ncols), Some(nrows)) => {
let mut data = Mat::<E>::with_capacity(nrows, ncols);
unsafe {
data.set_dims(nrows, ncols);
}
let expected_length = nrows * ncols;
for i in 0..expected_length {
let el = seq.next_element::<E>()?.ok_or_else(|| {
serde::de::Error::invalid_length(
i,
&format!("{} elements", expected_length).as_str(),
)
})?;
data.write(i / ncols, i % ncols, el);
}
let mut additional = 0usize;
while let Some(_) = seq.next_element::<E>()? {
additional += 1;
}
if additional > 0 {
return Err(serde::de::Error::invalid_length(
additional + expected_length,
&format!("{} elements", expected_length).as_str(),
));
}
Ok(MatrixOrVec::Matrix(data))
}
}
struct MatrixOrVecDeserializer<'a, E: Entity + Deserialize<'a>> {
marker: PhantomData<&'a E>,
nrows: Option<usize>,
ncols: Option<usize>,
}
impl<'a, E: Entity + Deserialize<'a>> MatrixOrVecDeserializer<'a, E> {
fn new(nrows: Option<usize>, ncols: Option<usize>) -> Self {
Self {
marker: PhantomData,
nrows,
ncols,
_ => {
let mut data = Vec::new();
while let Some(el) = seq.next_element::<E>()? {
data.push(el);
}
Ok(MatrixOrVec::Vec(data))
}
}
impl<'a, E: Entity> DeserializeSeed<'a> for MatrixOrVecDeserializer<'a, E>
where
E: Deserialize<'a>,
{
type Value = MatrixOrVec<E>;
}
}
impl<'a, E: Entity + Deserialize<'a>> Visitor<'a> for MatVisitor<E> {
type Value = Mat<E>;

fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'a>,
{
deserializer.deserialize_seq(self)
}
}
impl<'a, E: Entity> Visitor<'a> for MatrixOrVecDeserializer<'a, E>
where
E: Deserialize<'a>,
{
type Value = MatrixOrVec<E>;
fn expecting(&self, formatter: &mut alloc::fmt::Formatter) -> alloc::fmt::Result {
formatter.write_str("a faer matrix")
}

fn expecting(
&self,
formatter: &mut alloc::fmt::Formatter,
) -> alloc::fmt::Result {
formatter.write_str("a sequence")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'a>,
{
let nrows = seq
.next_element::<usize>()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &"nrows"))?;
let ncols = seq
.next_element::<usize>()?
.ok_or_else(|| serde::de::Error::invalid_length(1, &"ncols"))?;
let data = seq.next_element_seed(MatrixOrVecDeserializer::<E>::new(
Some(nrows),
Some(ncols),
))?;
let mat = data
.ok_or_else(|| serde::de::Error::missing_field("data"))?
.into_mat(nrows, ncols);
Ok(mat)
}

fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'a>,
{
match (self.ncols, self.nrows) {
(Some(ncols), Some(nrows)) => {
let mut data = Mat::<E>::with_capacity(nrows, ncols);
unsafe {
data.set_dims(nrows, ncols);
}
let expected_length = nrows * ncols;
for i in 0..expected_length {
let el = seq.next_element::<E>()?.ok_or_else(|| {
serde::de::Error::invalid_length(
i,
&format!("{} elements", expected_length).as_str(),
)
})?;
data.write(i / ncols, i % ncols, el);
}
let mut additional = 0usize;
while let Some(_) = seq.next_element::<E>()? {
additional += 1;
}
if additional > 0 {
return Err(serde::de::Error::invalid_length(
additional + expected_length,
&format!("{} elements", expected_length).as_str(),
));
}
Ok(MatrixOrVec::Matrix(data))
}
_ => {
let mut data = Vec::new();
while let Some(el) = seq.next_element::<E>()? {
data.push(el);
}
Ok(MatrixOrVec::Vec(data))
}
}
}
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'a>,
{
let mut nrows = None;
let mut ncols = None;
let mut data: Option<MatrixOrVec<E>> = None;
Expand Down
9 changes: 6 additions & 3 deletions src/sparse/linalg/solvers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@ pub trait SpSolverLstsqCore<E: Entity>: SpSolverCore<E> {
pub trait SpSolver<E: ComplexField>: SpSolverCore<E> {
/// Solves the equation `self * X = rhs` when self is square, and stores the result in `rhs`.
fn solve_in_place(&self, rhs: impl ColBatchMut<E>);
/// Solves the equation `conjugate(self) * X = rhs` when self is square, and stores the result in `rhs`.
/// Solves the equation `conjugate(self) * X = rhs` when self is square, and stores the result
/// in `rhs`.
fn solve_conj_in_place(&self, rhs: impl ColBatchMut<E>);
/// Solves the equation `transpose(self) * X = rhs` when self is square, and stores the result in `rhs`.
/// Solves the equation `transpose(self) * X = rhs` when self is square, and stores the result
/// in `rhs`.
fn solve_transpose_in_place(&self, rhs: impl ColBatchMut<E>);
/// Solves the equation `adjoint(self) * X = rhs` when self is square, and stores the result in `rhs`.
/// Solves the equation `adjoint(self) * X = rhs` when self is square, and stores the result in
/// `rhs`.
fn solve_conj_transpose_in_place(&self, rhs: impl ColBatchMut<E>);
/// Solves the equation `self * X = rhs` when self is square, and returns the result.
fn solve<ViewE: Conjugate<Canonical = E>, B: ColBatch<ViewE>>(&self, rhs: B) -> B::Owned;
Expand Down

0 comments on commit 6913304

Please sign in to comment.