diff options
Diffstat (limited to 'matrix-lang/src/value/matrix.rs')
-rw-r--r-- | matrix-lang/src/value/matrix.rs | 337 |
1 files changed, 337 insertions, 0 deletions
diff --git a/matrix-lang/src/value/matrix.rs b/matrix-lang/src/value/matrix.rs new file mode 100644 index 0000000..91e3ec2 --- /dev/null +++ b/matrix-lang/src/value/matrix.rs @@ -0,0 +1,337 @@ +use std::ops::{Add, Sub, Mul}; + +use crate::prelude::*; + +#[derive(Clone)] +pub struct Matrix { + pub domain: usize, + pub codomain: usize, + pub values: Vec<Value> +} + +macro_rules! error { + ($($arg:tt)*) => { + exception!(VALUE_EXCEPTION, $($arg)*) + }; +} + +impl Matrix { + pub fn new( + domain: usize, + codomain: usize, + values: Vec<Value> + ) -> Self { + Self { + domain, + codomain, + values + } + } + + pub fn from_list( + values: Vec<Value> + ) -> Self { + Self { + domain: values.len(), + codomain: 1, + values + } + } + + pub fn empty( + domain: usize, + codomain: usize + ) -> Self { + let values = (0..(domain * codomain)).into_iter() + .map(|_| Value::Int(0)) + .collect(); + Self { + domain, + codomain, + values + } + } + + pub fn get(&self, row: usize, col: usize) -> Result<Value> { + if row >= self.codomain || col >= self.domain { + return Err(error!("[{};{}] out of bounds for [Matrix {}x{}]", row, col, self.domain, self.codomain)) + } + let idx = col + row * self.domain; + Ok(self.values[idx].clone()) + } + + + + pub fn set(&mut self, row: usize, col: usize, val: Value) -> Result<()> { + if row >= self.codomain || col >= self.domain { + return Err(error!("[{};{}] out of bounds for [Matrix {}x{}]", row, col, self.domain, self.codomain)) + } + let idx = col + row * self.domain; + self.values[idx] = val; + Ok(()) + } + + pub fn rows<'a>(&'a self) -> MatrixRows<'a> { + MatrixRows { + matrix: self, + row: 0 + } + } + + pub fn cols<'a>(&'a self) -> MatrixCols<'a> { + MatrixCols { + matrix: self, + col: 0 + } + } + + // SPLCIE DOMAIN + pub fn splice_cols(&self, col_start: usize, col_end: usize) -> Result<Self> { + if col_start <= col_end || col_end > self.domain { + return Err(error!("[_;{}..{}] invalid for [Matrix {}x{}]", col_start, col_end, self.domain, self.codomain)) + } + + let mut cols = Vec::new(); + + for (i, col) in self.cols().enumerate() { + if i >= col_start && i < col_end { + cols.push(col); + } + } + + let domain = cols.len(); + let codomain = cols[0].len(); + let mut res = Self::empty(domain, codomain); + + for i in 0..domain { + for j in 0..codomain { + res.set(j, i, cols[i][j].clone())?; + } + } + + Ok(res) + } + + // SPLICE CODOMAIN + pub fn splice_rows(&self, row_start: usize, row_end: usize) -> Result<Self> { + if row_start <= row_end || row_end > self.codomain { + return Err(error!("[{}..{};_] invalid for [Matrix {}x{}]", row_start, row_end, self.domain, self.codomain)) + } + + let mut rows = Vec::new(); + + for (i, row) in self.rows().enumerate() { + if i >= row_start && i < row_end { + rows.push(row); + } + } + + let domain = rows[0].len(); + let codomain = rows.len(); + let mut res = Self::empty(domain, codomain); + + for i in 0..domain { + for j in 0..codomain { + res.set(j, i, rows[j][i].clone())?; + } + } + + Ok(res) + } + + pub fn increment(&self, increment: Value) -> Result<Self> { + let values = self.values.iter() + .map(|v| v.clone() + increment.clone()) + .collect::<Result<Vec<Value>>>()?; + Ok(Matrix::new(self.domain, self.codomain, values)) + } + + pub fn decrement(&self, decrement: Value) -> Result<Self> { + let values = self.values.iter() + .map(|v| v.clone() - decrement.clone()) + .collect::<Result<Vec<Value>>>()?; + Ok(Matrix::new(self.domain, self.codomain, values)) + } + + pub fn scale(&self, scale: Value) -> Result<Self> { + let values = self.values.iter() + .map(|v| v.clone() * scale.clone()) + .collect::<Result<Vec<Value>>>()?; + Ok(Matrix::new(self.domain, self.codomain, values)) + } + + pub fn join_right(&self, other: &Matrix) -> Result<Self> { + if self.codomain != other.codomain { + return Err(error!("matrix codomain's do not match")) + } + let mut r1 = self.rows(); + let mut r2 = other.rows(); + let mut rows = Vec::new(); + loop { + let Some(r1) = r1.next() else { break; }; + let Some(r2) = r2.next() else { break; }; + + let mut row = r1 + .into_iter() + .map(|v| v.clone()) + .collect::<Vec<Value>>(); + + row.extend(r2.into_iter().map(|v| v.clone())); + + rows.push(row); + } + + let values = rows + .into_iter() + .reduce(|mut a,b| {a.extend(b); a}) + .unwrap(); + + Ok(Matrix::new(self.domain + other.domain, self.codomain, values)) + } + + pub fn join_bottom(&self, other: &Matrix) -> Result<Self> { + if self.domain != other.domain { + return Err(error!("matrix domain's do not match")) + } + let mut values = self.values.clone(); + values.extend(other.values.clone()); + Ok(Matrix::new(self.domain, self.codomain + other.codomain, values)) + } +} + +impl PartialEq for Matrix { + fn eq(&self, other: &Self) -> bool { + if self.domain != other.domain || self.codomain != other.codomain { + return false + } + + for i in 0..self.values.len() { + if self.values[i] != other.values[i] { + return false; + } + } + + return true; + } +} + +impl Add for Matrix { + type Output = Result<Self>; + + fn add(self, rhs: Self) -> Self::Output { + if self.domain != rhs.domain || self.codomain != rhs.codomain { + return Err(error!("cannot add {self:?} + {rhs:?}")) + } + let mut res = Matrix::empty(self.domain, self.codomain); + for col in 0..self.domain { + for row in 0..self.codomain { + let add = self.get(row, col)? + rhs.get(row, col)?; + res.set(row, col, add?)?; + } + } + Ok(res) + } +} + +impl Sub for Matrix { + type Output = Result<Self>; + + fn sub(self, rhs: Self) -> Self::Output { + if self.domain != rhs.domain || self.codomain != rhs.codomain { + return Err(error!("cannot subtract {self:?} - {rhs:?}")) + } + let mut res = Matrix::empty(self.domain, self.codomain); + for col in 0..self.domain { + for row in 0..self.codomain { + let sub = self.get(row, col)? - rhs.get(row, col)?; + res.set(row, col, sub?)?; + } + } + Ok(res) + } +} + +fn dot(lhs: Vec<&Value>, rhs: Vec<&Value>) -> Result<Value> { + let len = lhs.len(); + + let mut res = Value::Int(0); + for i in 0..len { + let val = (lhs[i].clone() * rhs[i].clone())?; + res = (res + val)?; + } + + Ok(res) +} + +impl Mul for Matrix { + type Output = Result<Self>; + + fn mul(self, rhs: Self) -> Self::Output { + if self.domain != rhs.codomain { + return Err(error!("cannot multiply {self:?} * {rhs:?}")) + } + let mut res = Self::empty(rhs.domain, self.codomain); + for (i, row) in self.rows().enumerate() { + for (j, col) in rhs.cols().enumerate() { + let dot = dot(row.clone(), col.clone())?; + res.set(i, j, dot)?; + } + } + Ok(res) + } +} + +pub struct MatrixRows<'a> { + matrix: &'a Matrix, + row: usize +} + +impl<'a> Iterator for MatrixRows<'a> { + type Item = Vec<&'a Value>; + + fn next(&mut self) -> Option<Self::Item> { + if self.row >= self.matrix.codomain { + return None + } + + let row_start = self.row * self.matrix.domain; + let row_end = row_start + self.matrix.domain; + + let res = self.matrix.values + .iter() + .enumerate() + .filter(|(idx, _)| *idx >= row_start && *idx < row_end) + .map(|v| v.1) + .collect(); + + self.row += 1; + + Some(res) + } +} + +pub struct MatrixCols<'a> { + matrix: &'a Matrix, + col: usize +} + +impl<'a> Iterator for MatrixCols<'a> { + type Item = Vec<&'a Value>; + + fn next(&mut self) -> Option<Self::Item> { + if self.col >= self.matrix.domain { + return None + } + + let res = self.matrix.values + .iter() + .enumerate() + .filter(|(idx, _)| *idx % self.matrix.domain == self.col) + .map(|v| v.1) + .collect(); + + self.col += 1; + + Some(res) + } +} |