summaryrefslogtreecommitdiff
path: root/matrix-lang/src/value/matrix.rs
diff options
context:
space:
mode:
Diffstat (limited to 'matrix-lang/src/value/matrix.rs')
-rw-r--r--matrix-lang/src/value/matrix.rs337
1 files changed, 337 insertions, 0 deletions
diff --git a/matrix-lang/src/value/matrix.rs b/matrix-lang/src/value/matrix.rs
new file mode 100644
index 0000000..91e3ec2
--- /dev/null
+++ b/matrix-lang/src/value/matrix.rs
@@ -0,0 +1,337 @@
+use std::ops::{Add, Sub, Mul};
+
+use crate::prelude::*;
+
+#[derive(Clone)]
+pub struct Matrix {
+ pub domain: usize,
+ pub codomain: usize,
+ pub values: Vec<Value>
+}
+
+macro_rules! error {
+ ($($arg:tt)*) => {
+ exception!(VALUE_EXCEPTION, $($arg)*)
+ };
+}
+
+impl Matrix {
+ pub fn new(
+ domain: usize,
+ codomain: usize,
+ values: Vec<Value>
+ ) -> Self {
+ Self {
+ domain,
+ codomain,
+ values
+ }
+ }
+
+ pub fn from_list(
+ values: Vec<Value>
+ ) -> Self {
+ Self {
+ domain: values.len(),
+ codomain: 1,
+ values
+ }
+ }
+
+ pub fn empty(
+ domain: usize,
+ codomain: usize
+ ) -> Self {
+ let values = (0..(domain * codomain)).into_iter()
+ .map(|_| Value::Int(0))
+ .collect();
+ Self {
+ domain,
+ codomain,
+ values
+ }
+ }
+
+ pub fn get(&self, row: usize, col: usize) -> Result<Value> {
+ if row >= self.codomain || col >= self.domain {
+ return Err(error!("[{};{}] out of bounds for [Matrix {}x{}]", row, col, self.domain, self.codomain))
+ }
+ let idx = col + row * self.domain;
+ Ok(self.values[idx].clone())
+ }
+
+
+
+ pub fn set(&mut self, row: usize, col: usize, val: Value) -> Result<()> {
+ if row >= self.codomain || col >= self.domain {
+ return Err(error!("[{};{}] out of bounds for [Matrix {}x{}]", row, col, self.domain, self.codomain))
+ }
+ let idx = col + row * self.domain;
+ self.values[idx] = val;
+ Ok(())
+ }
+
+ pub fn rows<'a>(&'a self) -> MatrixRows<'a> {
+ MatrixRows {
+ matrix: self,
+ row: 0
+ }
+ }
+
+ pub fn cols<'a>(&'a self) -> MatrixCols<'a> {
+ MatrixCols {
+ matrix: self,
+ col: 0
+ }
+ }
+
+ // SPLCIE DOMAIN
+ pub fn splice_cols(&self, col_start: usize, col_end: usize) -> Result<Self> {
+ if col_start <= col_end || col_end > self.domain {
+ return Err(error!("[_;{}..{}] invalid for [Matrix {}x{}]", col_start, col_end, self.domain, self.codomain))
+ }
+
+ let mut cols = Vec::new();
+
+ for (i, col) in self.cols().enumerate() {
+ if i >= col_start && i < col_end {
+ cols.push(col);
+ }
+ }
+
+ let domain = cols.len();
+ let codomain = cols[0].len();
+ let mut res = Self::empty(domain, codomain);
+
+ for i in 0..domain {
+ for j in 0..codomain {
+ res.set(j, i, cols[i][j].clone())?;
+ }
+ }
+
+ Ok(res)
+ }
+
+ // SPLICE CODOMAIN
+ pub fn splice_rows(&self, row_start: usize, row_end: usize) -> Result<Self> {
+ if row_start <= row_end || row_end > self.codomain {
+ return Err(error!("[{}..{};_] invalid for [Matrix {}x{}]", row_start, row_end, self.domain, self.codomain))
+ }
+
+ let mut rows = Vec::new();
+
+ for (i, row) in self.rows().enumerate() {
+ if i >= row_start && i < row_end {
+ rows.push(row);
+ }
+ }
+
+ let domain = rows[0].len();
+ let codomain = rows.len();
+ let mut res = Self::empty(domain, codomain);
+
+ for i in 0..domain {
+ for j in 0..codomain {
+ res.set(j, i, rows[j][i].clone())?;
+ }
+ }
+
+ Ok(res)
+ }
+
+ pub fn increment(&self, increment: Value) -> Result<Self> {
+ let values = self.values.iter()
+ .map(|v| v.clone() + increment.clone())
+ .collect::<Result<Vec<Value>>>()?;
+ Ok(Matrix::new(self.domain, self.codomain, values))
+ }
+
+ pub fn decrement(&self, decrement: Value) -> Result<Self> {
+ let values = self.values.iter()
+ .map(|v| v.clone() - decrement.clone())
+ .collect::<Result<Vec<Value>>>()?;
+ Ok(Matrix::new(self.domain, self.codomain, values))
+ }
+
+ pub fn scale(&self, scale: Value) -> Result<Self> {
+ let values = self.values.iter()
+ .map(|v| v.clone() * scale.clone())
+ .collect::<Result<Vec<Value>>>()?;
+ Ok(Matrix::new(self.domain, self.codomain, values))
+ }
+
+ pub fn join_right(&self, other: &Matrix) -> Result<Self> {
+ if self.codomain != other.codomain {
+ return Err(error!("matrix codomain's do not match"))
+ }
+ let mut r1 = self.rows();
+ let mut r2 = other.rows();
+ let mut rows = Vec::new();
+ loop {
+ let Some(r1) = r1.next() else { break; };
+ let Some(r2) = r2.next() else { break; };
+
+ let mut row = r1
+ .into_iter()
+ .map(|v| v.clone())
+ .collect::<Vec<Value>>();
+
+ row.extend(r2.into_iter().map(|v| v.clone()));
+
+ rows.push(row);
+ }
+
+ let values = rows
+ .into_iter()
+ .reduce(|mut a,b| {a.extend(b); a})
+ .unwrap();
+
+ Ok(Matrix::new(self.domain + other.domain, self.codomain, values))
+ }
+
+ pub fn join_bottom(&self, other: &Matrix) -> Result<Self> {
+ if self.domain != other.domain {
+ return Err(error!("matrix domain's do not match"))
+ }
+ let mut values = self.values.clone();
+ values.extend(other.values.clone());
+ Ok(Matrix::new(self.domain, self.codomain + other.codomain, values))
+ }
+}
+
+impl PartialEq for Matrix {
+ fn eq(&self, other: &Self) -> bool {
+ if self.domain != other.domain || self.codomain != other.codomain {
+ return false
+ }
+
+ for i in 0..self.values.len() {
+ if self.values[i] != other.values[i] {
+ return false;
+ }
+ }
+
+ return true;
+ }
+}
+
+impl Add for Matrix {
+ type Output = Result<Self>;
+
+ fn add(self, rhs: Self) -> Self::Output {
+ if self.domain != rhs.domain || self.codomain != rhs.codomain {
+ return Err(error!("cannot add {self:?} + {rhs:?}"))
+ }
+ let mut res = Matrix::empty(self.domain, self.codomain);
+ for col in 0..self.domain {
+ for row in 0..self.codomain {
+ let add = self.get(row, col)? + rhs.get(row, col)?;
+ res.set(row, col, add?)?;
+ }
+ }
+ Ok(res)
+ }
+}
+
+impl Sub for Matrix {
+ type Output = Result<Self>;
+
+ fn sub(self, rhs: Self) -> Self::Output {
+ if self.domain != rhs.domain || self.codomain != rhs.codomain {
+ return Err(error!("cannot subtract {self:?} - {rhs:?}"))
+ }
+ let mut res = Matrix::empty(self.domain, self.codomain);
+ for col in 0..self.domain {
+ for row in 0..self.codomain {
+ let sub = self.get(row, col)? - rhs.get(row, col)?;
+ res.set(row, col, sub?)?;
+ }
+ }
+ Ok(res)
+ }
+}
+
+fn dot(lhs: Vec<&Value>, rhs: Vec<&Value>) -> Result<Value> {
+ let len = lhs.len();
+
+ let mut res = Value::Int(0);
+ for i in 0..len {
+ let val = (lhs[i].clone() * rhs[i].clone())?;
+ res = (res + val)?;
+ }
+
+ Ok(res)
+}
+
+impl Mul for Matrix {
+ type Output = Result<Self>;
+
+ fn mul(self, rhs: Self) -> Self::Output {
+ if self.domain != rhs.codomain {
+ return Err(error!("cannot multiply {self:?} * {rhs:?}"))
+ }
+ let mut res = Self::empty(rhs.domain, self.codomain);
+ for (i, row) in self.rows().enumerate() {
+ for (j, col) in rhs.cols().enumerate() {
+ let dot = dot(row.clone(), col.clone())?;
+ res.set(i, j, dot)?;
+ }
+ }
+ Ok(res)
+ }
+}
+
+pub struct MatrixRows<'a> {
+ matrix: &'a Matrix,
+ row: usize
+}
+
+impl<'a> Iterator for MatrixRows<'a> {
+ type Item = Vec<&'a Value>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ if self.row >= self.matrix.codomain {
+ return None
+ }
+
+ let row_start = self.row * self.matrix.domain;
+ let row_end = row_start + self.matrix.domain;
+
+ let res = self.matrix.values
+ .iter()
+ .enumerate()
+ .filter(|(idx, _)| *idx >= row_start && *idx < row_end)
+ .map(|v| v.1)
+ .collect();
+
+ self.row += 1;
+
+ Some(res)
+ }
+}
+
+pub struct MatrixCols<'a> {
+ matrix: &'a Matrix,
+ col: usize
+}
+
+impl<'a> Iterator for MatrixCols<'a> {
+ type Item = Vec<&'a Value>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ if self.col >= self.matrix.domain {
+ return None
+ }
+
+ let res = self.matrix.values
+ .iter()
+ .enumerate()
+ .filter(|(idx, _)| *idx % self.matrix.domain == self.col)
+ .map(|v| v.1)
+ .collect();
+
+ self.col += 1;
+
+ Some(res)
+ }
+}