matrix math

This commit is contained in:
Freya Murphy 2024-02-23 16:25:02 -05:00
parent a9c5ffe62f
commit 1ebe51c7b3
Signed by: freya
GPG key ID: 744AB800E383AE52
9 changed files with 429 additions and 42 deletions

5
Cargo.lock generated
View file

@ -61,9 +61,9 @@ dependencies = [
[[package]] [[package]]
name = "anyhow" name = "anyhow"
version = "1.0.79" version = "1.0.80"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1"
[[package]] [[package]]
name = "autocfg" name = "autocfg"
@ -248,6 +248,7 @@ dependencies = [
name = "matrix-stdlib" name = "matrix-stdlib"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow",
"matrix", "matrix",
"matrix-macros", "matrix-macros",
] ]

View file

@ -6,5 +6,6 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
anyhow = "1"
matrix = { path = "../matrix" } matrix = { path = "../matrix" }
matrix-macros = { path = "../matrix-macros" } matrix-macros = { path = "../matrix-macros" }

View file

@ -2,6 +2,6 @@ use matrix::vm::Vm;
mod io; mod io;
pub fn load(compiler: &mut Vm) { pub fn load(vm: &mut Vm) {
io::load(compiler); io::load(vm);
} }

View file

@ -221,6 +221,7 @@ impl<'c> Compiler<'c> {
} }
fn get_name(&mut self, name: Rc<str>) -> usize { fn get_name(&mut self, name: Rc<str>) -> usize {
// TODO: find name if already exists
let idx = self.names.borrow().len(); let idx = self.names.borrow().len();
self.names.borrow_mut().push(name); self.names.borrow_mut().push(name);
idx idx
@ -463,7 +464,7 @@ impl<'c> Compiler<'c> {
for expr in &mat.2 { for expr in &mat.2 {
self.compile_expr(expr)?; self.compile_expr(expr)?;
} }
self.emit(I::NewMatrix(mat.2.len() as u16, mat.1 as u8)); self.emit(I::NewMatrix(mat.2.len() as u16, mat.0 as u8));
}, },
E::Table(table) => { E::Table(table) => {
for (key, value) in table { for (key, value) in table {

View file

@ -1,4 +1,4 @@
use std::{ops::{Index, IndexMut, Deref, DerefMut}, marker::PhantomData, ptr::NonNull, fmt::{Debug, Display}}; use std::{ops::{Index, IndexMut, Deref, DerefMut, Add, Sub, Mul}, marker::PhantomData, ptr::NonNull, fmt::{Debug, Display}};
pub struct Gc<T> { pub struct Gc<T> {
ptr: NonNull<GcInner<T>>, ptr: NonNull<GcInner<T>>,
@ -30,6 +30,12 @@ impl <T: Clone> Gc<T> {
Self::new(data) Self::new(data)
} }
} }
fn data(&self) -> T {
unsafe {
self.ptr.as_ref().data.clone()
}
}
} }
impl<T> From<T> for Gc<T> { impl<T> From<T> for Gc<T> {
@ -124,3 +130,33 @@ impl<T: Display> Display for Gc<T> {
write!(f, "{}", self.deref()) write!(f, "{}", self.deref())
} }
} }
impl <T: Add + Clone> Add for Gc<T> {
type Output = T::Output;
fn add(self, rhs: Self) -> Self::Output {
let a = self.data();
let b = rhs.data();
a + b
}
}
impl <T: Sub + Clone> Sub for Gc<T> {
type Output = T::Output;
fn sub(self, rhs: Self) -> Self::Output {
let a = self.data();
let b = rhs.data();
a - b
}
}
impl <T: Mul + Clone> Mul for Gc<T> {
type Output = T::Output;
fn mul(self, rhs: Self) -> Self::Output {
let a = self.data();
let b = rhs.data();
a * b
}
}

View file

@ -21,6 +21,7 @@ enum ErrorInner {
Value(value::Error), Value(value::Error),
Compile(compiler::Error), Compile(compiler::Error),
Runtime(vm::Error), Runtime(vm::Error),
External(anyhow::Error),
} }
impl Display for crate::Error { impl Display for crate::Error {
@ -32,6 +33,7 @@ impl Display for crate::Error {
Value(err) => write!(f, "{err}"), Value(err) => write!(f, "{err}"),
Compile(err) => write!(f, "{err}"), Compile(err) => write!(f, "{err}"),
Runtime(err) => write!(f, "{err}"), Runtime(err) => write!(f, "{err}"),
External(err) => write!(f, "{err}"),
} }
} }
} }
@ -51,5 +53,6 @@ from_error!(parse::Error, Parse);
from_error!(value::Error, Value); from_error!(value::Error, Value);
from_error!(compiler::Error, Compile); from_error!(compiler::Error, Compile);
from_error!(vm::Error, Runtime); from_error!(vm::Error, Runtime);
from_error!(anyhow::Error, External);
pub type Result<T> = std::result::Result<T, crate::Error>; pub type Result<T> = std::result::Result<T, crate::Error>;

View file

@ -202,7 +202,7 @@ impl Parser {
let codomain = parts.len(); let codomain = parts.len();
let domain = parts[0].len(); let domain = parts[0].len();
for (i, part) in parts.iter().enumerate() { for (i, part) in parts.iter().enumerate() {
if part.len() != codomain { if part.len() != domain {
return Err(Error::MatrixInvDomain(i, domain, part.len()).into()) return Err(Error::MatrixInvDomain(i, domain, part.len()).into())
} }
} }

View file

@ -1,4 +1,4 @@
use std::{collections::HashMap, rc::Rc, hash::Hash, fmt::Display, ops::{Add, Neg, Not, Sub, Div, Mul, BitOr, BitAnd, BitXor, Shl, Shr}, cmp::Ordering}; use std::{collections::HashMap, rc::Rc, hash::Hash, fmt::{Display, Debug}, ops::{Add, Neg, Not, Sub, Div, Mul, BitOr, BitAnd, BitXor, Shl, Shr}, cmp::Ordering};
use num_complex::Complex64; use num_complex::Complex64;
use num_rational::Rational64; use num_rational::Rational64;
@ -7,7 +7,6 @@ use regex::Regex;
use crate::{ast::{Expr, BinaryOp, UnaryOp}, chunk::Function, Result, gc::Gc}; use crate::{ast::{Expr, BinaryOp, UnaryOp}, chunk::Function, Result, gc::Gc};
pub type List = Vec<Value>; pub type List = Vec<Value>;
pub type Matrix = (usize, usize, Vec<Value>);
pub type Table = ValueMap; pub type Table = ValueMap;
pub type InlineList = Vec<Expr>; pub type InlineList = Vec<Expr>;
@ -42,7 +41,12 @@ pub enum Error {
Modulo(Value, Value), Modulo(Value, Value),
Exponent(Value, Value), Exponent(Value, Value),
Compare(Value, Value), Compare(Value, Value),
IndexOutOfBounds(usize, usize), IndexOutOfBounds(i64, usize),
MatrixOutOfBounds(usize, usize, usize, usize),
MatrixIndexOutOfBounds(i64, usize, usize),
MatrixInvSlice(usize, usize),
MatrixMultiply(usize, usize),
MatrixAdd(usize, usize, usize, usize),
CannotIndex(Value), CannotIndex(Value),
BadIndex(Value, Value), BadIndex(Value, Value),
Concat(Value, Value), Concat(Value, Value),
@ -66,6 +70,11 @@ impl Display for self::Error {
Exponent(a, b) => write!(f, "cannot compute {a:?} pow {b:?}"), Exponent(a, b) => write!(f, "cannot compute {a:?} pow {b:?}"),
Compare(a, b) => write!(f, "cannot compare {a:?} and {b:?}"), Compare(a, b) => write!(f, "cannot compare {a:?} and {b:?}"),
IndexOutOfBounds(a, b) => write!(f, "index {a} out of bounds for list of length {b}"), IndexOutOfBounds(a, b) => write!(f, "index {a} out of bounds for list of length {b}"),
MatrixOutOfBounds(a, b, c, d) => write!(f, "row col ({a},{b}) is out of bounds for matrix {c}x{d}"),
MatrixIndexOutOfBounds(i, d, c) => write!(f, "row index {i} is out of bounds for matrix {d}x{c}"),
MatrixMultiply(c, d) => write!(f, "cannot multiply a matrix with the domain of {d} and matrix with the codomain of {c}"),
MatrixAdd(a, b, c, d) => write!(f, "cannot add a {a}x{b} matrix and a {c}x{d} matrix"),
MatrixInvSlice(s, e) => write!(f, "invalid matrix slice {s}..{e}"),
BadIndex(a, b) => write!(f, "cannot index {a:?} with {b:?}"), BadIndex(a, b) => write!(f, "cannot index {a:?} with {b:?}"),
CannotIndex(a) => write!(f, "cannot index {a:?}"), CannotIndex(a) => write!(f, "cannot index {a:?}"),
Concat(a, b) => write!(f, "cannot concat {a:?} and {b:?}"), Concat(a, b) => write!(f, "cannot concat {a:?} and {b:?}"),
@ -112,9 +121,9 @@ impl Hash for Value {
} }
} }
Matrix(m) => { Matrix(m) => {
m.0.hash(state); m.domain.hash(state);
m.1.hash(state); m.codomain.hash(state);
for val in m.2.iter() { for val in m.values.iter() {
val.hash(state); val.hash(state);
} }
}, },
@ -139,7 +148,7 @@ impl Value {
} }
} }
Matrix(m) => { Matrix(m) => {
for val in m.2.iter() { for val in m.values.iter() {
val.can_hash()?; val.can_hash()?;
} }
}, },
@ -172,9 +181,9 @@ impl Value {
}, },
Matrix(m) => { Matrix(m) => {
let mut str = "\n".to_string(); let mut str = "\n".to_string();
for row in m.2.chunks(m.0) { for row in m.rows() {
str.push_str(" "); str.push_str(" ");
for (i, v) in row.iter().enumerate() { for (i, v) in row.into_iter().enumerate() {
if i != 0 { if i != 0 {
str.push(' '); str.push(' ');
} }
@ -227,9 +236,9 @@ impl Value {
}, },
Matrix(m) => { Matrix(m) => {
let mut str = "\n".to_string(); let mut str = "\n".to_string();
for row in m.2.chunks(m.0) { for row in m.rows() {
str.push_str(" "); str.push_str(" ");
for (i, v) in row.iter().enumerate() { for (i, v) in row.into_iter().enumerate() {
if i != 0 { if i != 0 {
str.push(' '); str.push(' ');
} }
@ -270,20 +279,23 @@ fn ratio_to_f64(r: Rational64) -> f64 {
} }
fn promote(a: Value, b: Value) -> (Value, Value) { fn promote(a: Value, b: Value) -> (Value, Value) {
use Value::*; use Value as V;
match (&a, &b) { match (&a, &b) {
(Int(x), Ratio(..)) => (Ratio((*x).into()), b), (V::Int(x), V::Ratio(..)) => (V::Ratio((*x).into()), b),
(Int(x), Float(..)) => (Float(*x as f64), b), (V::Int(x), V::Float(..)) => (V::Float(*x as f64), b),
(Int(x), Complex(..)) => (Complex((*x as f64).into()), b), (V::Int(x), V::Complex(..)) => (V::Complex((*x as f64).into()), b),
(Ratio(x), Float(..)) => (Float(ratio_to_f64(*x)), b), (V::Ratio(x), V::Float(..)) => (V::Float(ratio_to_f64(*x)), b),
(Ratio(x), Complex(..)) => (Complex(ratio_to_f64(*x).into()), b), (V::Ratio(x), V::Complex(..)) => (V::Complex(ratio_to_f64(*x).into()), b),
(Float(x), Complex(..)) => (Complex((*x).into()), b), (V::Float(x), V::Complex(..)) => (V::Complex((*x).into()), b),
(Ratio(..), Int(y)) => (a, Ratio((*y).into())), (V::Ratio(..), V::Int(y)) => (a, V::Ratio((*y).into())),
(Float(..), Int(y)) => (a, Float(*y as f64)), (V::Float(..), V::Int(y)) => (a, V::Float(*y as f64)),
(Complex(..), Int(y)) => (a, Complex((*y as f64).into())), (V::Complex(..), V::Int(y)) => (a, V::Complex((*y as f64).into())),
(Float(..), Ratio(y)) => (a, Float(ratio_to_f64(*y))), (V::Float(..), V::Ratio(y)) => (a, V::Float(ratio_to_f64(*y))),
(Complex(..), Ratio(y)) => (a, Complex(ratio_to_f64(*y).into())), (V::Complex(..), V::Ratio(y)) => (a, V::Complex(ratio_to_f64(*y).into())),
(Complex(..), Float(y)) => (a, Complex((*y).into())), (V::Complex(..), V::Float(y)) => (a, V::Complex((*y).into())),
(V::List(l1), V::List(l2)) => (V::Matrix(Matrix::from_list(l1.to_vec()).into()), V::Matrix(Matrix::from_list(l2.to_vec()).into())),
(_, V::List(l)) => (a, V::Matrix(Matrix::from_list(l.to_vec()).into())),
(V::List(l), _) => (V::Matrix(Matrix::from_list(l.to_vec()).into()), b),
_ => (a, b), _ => (a, b),
} }
} }
@ -297,6 +309,7 @@ impl Add for Value {
(Float(x), Float(y)) => Ok(Float(x + y)), (Float(x), Float(y)) => Ok(Float(x + y)),
(Ratio(x), Ratio(y)) => Ok(Ratio(x + y)), (Ratio(x), Ratio(y)) => Ok(Ratio(x + y)),
(Complex(x), Complex(y)) => Ok(Complex(x + y)), (Complex(x), Complex(y)) => Ok(Complex(x + y)),
(Matrix(x), Matrix(y)) => Ok(Matrix((x + y)?.into())),
(String(str), value) => Ok(String(Rc::new( (String(str), value) => Ok(String(Rc::new(
format!("{str}{}", value.boring_print()) format!("{str}{}", value.boring_print())
))), ))),
@ -321,6 +334,7 @@ impl Sub for Value {
(Float(x), Float(y)) => Ok(Float(x - y)), (Float(x), Float(y)) => Ok(Float(x - y)),
(Ratio(x), Ratio(y)) => Ok(Ratio(x - y)), (Ratio(x), Ratio(y)) => Ok(Ratio(x - y)),
(Complex(x), Complex(y)) => Ok(Complex(x - y)), (Complex(x), Complex(y)) => Ok(Complex(x - y)),
(Matrix(x), Matrix(y)) => Ok(Matrix((x - y)?.into())),
(l, r) => Err(Error::Subtract(l, r).into()) (l, r) => Err(Error::Subtract(l, r).into())
} }
} }
@ -335,6 +349,9 @@ impl Mul for Value {
(Float(x), Float(y)) => Ok(Float(x * y)), (Float(x), Float(y)) => Ok(Float(x * y)),
(Ratio(x), Ratio(y)) => Ok(Ratio(x * y)), (Ratio(x), Ratio(y)) => Ok(Ratio(x * y)),
(Complex(x), Complex(y)) => Ok(Complex(x * y)), (Complex(x), Complex(y)) => Ok(Complex(x * y)),
(Matrix(x), Matrix(y)) => Ok(Matrix((x * y)?.into())),
(Matrix(x), r) => Ok(Matrix(x.scale(r)?.into())),
(l, Matrix(y)) => Ok(Matrix(y.scale(l)?.into())),
(l, r) => Err(Error::Multiply(l, r).into()) (l, r) => Err(Error::Multiply(l, r).into())
} }
} }
@ -485,7 +502,7 @@ impl PartialEq for Value {
(Complex(a), Float(b)) => *a == Complex64::from(*b), (Complex(a), Float(b)) => *a == Complex64::from(*b),
(String(a), String(b)) => *a == *b, (String(a), String(b)) => *a == *b,
(List(a), List(b)) => *a == *b, (List(a), List(b)) => *a == *b,
(Matrix(a), Matrix(b)) => a.0 == b.0 && *a.2 == *b.2, (Matrix(a), Matrix(b)) => a == b,
_ => false, _ => false,
} }
} }
@ -508,7 +525,7 @@ impl PartialOrd for Value {
(Float(a), Ratio(b)) => a.partial_cmp(&ratio_to_f64(*b)), (Float(a), Ratio(b)) => a.partial_cmp(&ratio_to_f64(*b)),
(String(a), String(b)) => a.partial_cmp(b), (String(a), String(b)) => a.partial_cmp(b),
(List(a), List(b)) => a.partial_cmp(b), (List(a), List(b)) => a.partial_cmp(b),
(Matrix(a), Matrix(b)) => a.2.partial_cmp(&b.2), (Matrix(a), Matrix(b)) => a.domain.partial_cmp(&b.domain),
_ => None, _ => None,
} }
} }
@ -528,10 +545,16 @@ impl Value {
}, },
(V::List(l), V::Int(i)) => { (V::List(l), V::Int(i)) => {
if *i < 0 || *i as usize >= l.len() { if *i < 0 || *i as usize >= l.len() {
return Err(Error::IndexOutOfBounds(*i as usize, l.len()).into()) return Err(Error::IndexOutOfBounds(*i, l.len()).into())
} }
Ok(l[*i as usize].clone()) Ok(l[*i as usize].clone())
}, },
(V::Matrix(m), V::Int(i)) => {
if *i < 0 || *i as usize >= m.values.len() {
return Err(Error::MatrixIndexOutOfBounds(*i, m.domain, m.codomain).into())
}
Ok(m.values[*i as usize].clone())
},
_ => return Err(self::Error::BadIndex(self.clone(), index.clone()).into()) _ => return Err(self::Error::BadIndex(self.clone(), index.clone()).into())
} }
} }
@ -555,6 +578,40 @@ impl Value {
} }
Ok(V::Table(ret.into())) Ok(V::Table(ret.into()))
} }
V::Matrix(m) => {
let err = || Err(self::Error::BadIndex(self.clone(), V::List(indexes.clone().into())).into());
if indexes.len() != 2 {
return err()
}
let lhs = indexes[0].clone();
let rhs = indexes[1].clone();
match (lhs, rhs) {
(V::Nil, V::Nil) => {
Ok(V::Matrix(m.clone_inside()))
},
(V::Int(row), V::Nil) => {
let Some((_, row)) = m.rows().enumerate().filter(|(idx, _)| *idx as i64 == row).next() else {
return err();
};
let row: Vec<Value> = row.into_iter().map(|e| e.clone()).collect();
Ok(V::Matrix(Matrix::new(row.len(), 1, row).into()))
},
(V::Nil, V::Int(col)) => {
let Some((_, col)) = m.cols().enumerate().filter(|(idx, _)| *idx as i64 == col).next() else {
return err();
};
let col: Vec<Value> = col.into_iter().map(|e| e.clone()).collect();
Ok(V::Matrix(Matrix::new(1, col.len(), col).into()))
},
(V::Int(row), V::Int(col)) => {
if row < 0 || col < 0 {
return err();
}
m.get(row as usize, col as usize)
}
_ => return err()
}
}
_ => return Err(self::Error::CannotIndex(self.clone()).into()) _ => return Err(self::Error::CannotIndex(self.clone()).into())
} }
} }
@ -578,11 +635,18 @@ impl Value {
}, },
(V::List(l), V::Int(i)) => { (V::List(l), V::Int(i)) => {
if *i < 0 || *i as usize >= l.len() { if *i < 0 || *i as usize >= l.len() {
return Err(Error::IndexOutOfBounds(*i as usize, l.len()).into()) return Err(Error::IndexOutOfBounds(*i, l.len()).into())
} }
l[*i as usize] = store; l[*i as usize] = store;
Ok(()) Ok(())
},
(V::Matrix(m), V::Int(i)) => {
if *i < 0 || *i as usize >= m.values.len() {
return Err(Error::MatrixIndexOutOfBounds(*i, m.domain, m.codomain).into())
} }
m.values[*i as usize] = store;
Ok(())
},
_ => return Err(self::Error::BadIndex(err, index.clone()).into()) _ => return Err(self::Error::BadIndex(err, index.clone()).into())
} }
} }
@ -708,3 +772,284 @@ impl Not for Value {
} }
} }
fn dot(lhs: Vec<&Value>, rhs: Vec<&Value>) -> Result<Value> {
let len = lhs.len();
let mut res = Value::Int(0);
for i in 0..len {
let val = (lhs[i].clone() * rhs[i].clone())?;
res = (res + val)?;
}
Ok(res)
}
#[derive(Clone)]
pub struct Matrix {
pub domain: usize,
pub codomain: usize,
pub values: Vec<Value>
}
impl Matrix {
pub fn new(
domain: usize,
codomain: usize,
values: Vec<Value>
) -> Self {
Self {
domain,
codomain,
values
}
}
pub fn from_list(
values: Vec<Value>
) -> Self {
Self {
domain: values.len(),
codomain: 1,
values
}
}
pub fn empty(
domain: usize,
codomain: usize
) -> Self {
let values = (0..(domain * codomain)).into_iter()
.map(|_| Value::Int(0))
.collect();
Self {
domain,
codomain,
values
}
}
pub fn get(&self, row: usize, col: usize) -> Result<Value> {
if row >= self.codomain || col >= self.domain {
return Err(self::Error::MatrixOutOfBounds(row, col, self.domain, self.codomain).into())
}
let idx = col + row * self.codomain;
Ok(self.values[idx].clone())
}
pub fn set(&mut self, row: usize, col: usize, val: Value) -> Result<()> {
if row >= self.codomain || col >= self.domain {
return Err(self::Error::MatrixOutOfBounds(row, col, self.domain, self.codomain).into())
}
let idx = col + row * self.codomain;
self.values[idx] = val;
Ok(())
}
pub fn rows<'a>(&'a self) -> MatrixRows<'a> {
MatrixRows {
matrix: self,
row: 0
}
}
pub fn cols<'a>(&'a self) -> MatrixCols<'a> {
MatrixCols {
matrix: self,
col: 0
}
}
// SPLCIE DOMAIN
pub fn splice_cols(&self, col_start: usize, col_end: usize) -> Result<Self> {
if col_start <= col_end {
return Err(self::Error::MatrixInvSlice(col_start, col_end).into());
}
let mut cols = Vec::new();
for (i, col) in self.cols().enumerate() {
if i >= col_start && i < col_end {
cols.push(col);
}
}
let domain = cols.len();
let codomain = cols[0].len();
let mut res = Self::empty(domain, codomain);
for i in 0..domain {
for j in 0..codomain {
res.set(j, i, cols[i][j].clone())?;
}
}
Ok(res)
}
// SPLICE CODOMAIN
pub fn splice_rows(&self, row_start: usize, row_end: usize) -> Result<Self> {
if row_start <= row_end {
return Err(self::Error::MatrixInvSlice(row_start, row_end).into());
}
let mut rows = Vec::new();
for (i, row) in self.rows().enumerate() {
if i >= row_start && i < row_end {
rows.push(row);
}
}
let domain = rows[0].len();
let codomain = rows.len();
let mut res = Self::empty(domain, codomain);
for i in 0..domain {
for j in 0..codomain {
res.set(j, i, rows[j][i].clone())?;
}
}
Ok(res)
}
pub fn scale(&self, scale: Value) -> Result<Self> {
let values = self.values.iter()
.map(|v| v.clone() * scale.clone())
.collect::<Result<Vec<Value>>>()?;
Ok(Matrix::new(self.domain, self.codomain, values))
}
}
impl Debug for Matrix {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[Matrix {}x{}]", self.domain, self.codomain)
}
}
impl PartialEq for Matrix {
fn eq(&self, other: &Self) -> bool {
if self.domain != other.domain || self.codomain != other.codomain {
return false
}
for i in 0..self.values.len() {
if self.values[i] != other.values[i] {
return false;
}
}
return true;
}
}
impl Add for Matrix {
type Output = Result<Self>;
fn add(self, rhs: Self) -> Self::Output {
if self.domain != rhs.domain || self.codomain != rhs.codomain {
return Err(self::Error::MatrixAdd(self.domain, self.codomain, rhs.domain, rhs.codomain).into())
}
let mut res = Matrix::empty(self.domain, self.codomain);
for col in 0..self.domain {
for row in 0..self.codomain {
let add = self.get(row, col)? + rhs.get(row, col)?;
res.set(row, col, add?)?;
}
}
Ok(res)
}
}
impl Sub for Matrix {
type Output = Result<Self>;
fn sub(self, rhs: Self) -> Self::Output {
if self.domain != rhs.domain || self.codomain != rhs.codomain {
return Err(self::Error::MatrixAdd(self.domain, self.codomain, rhs.domain, rhs.codomain).into())
}
let mut res = Matrix::empty(self.domain, self.codomain);
for col in 0..self.domain {
for row in 0..self.codomain {
let sub = self.get(row, col)? - rhs.get(row, col)?;
res.set(row, col, sub?)?;
}
}
Ok(res)
}
}
impl Mul for Matrix {
type Output = Result<Self>;
fn mul(self, rhs: Self) -> Self::Output {
if self.domain != rhs.codomain {
return Err(self::Error::MatrixMultiply(self.domain, rhs.codomain).into())
}
let mut res = Self::empty(rhs.domain, self.codomain);
for (i, row) in self.rows().enumerate() {
for (j, col) in rhs.cols().enumerate() {
let dot = dot(row.clone(), col.clone())?;
res.set(i, j, dot)?;
}
}
Ok(res)
}
}
pub struct MatrixRows<'a> {
matrix: &'a Matrix,
row: usize
}
impl<'a> Iterator for MatrixRows<'a> {
type Item = Vec<&'a Value>;
fn next(&mut self) -> Option<Self::Item> {
if self.row >= self.matrix.codomain {
return None
}
let row_start = self.row * self.matrix.domain;
let row_end = row_start + self.matrix.domain;
let res = self.matrix.values
.iter()
.enumerate()
.filter(|(idx, _)| *idx >= row_start && *idx < row_end)
.map(|v| v.1)
.collect();
self.row += 1;
Some(res)
}
}
pub struct MatrixCols<'a> {
matrix: &'a Matrix,
col: usize
}
impl<'a> Iterator for MatrixCols<'a> {
type Item = Vec<&'a Value>;
fn next(&mut self) -> Option<Self::Item> {
if self.col >= self.matrix.domain {
return None
}
let res = self.matrix.values
.iter()
.enumerate()
.filter(|(idx, _)| *idx % self.matrix.domain == self.col)
.map(|v| v.1)
.collect();
self.col += 1;
Some(res)
}
}

View file

@ -1,5 +1,5 @@
use std::{rc::Rc, fmt::{Debug, Display}, usize, ops::{Index, IndexMut}, collections::HashMap, cell::RefCell, sync::{atomic::{AtomicUsize, Ordering}, Arc}}; use std::{rc::Rc, fmt::{Debug, Display}, usize, ops::{Index, IndexMut}, collections::HashMap, cell::RefCell, sync::{atomic::{AtomicUsize, Ordering}, Arc}};
use crate::{value::{Value, self, ValueMap}, gc::Gc, chunk::{Function, Instruction, Chunk, InnerFunction}, Result, compiler::NamesTable}; use crate::{value::{Value, self, ValueMap, Matrix}, gc::Gc, chunk::{Function, Instruction, Chunk, InnerFunction}, Result, compiler::NamesTable};
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
@ -233,11 +233,11 @@ impl Vm {
} }
self.push(Value::Table(table.into())) self.push(Value::Table(table.into()))
}, },
NewMatrix(items, codomain) => { NewMatrix(items, domain) => {
let list = self.stack.split_off(self.stack.len() - items as usize); let values = self.stack.split_off(self.stack.len() - items as usize).inner;
let codomain = codomain as usize; let domain = domain as usize;
let domain = list.len() / codomain; let codomain = values.len() / domain;
self.push(Value::Matrix(Gc::new((domain, codomain, list.inner)))); self.push(Value::Matrix(Gc::new(Matrix::new(domain, codomain, values))));
} }
Jump(idx) => { Jump(idx) => {
if self.check_interupt() { if self.check_interupt() {