use std::ops::{Add, Sub, Mul}; use crate::prelude::*; #[derive(Clone)] pub struct Matrix { pub domain: usize, pub codomain: usize, pub values: Vec } macro_rules! error { ($($arg:tt)*) => { exception!(VALUE_EXCEPTION, $($arg)*) }; } impl Matrix { pub fn new( domain: usize, codomain: usize, values: Vec ) -> Self { Self { domain, codomain, values } } pub fn from_list( values: Vec ) -> Self { Self { domain: values.len(), codomain: 1, values } } pub fn empty( domain: usize, codomain: usize ) -> Self { let values = (0..(domain * codomain)).into_iter() .map(|_| Value::Int(0)) .collect(); Self { domain, codomain, values } } pub fn get(&self, row: usize, col: usize) -> Result { if row >= self.codomain || col >= self.domain { return Err(error!("[{};{}] out of bounds for [Matrix {}x{}]", row, col, self.domain, self.codomain)) } let idx = col + row * self.domain; Ok(self.values[idx].clone()) } pub fn set(&mut self, row: usize, col: usize, val: Value) -> Result<()> { if row >= self.codomain || col >= self.domain { return Err(error!("[{};{}] out of bounds for [Matrix {}x{}]", row, col, self.domain, self.codomain)) } let idx = col + row * self.domain; self.values[idx] = val; Ok(()) } pub fn rows<'a>(&'a self) -> MatrixRows<'a> { MatrixRows { matrix: self, row: 0 } } pub fn cols<'a>(&'a self) -> MatrixCols<'a> { MatrixCols { matrix: self, col: 0 } } // SPLCIE DOMAIN pub fn splice_cols(&self, col_start: usize, col_end: usize) -> Result { if col_start <= col_end || col_end > self.domain { return Err(error!("[_;{}..{}] invalid for [Matrix {}x{}]", col_start, col_end, self.domain, self.codomain)) } let mut cols = Vec::new(); for (i, col) in self.cols().enumerate() { if i >= col_start && i < col_end { cols.push(col); } } let domain = cols.len(); let codomain = cols[0].len(); let mut res = Self::empty(domain, codomain); for i in 0..domain { for j in 0..codomain { res.set(j, i, cols[i][j].clone())?; } } Ok(res) } // SPLICE CODOMAIN pub fn splice_rows(&self, row_start: usize, row_end: usize) -> Result { if row_start <= row_end || row_end > self.codomain { return Err(error!("[{}..{};_] invalid for [Matrix {}x{}]", row_start, row_end, self.domain, self.codomain)) } let mut rows = Vec::new(); for (i, row) in self.rows().enumerate() { if i >= row_start && i < row_end { rows.push(row); } } let domain = rows[0].len(); let codomain = rows.len(); let mut res = Self::empty(domain, codomain); for i in 0..domain { for j in 0..codomain { res.set(j, i, rows[j][i].clone())?; } } Ok(res) } pub fn increment(&self, increment: Value) -> Result { let values = self.values.iter() .map(|v| v.clone() + increment.clone()) .collect::>>()?; Ok(Matrix::new(self.domain, self.codomain, values)) } pub fn decrement(&self, decrement: Value) -> Result { let values = self.values.iter() .map(|v| v.clone() - decrement.clone()) .collect::>>()?; Ok(Matrix::new(self.domain, self.codomain, values)) } pub fn scale(&self, scale: Value) -> Result { let values = self.values.iter() .map(|v| v.clone() * scale.clone()) .collect::>>()?; Ok(Matrix::new(self.domain, self.codomain, values)) } pub fn join_right(&self, other: &Matrix) -> Result { if self.codomain != other.codomain { return Err(error!("matrix codomain's do not match")) } let mut r1 = self.rows(); let mut r2 = other.rows(); let mut rows = Vec::new(); loop { let Some(r1) = r1.next() else { break; }; let Some(r2) = r2.next() else { break; }; let mut row = r1 .into_iter() .map(|v| v.clone()) .collect::>(); row.extend(r2.into_iter().map(|v| v.clone())); rows.push(row); } let values = rows .into_iter() .reduce(|mut a,b| {a.extend(b); a}) .ok_or(error!("matrix row smashed"))?; Ok(Matrix::new(self.domain + other.domain, self.codomain, values)) } pub fn join_bottom(&self, other: &Matrix) -> Result { if self.domain != other.domain { return Err(error!("matrix domain's do not match")) } let mut values = self.values.clone(); values.extend(other.values.clone()); Ok(Matrix::new(self.domain, self.codomain + other.codomain, values)) } } impl PartialEq for Matrix { fn eq(&self, other: &Self) -> bool { if self.domain != other.domain || self.codomain != other.codomain { return false } for i in 0..self.values.len() { if self.values[i] != other.values[i] { return false; } } return true; } } impl Add for Matrix { type Output = Result; fn add(self, rhs: Self) -> Self::Output { if self.domain != rhs.domain || self.codomain != rhs.codomain { return Err(error!("cannot add {self:?} + {rhs:?}")) } let mut res = Matrix::empty(self.domain, self.codomain); for col in 0..self.domain { for row in 0..self.codomain { let add = self.get(row, col)? + rhs.get(row, col)?; res.set(row, col, add?)?; } } Ok(res) } } impl Sub for Matrix { type Output = Result; fn sub(self, rhs: Self) -> Self::Output { if self.domain != rhs.domain || self.codomain != rhs.codomain { return Err(error!("cannot subtract {self:?} - {rhs:?}")) } let mut res = Matrix::empty(self.domain, self.codomain); for col in 0..self.domain { for row in 0..self.codomain { let sub = self.get(row, col)? - rhs.get(row, col)?; res.set(row, col, sub?)?; } } Ok(res) } } fn dot(lhs: Vec<&Value>, rhs: Vec<&Value>) -> Result { let len = lhs.len(); let mut res = Value::Int(0); for i in 0..len { let val = (lhs[i].clone() * rhs[i].clone())?; res = (res + val)?; } Ok(res) } impl Mul for Matrix { type Output = Result; fn mul(self, rhs: Self) -> Self::Output { if self.domain != rhs.codomain { return Err(error!("cannot multiply {self:?} * {rhs:?}")) } let mut res = Self::empty(rhs.domain, self.codomain); for (i, row) in self.rows().enumerate() { for (j, col) in rhs.cols().enumerate() { let dot = dot(row.clone(), col.clone())?; res.set(i, j, dot)?; } } Ok(res) } } pub struct MatrixRows<'a> { matrix: &'a Matrix, row: usize } impl<'a> Iterator for MatrixRows<'a> { type Item = Vec<&'a Value>; fn next(&mut self) -> Option { if self.row >= self.matrix.codomain { return None } let row_start = self.row * self.matrix.domain; let row_end = row_start + self.matrix.domain; let res = self.matrix.values .iter() .enumerate() .filter(|(idx, _)| *idx >= row_start && *idx < row_end) .map(|v| v.1) .collect(); self.row += 1; Some(res) } } pub struct MatrixCols<'a> { matrix: &'a Matrix, col: usize } impl<'a> Iterator for MatrixCols<'a> { type Item = Vec<&'a Value>; fn next(&mut self) -> Option { if self.col >= self.matrix.domain { return None } let res = self.matrix.values .iter() .enumerate() .filter(|(idx, _)| *idx % self.matrix.domain == self.col) .map(|v| v.1) .collect(); self.col += 1; Some(res) } }