summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock5
-rw-r--r--matrix-stdlib/Cargo.toml1
-rw-r--r--matrix-stdlib/src/lib.rs4
-rw-r--r--matrix/src/compiler.rs3
-rw-r--r--matrix/src/gc.rs38
-rw-r--r--matrix/src/lib.rs3
-rw-r--r--matrix/src/parse.rs2
-rw-r--r--matrix/src/value.rs403
-rw-r--r--matrix/src/vm.rs12
9 files changed, 429 insertions, 42 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 6b4762f..a3a118e 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -61,9 +61,9 @@ dependencies = [
[[package]]
name = "anyhow"
-version = "1.0.79"
+version = "1.0.80"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca"
+checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1"
[[package]]
name = "autocfg"
@@ -248,6 +248,7 @@ dependencies = [
name = "matrix-stdlib"
version = "0.1.0"
dependencies = [
+ "anyhow",
"matrix",
"matrix-macros",
]
diff --git a/matrix-stdlib/Cargo.toml b/matrix-stdlib/Cargo.toml
index a892c04..1c6b0ac 100644
--- a/matrix-stdlib/Cargo.toml
+++ b/matrix-stdlib/Cargo.toml
@@ -6,5 +6,6 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
+anyhow = "1"
matrix = { path = "../matrix" }
matrix-macros = { path = "../matrix-macros" }
diff --git a/matrix-stdlib/src/lib.rs b/matrix-stdlib/src/lib.rs
index 312d397..6e0cfc1 100644
--- a/matrix-stdlib/src/lib.rs
+++ b/matrix-stdlib/src/lib.rs
@@ -2,6 +2,6 @@ use matrix::vm::Vm;
mod io;
-pub fn load(compiler: &mut Vm) {
- io::load(compiler);
+pub fn load(vm: &mut Vm) {
+ io::load(vm);
}
diff --git a/matrix/src/compiler.rs b/matrix/src/compiler.rs
index f573437..a516807 100644
--- a/matrix/src/compiler.rs
+++ b/matrix/src/compiler.rs
@@ -221,6 +221,7 @@ impl<'c> Compiler<'c> {
}
fn get_name(&mut self, name: Rc<str>) -> usize {
+ // TODO: find name if already exists
let idx = self.names.borrow().len();
self.names.borrow_mut().push(name);
idx
@@ -463,7 +464,7 @@ impl<'c> Compiler<'c> {
for expr in &mat.2 {
self.compile_expr(expr)?;
}
- self.emit(I::NewMatrix(mat.2.len() as u16, mat.1 as u8));
+ self.emit(I::NewMatrix(mat.2.len() as u16, mat.0 as u8));
},
E::Table(table) => {
for (key, value) in table {
diff --git a/matrix/src/gc.rs b/matrix/src/gc.rs
index 0f31aea..0c60849 100644
--- a/matrix/src/gc.rs
+++ b/matrix/src/gc.rs
@@ -1,4 +1,4 @@
-use std::{ops::{Index, IndexMut, Deref, DerefMut}, marker::PhantomData, ptr::NonNull, fmt::{Debug, Display}};
+use std::{ops::{Index, IndexMut, Deref, DerefMut, Add, Sub, Mul}, marker::PhantomData, ptr::NonNull, fmt::{Debug, Display}};
pub struct Gc<T> {
ptr: NonNull<GcInner<T>>,
@@ -30,6 +30,12 @@ impl <T: Clone> Gc<T> {
Self::new(data)
}
}
+
+ fn data(&self) -> T {
+ unsafe {
+ self.ptr.as_ref().data.clone()
+ }
+ }
}
impl<T> From<T> for Gc<T> {
@@ -124,3 +130,33 @@ impl<T: Display> Display for Gc<T> {
write!(f, "{}", self.deref())
}
}
+
+impl <T: Add + Clone> Add for Gc<T> {
+ type Output = T::Output;
+
+ fn add(self, rhs: Self) -> Self::Output {
+ let a = self.data();
+ let b = rhs.data();
+ a + b
+ }
+}
+
+impl <T: Sub + Clone> Sub for Gc<T> {
+ type Output = T::Output;
+
+ fn sub(self, rhs: Self) -> Self::Output {
+ let a = self.data();
+ let b = rhs.data();
+ a - b
+ }
+}
+
+impl <T: Mul + Clone> Mul for Gc<T> {
+ type Output = T::Output;
+
+ fn mul(self, rhs: Self) -> Self::Output {
+ let a = self.data();
+ let b = rhs.data();
+ a * b
+ }
+}
diff --git a/matrix/src/lib.rs b/matrix/src/lib.rs
index 3c4732b..fbc1aac 100644
--- a/matrix/src/lib.rs
+++ b/matrix/src/lib.rs
@@ -21,6 +21,7 @@ enum ErrorInner {
Value(value::Error),
Compile(compiler::Error),
Runtime(vm::Error),
+ External(anyhow::Error),
}
impl Display for crate::Error {
@@ -32,6 +33,7 @@ impl Display for crate::Error {
Value(err) => write!(f, "{err}"),
Compile(err) => write!(f, "{err}"),
Runtime(err) => write!(f, "{err}"),
+ External(err) => write!(f, "{err}"),
}
}
}
@@ -51,5 +53,6 @@ from_error!(parse::Error, Parse);
from_error!(value::Error, Value);
from_error!(compiler::Error, Compile);
from_error!(vm::Error, Runtime);
+from_error!(anyhow::Error, External);
pub type Result<T> = std::result::Result<T, crate::Error>;
diff --git a/matrix/src/parse.rs b/matrix/src/parse.rs
index b505415..c6c503b 100644
--- a/matrix/src/parse.rs
+++ b/matrix/src/parse.rs
@@ -202,7 +202,7 @@ impl Parser {
let codomain = parts.len();
let domain = parts[0].len();
for (i, part) in parts.iter().enumerate() {
- if part.len() != codomain {
+ if part.len() != domain {
return Err(Error::MatrixInvDomain(i, domain, part.len()).into())
}
}
diff --git a/matrix/src/value.rs b/matrix/src/value.rs
index 1e75902..2c86226 100644
--- a/matrix/src/value.rs
+++ b/matrix/src/value.rs
@@ -1,4 +1,4 @@
-use std::{collections::HashMap, rc::Rc, hash::Hash, fmt::Display, ops::{Add, Neg, Not, Sub, Div, Mul, BitOr, BitAnd, BitXor, Shl, Shr}, cmp::Ordering};
+use std::{collections::HashMap, rc::Rc, hash::Hash, fmt::{Display, Debug}, ops::{Add, Neg, Not, Sub, Div, Mul, BitOr, BitAnd, BitXor, Shl, Shr}, cmp::Ordering};
use num_complex::Complex64;
use num_rational::Rational64;
@@ -7,7 +7,6 @@ use regex::Regex;
use crate::{ast::{Expr, BinaryOp, UnaryOp}, chunk::Function, Result, gc::Gc};
pub type List = Vec<Value>;
-pub type Matrix = (usize, usize, Vec<Value>);
pub type Table = ValueMap;
pub type InlineList = Vec<Expr>;
@@ -42,7 +41,12 @@ pub enum Error {
Modulo(Value, Value),
Exponent(Value, Value),
Compare(Value, Value),
- IndexOutOfBounds(usize, usize),
+ IndexOutOfBounds(i64, usize),
+ MatrixOutOfBounds(usize, usize, usize, usize),
+ MatrixIndexOutOfBounds(i64, usize, usize),
+ MatrixInvSlice(usize, usize),
+ MatrixMultiply(usize, usize),
+ MatrixAdd(usize, usize, usize, usize),
CannotIndex(Value),
BadIndex(Value, Value),
Concat(Value, Value),
@@ -66,6 +70,11 @@ impl Display for self::Error {
Exponent(a, b) => write!(f, "cannot compute {a:?} pow {b:?}"),
Compare(a, b) => write!(f, "cannot compare {a:?} and {b:?}"),
IndexOutOfBounds(a, b) => write!(f, "index {a} out of bounds for list of length {b}"),
+ MatrixOutOfBounds(a, b, c, d) => write!(f, "row col ({a},{b}) is out of bounds for matrix {c}x{d}"),
+ MatrixIndexOutOfBounds(i, d, c) => write!(f, "row index {i} is out of bounds for matrix {d}x{c}"),
+ MatrixMultiply(c, d) => write!(f, "cannot multiply a matrix with the domain of {d} and matrix with the codomain of {c}"),
+ MatrixAdd(a, b, c, d) => write!(f, "cannot add a {a}x{b} matrix and a {c}x{d} matrix"),
+ MatrixInvSlice(s, e) => write!(f, "invalid matrix slice {s}..{e}"),
BadIndex(a, b) => write!(f, "cannot index {a:?} with {b:?}"),
CannotIndex(a) => write!(f, "cannot index {a:?}"),
Concat(a, b) => write!(f, "cannot concat {a:?} and {b:?}"),
@@ -112,9 +121,9 @@ impl Hash for Value {
}
}
Matrix(m) => {
- m.0.hash(state);
- m.1.hash(state);
- for val in m.2.iter() {
+ m.domain.hash(state);
+ m.codomain.hash(state);
+ for val in m.values.iter() {
val.hash(state);
}
},
@@ -139,7 +148,7 @@ impl Value {
}
}
Matrix(m) => {
- for val in m.2.iter() {
+ for val in m.values.iter() {
val.can_hash()?;
}
},
@@ -172,9 +181,9 @@ impl Value {
},
Matrix(m) => {
let mut str = "\n".to_string();
- for row in m.2.chunks(m.0) {
+ for row in m.rows() {
str.push_str(" ");
- for (i, v) in row.iter().enumerate() {
+ for (i, v) in row.into_iter().enumerate() {
if i != 0 {
str.push(' ');
}
@@ -227,9 +236,9 @@ impl Value {
},
Matrix(m) => {
let mut str = "\n".to_string();
- for row in m.2.chunks(m.0) {
+ for row in m.rows() {
str.push_str(" ");
- for (i, v) in row.iter().enumerate() {
+ for (i, v) in row.into_iter().enumerate() {
if i != 0 {
str.push(' ');
}
@@ -270,20 +279,23 @@ fn ratio_to_f64(r: Rational64) -> f64 {
}
fn promote(a: Value, b: Value) -> (Value, Value) {
- use Value::*;
+ use Value as V;
match (&a, &b) {
- (Int(x), Ratio(..)) => (Ratio((*x).into()), b),
- (Int(x), Float(..)) => (Float(*x as f64), b),
- (Int(x), Complex(..)) => (Complex((*x as f64).into()), b),
- (Ratio(x), Float(..)) => (Float(ratio_to_f64(*x)), b),
- (Ratio(x), Complex(..)) => (Complex(ratio_to_f64(*x).into()), b),
- (Float(x), Complex(..)) => (Complex((*x).into()), b),
- (Ratio(..), Int(y)) => (a, Ratio((*y).into())),
- (Float(..), Int(y)) => (a, Float(*y as f64)),
- (Complex(..), Int(y)) => (a, Complex((*y as f64).into())),
- (Float(..), Ratio(y)) => (a, Float(ratio_to_f64(*y))),
- (Complex(..), Ratio(y)) => (a, Complex(ratio_to_f64(*y).into())),
- (Complex(..), Float(y)) => (a, Complex((*y).into())),
+ (V::Int(x), V::Ratio(..)) => (V::Ratio((*x).into()), b),
+ (V::Int(x), V::Float(..)) => (V::Float(*x as f64), b),
+ (V::Int(x), V::Complex(..)) => (V::Complex((*x as f64).into()), b),
+ (V::Ratio(x), V::Float(..)) => (V::Float(ratio_to_f64(*x)), b),
+ (V::Ratio(x), V::Complex(..)) => (V::Complex(ratio_to_f64(*x).into()), b),
+ (V::Float(x), V::Complex(..)) => (V::Complex((*x).into()), b),
+ (V::Ratio(..), V::Int(y)) => (a, V::Ratio((*y).into())),
+ (V::Float(..), V::Int(y)) => (a, V::Float(*y as f64)),
+ (V::Complex(..), V::Int(y)) => (a, V::Complex((*y as f64).into())),
+ (V::Float(..), V::Ratio(y)) => (a, V::Float(ratio_to_f64(*y))),
+ (V::Complex(..), V::Ratio(y)) => (a, V::Complex(ratio_to_f64(*y).into())),
+ (V::Complex(..), V::Float(y)) => (a, V::Complex((*y).into())),
+ (V::List(l1), V::List(l2)) => (V::Matrix(Matrix::from_list(l1.to_vec()).into()), V::Matrix(Matrix::from_list(l2.to_vec()).into())),
+ (_, V::List(l)) => (a, V::Matrix(Matrix::from_list(l.to_vec()).into())),
+ (V::List(l), _) => (V::Matrix(Matrix::from_list(l.to_vec()).into()), b),
_ => (a, b),
}
}
@@ -297,6 +309,7 @@ impl Add for Value {
(Float(x), Float(y)) => Ok(Float(x + y)),
(Ratio(x), Ratio(y)) => Ok(Ratio(x + y)),
(Complex(x), Complex(y)) => Ok(Complex(x + y)),
+ (Matrix(x), Matrix(y)) => Ok(Matrix((x + y)?.into())),
(String(str), value) => Ok(String(Rc::new(
format!("{str}{}", value.boring_print())
))),
@@ -321,6 +334,7 @@ impl Sub for Value {
(Float(x), Float(y)) => Ok(Float(x - y)),
(Ratio(x), Ratio(y)) => Ok(Ratio(x - y)),
(Complex(x), Complex(y)) => Ok(Complex(x - y)),
+ (Matrix(x), Matrix(y)) => Ok(Matrix((x - y)?.into())),
(l, r) => Err(Error::Subtract(l, r).into())
}
}
@@ -335,6 +349,9 @@ impl Mul for Value {
(Float(x), Float(y)) => Ok(Float(x * y)),
(Ratio(x), Ratio(y)) => Ok(Ratio(x * y)),
(Complex(x), Complex(y)) => Ok(Complex(x * y)),
+ (Matrix(x), Matrix(y)) => Ok(Matrix((x * y)?.into())),
+ (Matrix(x), r) => Ok(Matrix(x.scale(r)?.into())),
+ (l, Matrix(y)) => Ok(Matrix(y.scale(l)?.into())),
(l, r) => Err(Error::Multiply(l, r).into())
}
}
@@ -485,7 +502,7 @@ impl PartialEq for Value {
(Complex(a), Float(b)) => *a == Complex64::from(*b),
(String(a), String(b)) => *a == *b,
(List(a), List(b)) => *a == *b,
- (Matrix(a), Matrix(b)) => a.0 == b.0 && *a.2 == *b.2,
+ (Matrix(a), Matrix(b)) => a == b,
_ => false,
}
}
@@ -508,7 +525,7 @@ impl PartialOrd for Value {
(Float(a), Ratio(b)) => a.partial_cmp(&ratio_to_f64(*b)),
(String(a), String(b)) => a.partial_cmp(b),
(List(a), List(b)) => a.partial_cmp(b),
- (Matrix(a), Matrix(b)) => a.2.partial_cmp(&b.2),
+ (Matrix(a), Matrix(b)) => a.domain.partial_cmp(&b.domain),
_ => None,
}
}
@@ -528,10 +545,16 @@ impl Value {
},
(V::List(l), V::Int(i)) => {
if *i < 0 || *i as usize >= l.len() {
- return Err(Error::IndexOutOfBounds(*i as usize, l.len()).into())
+ return Err(Error::IndexOutOfBounds(*i, l.len()).into())
}
Ok(l[*i as usize].clone())
},
+ (V::Matrix(m), V::Int(i)) => {
+ if *i < 0 || *i as usize >= m.values.len() {
+ return Err(Error::MatrixIndexOutOfBounds(*i, m.domain, m.codomain).into())
+ }
+ Ok(m.values[*i as usize].clone())
+ },
_ => return Err(self::Error::BadIndex(self.clone(), index.clone()).into())
}
}
@@ -555,6 +578,40 @@ impl Value {
}
Ok(V::Table(ret.into()))
}
+ V::Matrix(m) => {
+ let err = || Err(self::Error::BadIndex(self.clone(), V::List(indexes.clone().into())).into());
+ if indexes.len() != 2 {
+ return err()
+ }
+ let lhs = indexes[0].clone();
+ let rhs = indexes[1].clone();
+ match (lhs, rhs) {
+ (V::Nil, V::Nil) => {
+ Ok(V::Matrix(m.clone_inside()))
+ },
+ (V::Int(row), V::Nil) => {
+ let Some((_, row)) = m.rows().enumerate().filter(|(idx, _)| *idx as i64 == row).next() else {
+ return err();
+ };
+ let row: Vec<Value> = row.into_iter().map(|e| e.clone()).collect();
+ Ok(V::Matrix(Matrix::new(row.len(), 1, row).into()))
+ },
+ (V::Nil, V::Int(col)) => {
+ let Some((_, col)) = m.cols().enumerate().filter(|(idx, _)| *idx as i64 == col).next() else {
+ return err();
+ };
+ let col: Vec<Value> = col.into_iter().map(|e| e.clone()).collect();
+ Ok(V::Matrix(Matrix::new(1, col.len(), col).into()))
+ },
+ (V::Int(row), V::Int(col)) => {
+ if row < 0 || col < 0 {
+ return err();
+ }
+ m.get(row as usize, col as usize)
+ }
+ _ => return err()
+ }
+ }
_ => return Err(self::Error::CannotIndex(self.clone()).into())
}
}
@@ -578,11 +635,18 @@ impl Value {
},
(V::List(l), V::Int(i)) => {
if *i < 0 || *i as usize >= l.len() {
- return Err(Error::IndexOutOfBounds(*i as usize, l.len()).into())
+ return Err(Error::IndexOutOfBounds(*i, l.len()).into())
}
l[*i as usize] = store;
Ok(())
- }
+ },
+ (V::Matrix(m), V::Int(i)) => {
+ if *i < 0 || *i as usize >= m.values.len() {
+ return Err(Error::MatrixIndexOutOfBounds(*i, m.domain, m.codomain).into())
+ }
+ m.values[*i as usize] = store;
+ Ok(())
+ },
_ => return Err(self::Error::BadIndex(err, index.clone()).into())
}
}
@@ -708,3 +772,284 @@ impl Not for Value {
}
}
+fn dot(lhs: Vec<&Value>, rhs: Vec<&Value>) -> Result<Value> {
+ let len = lhs.len();
+
+ let mut res = Value::Int(0);
+ for i in 0..len {
+ let val = (lhs[i].clone() * rhs[i].clone())?;
+ res = (res + val)?;
+ }
+
+ Ok(res)
+}
+
+#[derive(Clone)]
+pub struct Matrix {
+ pub domain: usize,
+ pub codomain: usize,
+ pub values: Vec<Value>
+}
+
+impl Matrix {
+ pub fn new(
+ domain: usize,
+ codomain: usize,
+ values: Vec<Value>
+ ) -> Self {
+ Self {
+ domain,
+ codomain,
+ values
+ }
+ }
+
+ pub fn from_list(
+ values: Vec<Value>
+ ) -> Self {
+ Self {
+ domain: values.len(),
+ codomain: 1,
+ values
+ }
+ }
+
+ pub fn empty(
+ domain: usize,
+ codomain: usize
+ ) -> Self {
+ let values = (0..(domain * codomain)).into_iter()
+ .map(|_| Value::Int(0))
+ .collect();
+ Self {
+ domain,
+ codomain,
+ values
+ }
+ }
+
+ pub fn get(&self, row: usize, col: usize) -> Result<Value> {
+ if row >= self.codomain || col >= self.domain {
+ return Err(self::Error::MatrixOutOfBounds(row, col, self.domain, self.codomain).into())
+ }
+ let idx = col + row * self.codomain;
+ Ok(self.values[idx].clone())
+ }
+
+
+
+ pub fn set(&mut self, row: usize, col: usize, val: Value) -> Result<()> {
+ if row >= self.codomain || col >= self.domain {
+ return Err(self::Error::MatrixOutOfBounds(row, col, self.domain, self.codomain).into())
+ }
+ let idx = col + row * self.codomain;
+ self.values[idx] = val;
+ Ok(())
+ }
+
+ pub fn rows<'a>(&'a self) -> MatrixRows<'a> {
+ MatrixRows {
+ matrix: self,
+ row: 0
+ }
+ }
+
+ pub fn cols<'a>(&'a self) -> MatrixCols<'a> {
+ MatrixCols {
+ matrix: self,
+ col: 0
+ }
+ }
+
+ // SPLCIE DOMAIN
+ pub fn splice_cols(&self, col_start: usize, col_end: usize) -> Result<Self> {
+ if col_start <= col_end {
+ return Err(self::Error::MatrixInvSlice(col_start, col_end).into());
+ }
+
+ let mut cols = Vec::new();
+
+ for (i, col) in self.cols().enumerate() {
+ if i >= col_start && i < col_end {
+ cols.push(col);
+ }
+ }
+
+ let domain = cols.len();
+ let codomain = cols[0].len();
+ let mut res = Self::empty(domain, codomain);
+
+ for i in 0..domain {
+ for j in 0..codomain {
+ res.set(j, i, cols[i][j].clone())?;
+ }
+ }
+
+ Ok(res)
+ }
+
+ // SPLICE CODOMAIN
+ pub fn splice_rows(&self, row_start: usize, row_end: usize) -> Result<Self> {
+ if row_start <= row_end {
+ return Err(self::Error::MatrixInvSlice(row_start, row_end).into());
+ }
+
+ let mut rows = Vec::new();
+
+ for (i, row) in self.rows().enumerate() {
+ if i >= row_start && i < row_end {
+ rows.push(row);
+ }
+ }
+
+ let domain = rows[0].len();
+ let codomain = rows.len();
+ let mut res = Self::empty(domain, codomain);
+
+ for i in 0..domain {
+ for j in 0..codomain {
+ res.set(j, i, rows[j][i].clone())?;
+ }
+ }
+
+ Ok(res)
+ }
+
+ pub fn scale(&self, scale: Value) -> Result<Self> {
+ let values = self.values.iter()
+ .map(|v| v.clone() * scale.clone())
+ .collect::<Result<Vec<Value>>>()?;
+ Ok(Matrix::new(self.domain, self.codomain, values))
+ }
+}
+
+impl Debug for Matrix {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "[Matrix {}x{}]", self.domain, self.codomain)
+ }
+}
+
+impl PartialEq for Matrix {
+ fn eq(&self, other: &Self) -> bool {
+ if self.domain != other.domain || self.codomain != other.codomain {
+ return false
+ }
+
+ for i in 0..self.values.len() {
+ if self.values[i] != other.values[i] {
+ return false;
+ }
+ }
+
+ return true;
+ }
+}
+
+impl Add for Matrix {
+ type Output = Result<Self>;
+
+ fn add(self, rhs: Self) -> Self::Output {
+ if self.domain != rhs.domain || self.codomain != rhs.codomain {
+ return Err(self::Error::MatrixAdd(self.domain, self.codomain, rhs.domain, rhs.codomain).into())
+ }
+ let mut res = Matrix::empty(self.domain, self.codomain);
+ for col in 0..self.domain {
+ for row in 0..self.codomain {
+ let add = self.get(row, col)? + rhs.get(row, col)?;
+ res.set(row, col, add?)?;
+ }
+ }
+ Ok(res)
+ }
+}
+
+impl Sub for Matrix {
+ type Output = Result<Self>;
+
+ fn sub(self, rhs: Self) -> Self::Output {
+ if self.domain != rhs.domain || self.codomain != rhs.codomain {
+ return Err(self::Error::MatrixAdd(self.domain, self.codomain, rhs.domain, rhs.codomain).into())
+ }
+ let mut res = Matrix::empty(self.domain, self.codomain);
+ for col in 0..self.domain {
+ for row in 0..self.codomain {
+ let sub = self.get(row, col)? - rhs.get(row, col)?;
+ res.set(row, col, sub?)?;
+ }
+ }
+ Ok(res)
+ }
+}
+
+impl Mul for Matrix {
+ type Output = Result<Self>;
+
+ fn mul(self, rhs: Self) -> Self::Output {
+ if self.domain != rhs.codomain {
+ return Err(self::Error::MatrixMultiply(self.domain, rhs.codomain).into())
+ }
+ let mut res = Self::empty(rhs.domain, self.codomain);
+ for (i, row) in self.rows().enumerate() {
+ for (j, col) in rhs.cols().enumerate() {
+ let dot = dot(row.clone(), col.clone())?;
+ res.set(i, j, dot)?;
+ }
+ }
+ Ok(res)
+ }
+}
+
+pub struct MatrixRows<'a> {
+ matrix: &'a Matrix,
+ row: usize
+}
+
+impl<'a> Iterator for MatrixRows<'a> {
+ type Item = Vec<&'a Value>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ if self.row >= self.matrix.codomain {
+ return None
+ }
+
+ let row_start = self.row * self.matrix.domain;
+ let row_end = row_start + self.matrix.domain;
+
+ let res = self.matrix.values
+ .iter()
+ .enumerate()
+ .filter(|(idx, _)| *idx >= row_start && *idx < row_end)
+ .map(|v| v.1)
+ .collect();
+
+ self.row += 1;
+
+ Some(res)
+ }
+}
+
+pub struct MatrixCols<'a> {
+ matrix: &'a Matrix,
+ col: usize
+}
+
+impl<'a> Iterator for MatrixCols<'a> {
+ type Item = Vec<&'a Value>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ if self.col >= self.matrix.domain {
+ return None
+ }
+
+ let res = self.matrix.values
+ .iter()
+ .enumerate()
+ .filter(|(idx, _)| *idx % self.matrix.domain == self.col)
+ .map(|v| v.1)
+ .collect();
+
+ self.col += 1;
+
+ Some(res)
+ }
+}
diff --git a/matrix/src/vm.rs b/matrix/src/vm.rs
index dc6931f..bb105d8 100644
--- a/matrix/src/vm.rs
+++ b/matrix/src/vm.rs
@@ -1,5 +1,5 @@
use std::{rc::Rc, fmt::{Debug, Display}, usize, ops::{Index, IndexMut}, collections::HashMap, cell::RefCell, sync::{atomic::{AtomicUsize, Ordering}, Arc}};
-use crate::{value::{Value, self, ValueMap}, gc::Gc, chunk::{Function, Instruction, Chunk, InnerFunction}, Result, compiler::NamesTable};
+use crate::{value::{Value, self, ValueMap, Matrix}, gc::Gc, chunk::{Function, Instruction, Chunk, InnerFunction}, Result, compiler::NamesTable};
#[derive(Debug)]
pub enum Error {
@@ -233,11 +233,11 @@ impl Vm {
}
self.push(Value::Table(table.into()))
},
- NewMatrix(items, codomain) => {
- let list = self.stack.split_off(self.stack.len() - items as usize);
- let codomain = codomain as usize;
- let domain = list.len() / codomain;
- self.push(Value::Matrix(Gc::new((domain, codomain, list.inner))));
+ NewMatrix(items, domain) => {
+ let values = self.stack.split_off(self.stack.len() - items as usize).inner;
+ let domain = domain as usize;
+ let codomain = values.len() / domain;
+ self.push(Value::Matrix(Gc::new(Matrix::new(domain, codomain, values))));
}
Jump(idx) => {
if self.check_interupt() {