matrix math
This commit is contained in:
parent
a9c5ffe62f
commit
1ebe51c7b3
9 changed files with 429 additions and 42 deletions
5
Cargo.lock
generated
5
Cargo.lock
generated
|
@ -61,9 +61,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.79"
|
||||
version = "1.0.80"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca"
|
||||
checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
|
@ -248,6 +248,7 @@ dependencies = [
|
|||
name = "matrix-stdlib"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"matrix",
|
||||
"matrix-macros",
|
||||
]
|
||||
|
|
|
@ -6,5 +6,6 @@ edition = "2021"
|
|||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
matrix = { path = "../matrix" }
|
||||
matrix-macros = { path = "../matrix-macros" }
|
||||
|
|
|
@ -2,6 +2,6 @@ use matrix::vm::Vm;
|
|||
|
||||
mod io;
|
||||
|
||||
pub fn load(compiler: &mut Vm) {
|
||||
io::load(compiler);
|
||||
pub fn load(vm: &mut Vm) {
|
||||
io::load(vm);
|
||||
}
|
||||
|
|
|
@ -221,6 +221,7 @@ impl<'c> Compiler<'c> {
|
|||
}
|
||||
|
||||
fn get_name(&mut self, name: Rc<str>) -> usize {
|
||||
// TODO: find name if already exists
|
||||
let idx = self.names.borrow().len();
|
||||
self.names.borrow_mut().push(name);
|
||||
idx
|
||||
|
@ -463,7 +464,7 @@ impl<'c> Compiler<'c> {
|
|||
for expr in &mat.2 {
|
||||
self.compile_expr(expr)?;
|
||||
}
|
||||
self.emit(I::NewMatrix(mat.2.len() as u16, mat.1 as u8));
|
||||
self.emit(I::NewMatrix(mat.2.len() as u16, mat.0 as u8));
|
||||
},
|
||||
E::Table(table) => {
|
||||
for (key, value) in table {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::{ops::{Index, IndexMut, Deref, DerefMut}, marker::PhantomData, ptr::NonNull, fmt::{Debug, Display}};
|
||||
use std::{ops::{Index, IndexMut, Deref, DerefMut, Add, Sub, Mul}, marker::PhantomData, ptr::NonNull, fmt::{Debug, Display}};
|
||||
|
||||
pub struct Gc<T> {
|
||||
ptr: NonNull<GcInner<T>>,
|
||||
|
@ -30,6 +30,12 @@ impl <T: Clone> Gc<T> {
|
|||
Self::new(data)
|
||||
}
|
||||
}
|
||||
|
||||
fn data(&self) -> T {
|
||||
unsafe {
|
||||
self.ptr.as_ref().data.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for Gc<T> {
|
||||
|
@ -124,3 +130,33 @@ impl<T: Display> Display for Gc<T> {
|
|||
write!(f, "{}", self.deref())
|
||||
}
|
||||
}
|
||||
|
||||
impl <T: Add + Clone> Add for Gc<T> {
|
||||
type Output = T::Output;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
let a = self.data();
|
||||
let b = rhs.data();
|
||||
a + b
|
||||
}
|
||||
}
|
||||
|
||||
impl <T: Sub + Clone> Sub for Gc<T> {
|
||||
type Output = T::Output;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
let a = self.data();
|
||||
let b = rhs.data();
|
||||
a - b
|
||||
}
|
||||
}
|
||||
|
||||
impl <T: Mul + Clone> Mul for Gc<T> {
|
||||
type Output = T::Output;
|
||||
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
let a = self.data();
|
||||
let b = rhs.data();
|
||||
a * b
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ enum ErrorInner {
|
|||
Value(value::Error),
|
||||
Compile(compiler::Error),
|
||||
Runtime(vm::Error),
|
||||
External(anyhow::Error),
|
||||
}
|
||||
|
||||
impl Display for crate::Error {
|
||||
|
@ -32,6 +33,7 @@ impl Display for crate::Error {
|
|||
Value(err) => write!(f, "{err}"),
|
||||
Compile(err) => write!(f, "{err}"),
|
||||
Runtime(err) => write!(f, "{err}"),
|
||||
External(err) => write!(f, "{err}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -51,5 +53,6 @@ from_error!(parse::Error, Parse);
|
|||
from_error!(value::Error, Value);
|
||||
from_error!(compiler::Error, Compile);
|
||||
from_error!(vm::Error, Runtime);
|
||||
from_error!(anyhow::Error, External);
|
||||
|
||||
pub type Result<T> = std::result::Result<T, crate::Error>;
|
||||
|
|
|
@ -202,7 +202,7 @@ impl Parser {
|
|||
let codomain = parts.len();
|
||||
let domain = parts[0].len();
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if part.len() != codomain {
|
||||
if part.len() != domain {
|
||||
return Err(Error::MatrixInvDomain(i, domain, part.len()).into())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::{collections::HashMap, rc::Rc, hash::Hash, fmt::Display, ops::{Add, Neg, Not, Sub, Div, Mul, BitOr, BitAnd, BitXor, Shl, Shr}, cmp::Ordering};
|
||||
use std::{collections::HashMap, rc::Rc, hash::Hash, fmt::{Display, Debug}, ops::{Add, Neg, Not, Sub, Div, Mul, BitOr, BitAnd, BitXor, Shl, Shr}, cmp::Ordering};
|
||||
|
||||
use num_complex::Complex64;
|
||||
use num_rational::Rational64;
|
||||
|
@ -7,7 +7,6 @@ use regex::Regex;
|
|||
use crate::{ast::{Expr, BinaryOp, UnaryOp}, chunk::Function, Result, gc::Gc};
|
||||
|
||||
pub type List = Vec<Value>;
|
||||
pub type Matrix = (usize, usize, Vec<Value>);
|
||||
pub type Table = ValueMap;
|
||||
|
||||
pub type InlineList = Vec<Expr>;
|
||||
|
@ -42,7 +41,12 @@ pub enum Error {
|
|||
Modulo(Value, Value),
|
||||
Exponent(Value, Value),
|
||||
Compare(Value, Value),
|
||||
IndexOutOfBounds(usize, usize),
|
||||
IndexOutOfBounds(i64, usize),
|
||||
MatrixOutOfBounds(usize, usize, usize, usize),
|
||||
MatrixIndexOutOfBounds(i64, usize, usize),
|
||||
MatrixInvSlice(usize, usize),
|
||||
MatrixMultiply(usize, usize),
|
||||
MatrixAdd(usize, usize, usize, usize),
|
||||
CannotIndex(Value),
|
||||
BadIndex(Value, Value),
|
||||
Concat(Value, Value),
|
||||
|
@ -66,6 +70,11 @@ impl Display for self::Error {
|
|||
Exponent(a, b) => write!(f, "cannot compute {a:?} pow {b:?}"),
|
||||
Compare(a, b) => write!(f, "cannot compare {a:?} and {b:?}"),
|
||||
IndexOutOfBounds(a, b) => write!(f, "index {a} out of bounds for list of length {b}"),
|
||||
MatrixOutOfBounds(a, b, c, d) => write!(f, "row col ({a},{b}) is out of bounds for matrix {c}x{d}"),
|
||||
MatrixIndexOutOfBounds(i, d, c) => write!(f, "row index {i} is out of bounds for matrix {d}x{c}"),
|
||||
MatrixMultiply(c, d) => write!(f, "cannot multiply a matrix with the domain of {d} and matrix with the codomain of {c}"),
|
||||
MatrixAdd(a, b, c, d) => write!(f, "cannot add a {a}x{b} matrix and a {c}x{d} matrix"),
|
||||
MatrixInvSlice(s, e) => write!(f, "invalid matrix slice {s}..{e}"),
|
||||
BadIndex(a, b) => write!(f, "cannot index {a:?} with {b:?}"),
|
||||
CannotIndex(a) => write!(f, "cannot index {a:?}"),
|
||||
Concat(a, b) => write!(f, "cannot concat {a:?} and {b:?}"),
|
||||
|
@ -112,9 +121,9 @@ impl Hash for Value {
|
|||
}
|
||||
}
|
||||
Matrix(m) => {
|
||||
m.0.hash(state);
|
||||
m.1.hash(state);
|
||||
for val in m.2.iter() {
|
||||
m.domain.hash(state);
|
||||
m.codomain.hash(state);
|
||||
for val in m.values.iter() {
|
||||
val.hash(state);
|
||||
}
|
||||
},
|
||||
|
@ -139,7 +148,7 @@ impl Value {
|
|||
}
|
||||
}
|
||||
Matrix(m) => {
|
||||
for val in m.2.iter() {
|
||||
for val in m.values.iter() {
|
||||
val.can_hash()?;
|
||||
}
|
||||
},
|
||||
|
@ -172,9 +181,9 @@ impl Value {
|
|||
},
|
||||
Matrix(m) => {
|
||||
let mut str = "\n".to_string();
|
||||
for row in m.2.chunks(m.0) {
|
||||
for row in m.rows() {
|
||||
str.push_str(" ");
|
||||
for (i, v) in row.iter().enumerate() {
|
||||
for (i, v) in row.into_iter().enumerate() {
|
||||
if i != 0 {
|
||||
str.push(' ');
|
||||
}
|
||||
|
@ -227,9 +236,9 @@ impl Value {
|
|||
},
|
||||
Matrix(m) => {
|
||||
let mut str = "\n".to_string();
|
||||
for row in m.2.chunks(m.0) {
|
||||
for row in m.rows() {
|
||||
str.push_str(" ");
|
||||
for (i, v) in row.iter().enumerate() {
|
||||
for (i, v) in row.into_iter().enumerate() {
|
||||
if i != 0 {
|
||||
str.push(' ');
|
||||
}
|
||||
|
@ -270,20 +279,23 @@ fn ratio_to_f64(r: Rational64) -> f64 {
|
|||
}
|
||||
|
||||
fn promote(a: Value, b: Value) -> (Value, Value) {
|
||||
use Value::*;
|
||||
use Value as V;
|
||||
match (&a, &b) {
|
||||
(Int(x), Ratio(..)) => (Ratio((*x).into()), b),
|
||||
(Int(x), Float(..)) => (Float(*x as f64), b),
|
||||
(Int(x), Complex(..)) => (Complex((*x as f64).into()), b),
|
||||
(Ratio(x), Float(..)) => (Float(ratio_to_f64(*x)), b),
|
||||
(Ratio(x), Complex(..)) => (Complex(ratio_to_f64(*x).into()), b),
|
||||
(Float(x), Complex(..)) => (Complex((*x).into()), b),
|
||||
(Ratio(..), Int(y)) => (a, Ratio((*y).into())),
|
||||
(Float(..), Int(y)) => (a, Float(*y as f64)),
|
||||
(Complex(..), Int(y)) => (a, Complex((*y as f64).into())),
|
||||
(Float(..), Ratio(y)) => (a, Float(ratio_to_f64(*y))),
|
||||
(Complex(..), Ratio(y)) => (a, Complex(ratio_to_f64(*y).into())),
|
||||
(Complex(..), Float(y)) => (a, Complex((*y).into())),
|
||||
(V::Int(x), V::Ratio(..)) => (V::Ratio((*x).into()), b),
|
||||
(V::Int(x), V::Float(..)) => (V::Float(*x as f64), b),
|
||||
(V::Int(x), V::Complex(..)) => (V::Complex((*x as f64).into()), b),
|
||||
(V::Ratio(x), V::Float(..)) => (V::Float(ratio_to_f64(*x)), b),
|
||||
(V::Ratio(x), V::Complex(..)) => (V::Complex(ratio_to_f64(*x).into()), b),
|
||||
(V::Float(x), V::Complex(..)) => (V::Complex((*x).into()), b),
|
||||
(V::Ratio(..), V::Int(y)) => (a, V::Ratio((*y).into())),
|
||||
(V::Float(..), V::Int(y)) => (a, V::Float(*y as f64)),
|
||||
(V::Complex(..), V::Int(y)) => (a, V::Complex((*y as f64).into())),
|
||||
(V::Float(..), V::Ratio(y)) => (a, V::Float(ratio_to_f64(*y))),
|
||||
(V::Complex(..), V::Ratio(y)) => (a, V::Complex(ratio_to_f64(*y).into())),
|
||||
(V::Complex(..), V::Float(y)) => (a, V::Complex((*y).into())),
|
||||
(V::List(l1), V::List(l2)) => (V::Matrix(Matrix::from_list(l1.to_vec()).into()), V::Matrix(Matrix::from_list(l2.to_vec()).into())),
|
||||
(_, V::List(l)) => (a, V::Matrix(Matrix::from_list(l.to_vec()).into())),
|
||||
(V::List(l), _) => (V::Matrix(Matrix::from_list(l.to_vec()).into()), b),
|
||||
_ => (a, b),
|
||||
}
|
||||
}
|
||||
|
@ -297,6 +309,7 @@ impl Add for Value {
|
|||
(Float(x), Float(y)) => Ok(Float(x + y)),
|
||||
(Ratio(x), Ratio(y)) => Ok(Ratio(x + y)),
|
||||
(Complex(x), Complex(y)) => Ok(Complex(x + y)),
|
||||
(Matrix(x), Matrix(y)) => Ok(Matrix((x + y)?.into())),
|
||||
(String(str), value) => Ok(String(Rc::new(
|
||||
format!("{str}{}", value.boring_print())
|
||||
))),
|
||||
|
@ -321,6 +334,7 @@ impl Sub for Value {
|
|||
(Float(x), Float(y)) => Ok(Float(x - y)),
|
||||
(Ratio(x), Ratio(y)) => Ok(Ratio(x - y)),
|
||||
(Complex(x), Complex(y)) => Ok(Complex(x - y)),
|
||||
(Matrix(x), Matrix(y)) => Ok(Matrix((x - y)?.into())),
|
||||
(l, r) => Err(Error::Subtract(l, r).into())
|
||||
}
|
||||
}
|
||||
|
@ -335,6 +349,9 @@ impl Mul for Value {
|
|||
(Float(x), Float(y)) => Ok(Float(x * y)),
|
||||
(Ratio(x), Ratio(y)) => Ok(Ratio(x * y)),
|
||||
(Complex(x), Complex(y)) => Ok(Complex(x * y)),
|
||||
(Matrix(x), Matrix(y)) => Ok(Matrix((x * y)?.into())),
|
||||
(Matrix(x), r) => Ok(Matrix(x.scale(r)?.into())),
|
||||
(l, Matrix(y)) => Ok(Matrix(y.scale(l)?.into())),
|
||||
(l, r) => Err(Error::Multiply(l, r).into())
|
||||
}
|
||||
}
|
||||
|
@ -485,7 +502,7 @@ impl PartialEq for Value {
|
|||
(Complex(a), Float(b)) => *a == Complex64::from(*b),
|
||||
(String(a), String(b)) => *a == *b,
|
||||
(List(a), List(b)) => *a == *b,
|
||||
(Matrix(a), Matrix(b)) => a.0 == b.0 && *a.2 == *b.2,
|
||||
(Matrix(a), Matrix(b)) => a == b,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
@ -508,7 +525,7 @@ impl PartialOrd for Value {
|
|||
(Float(a), Ratio(b)) => a.partial_cmp(&ratio_to_f64(*b)),
|
||||
(String(a), String(b)) => a.partial_cmp(b),
|
||||
(List(a), List(b)) => a.partial_cmp(b),
|
||||
(Matrix(a), Matrix(b)) => a.2.partial_cmp(&b.2),
|
||||
(Matrix(a), Matrix(b)) => a.domain.partial_cmp(&b.domain),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
@ -528,10 +545,16 @@ impl Value {
|
|||
},
|
||||
(V::List(l), V::Int(i)) => {
|
||||
if *i < 0 || *i as usize >= l.len() {
|
||||
return Err(Error::IndexOutOfBounds(*i as usize, l.len()).into())
|
||||
return Err(Error::IndexOutOfBounds(*i, l.len()).into())
|
||||
}
|
||||
Ok(l[*i as usize].clone())
|
||||
},
|
||||
(V::Matrix(m), V::Int(i)) => {
|
||||
if *i < 0 || *i as usize >= m.values.len() {
|
||||
return Err(Error::MatrixIndexOutOfBounds(*i, m.domain, m.codomain).into())
|
||||
}
|
||||
Ok(m.values[*i as usize].clone())
|
||||
},
|
||||
_ => return Err(self::Error::BadIndex(self.clone(), index.clone()).into())
|
||||
}
|
||||
}
|
||||
|
@ -555,6 +578,40 @@ impl Value {
|
|||
}
|
||||
Ok(V::Table(ret.into()))
|
||||
}
|
||||
V::Matrix(m) => {
|
||||
let err = || Err(self::Error::BadIndex(self.clone(), V::List(indexes.clone().into())).into());
|
||||
if indexes.len() != 2 {
|
||||
return err()
|
||||
}
|
||||
let lhs = indexes[0].clone();
|
||||
let rhs = indexes[1].clone();
|
||||
match (lhs, rhs) {
|
||||
(V::Nil, V::Nil) => {
|
||||
Ok(V::Matrix(m.clone_inside()))
|
||||
},
|
||||
(V::Int(row), V::Nil) => {
|
||||
let Some((_, row)) = m.rows().enumerate().filter(|(idx, _)| *idx as i64 == row).next() else {
|
||||
return err();
|
||||
};
|
||||
let row: Vec<Value> = row.into_iter().map(|e| e.clone()).collect();
|
||||
Ok(V::Matrix(Matrix::new(row.len(), 1, row).into()))
|
||||
},
|
||||
(V::Nil, V::Int(col)) => {
|
||||
let Some((_, col)) = m.cols().enumerate().filter(|(idx, _)| *idx as i64 == col).next() else {
|
||||
return err();
|
||||
};
|
||||
let col: Vec<Value> = col.into_iter().map(|e| e.clone()).collect();
|
||||
Ok(V::Matrix(Matrix::new(1, col.len(), col).into()))
|
||||
},
|
||||
(V::Int(row), V::Int(col)) => {
|
||||
if row < 0 || col < 0 {
|
||||
return err();
|
||||
}
|
||||
m.get(row as usize, col as usize)
|
||||
}
|
||||
_ => return err()
|
||||
}
|
||||
}
|
||||
_ => return Err(self::Error::CannotIndex(self.clone()).into())
|
||||
}
|
||||
}
|
||||
|
@ -578,11 +635,18 @@ impl Value {
|
|||
},
|
||||
(V::List(l), V::Int(i)) => {
|
||||
if *i < 0 || *i as usize >= l.len() {
|
||||
return Err(Error::IndexOutOfBounds(*i as usize, l.len()).into())
|
||||
return Err(Error::IndexOutOfBounds(*i, l.len()).into())
|
||||
}
|
||||
l[*i as usize] = store;
|
||||
Ok(())
|
||||
},
|
||||
(V::Matrix(m), V::Int(i)) => {
|
||||
if *i < 0 || *i as usize >= m.values.len() {
|
||||
return Err(Error::MatrixIndexOutOfBounds(*i, m.domain, m.codomain).into())
|
||||
}
|
||||
m.values[*i as usize] = store;
|
||||
Ok(())
|
||||
},
|
||||
_ => return Err(self::Error::BadIndex(err, index.clone()).into())
|
||||
}
|
||||
}
|
||||
|
@ -708,3 +772,284 @@ impl Not for Value {
|
|||
}
|
||||
}
|
||||
|
||||
fn dot(lhs: Vec<&Value>, rhs: Vec<&Value>) -> Result<Value> {
|
||||
let len = lhs.len();
|
||||
|
||||
let mut res = Value::Int(0);
|
||||
for i in 0..len {
|
||||
let val = (lhs[i].clone() * rhs[i].clone())?;
|
||||
res = (res + val)?;
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Matrix {
|
||||
pub domain: usize,
|
||||
pub codomain: usize,
|
||||
pub values: Vec<Value>
|
||||
}
|
||||
|
||||
impl Matrix {
|
||||
pub fn new(
|
||||
domain: usize,
|
||||
codomain: usize,
|
||||
values: Vec<Value>
|
||||
) -> Self {
|
||||
Self {
|
||||
domain,
|
||||
codomain,
|
||||
values
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_list(
|
||||
values: Vec<Value>
|
||||
) -> Self {
|
||||
Self {
|
||||
domain: values.len(),
|
||||
codomain: 1,
|
||||
values
|
||||
}
|
||||
}
|
||||
|
||||
pub fn empty(
|
||||
domain: usize,
|
||||
codomain: usize
|
||||
) -> Self {
|
||||
let values = (0..(domain * codomain)).into_iter()
|
||||
.map(|_| Value::Int(0))
|
||||
.collect();
|
||||
Self {
|
||||
domain,
|
||||
codomain,
|
||||
values
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get(&self, row: usize, col: usize) -> Result<Value> {
|
||||
if row >= self.codomain || col >= self.domain {
|
||||
return Err(self::Error::MatrixOutOfBounds(row, col, self.domain, self.codomain).into())
|
||||
}
|
||||
let idx = col + row * self.codomain;
|
||||
Ok(self.values[idx].clone())
|
||||
}
|
||||
|
||||
|
||||
|
||||
pub fn set(&mut self, row: usize, col: usize, val: Value) -> Result<()> {
|
||||
if row >= self.codomain || col >= self.domain {
|
||||
return Err(self::Error::MatrixOutOfBounds(row, col, self.domain, self.codomain).into())
|
||||
}
|
||||
let idx = col + row * self.codomain;
|
||||
self.values[idx] = val;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn rows<'a>(&'a self) -> MatrixRows<'a> {
|
||||
MatrixRows {
|
||||
matrix: self,
|
||||
row: 0
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cols<'a>(&'a self) -> MatrixCols<'a> {
|
||||
MatrixCols {
|
||||
matrix: self,
|
||||
col: 0
|
||||
}
|
||||
}
|
||||
|
||||
// SPLCIE DOMAIN
|
||||
pub fn splice_cols(&self, col_start: usize, col_end: usize) -> Result<Self> {
|
||||
if col_start <= col_end {
|
||||
return Err(self::Error::MatrixInvSlice(col_start, col_end).into());
|
||||
}
|
||||
|
||||
let mut cols = Vec::new();
|
||||
|
||||
for (i, col) in self.cols().enumerate() {
|
||||
if i >= col_start && i < col_end {
|
||||
cols.push(col);
|
||||
}
|
||||
}
|
||||
|
||||
let domain = cols.len();
|
||||
let codomain = cols[0].len();
|
||||
let mut res = Self::empty(domain, codomain);
|
||||
|
||||
for i in 0..domain {
|
||||
for j in 0..codomain {
|
||||
res.set(j, i, cols[i][j].clone())?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
// SPLICE CODOMAIN
|
||||
pub fn splice_rows(&self, row_start: usize, row_end: usize) -> Result<Self> {
|
||||
if row_start <= row_end {
|
||||
return Err(self::Error::MatrixInvSlice(row_start, row_end).into());
|
||||
}
|
||||
|
||||
let mut rows = Vec::new();
|
||||
|
||||
for (i, row) in self.rows().enumerate() {
|
||||
if i >= row_start && i < row_end {
|
||||
rows.push(row);
|
||||
}
|
||||
}
|
||||
|
||||
let domain = rows[0].len();
|
||||
let codomain = rows.len();
|
||||
let mut res = Self::empty(domain, codomain);
|
||||
|
||||
for i in 0..domain {
|
||||
for j in 0..codomain {
|
||||
res.set(j, i, rows[j][i].clone())?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub fn scale(&self, scale: Value) -> Result<Self> {
|
||||
let values = self.values.iter()
|
||||
.map(|v| v.clone() * scale.clone())
|
||||
.collect::<Result<Vec<Value>>>()?;
|
||||
Ok(Matrix::new(self.domain, self.codomain, values))
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for Matrix {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "[Matrix {}x{}]", self.domain, self.codomain)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Matrix {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.domain != other.domain || self.codomain != other.codomain {
|
||||
return false
|
||||
}
|
||||
|
||||
for i in 0..self.values.len() {
|
||||
if self.values[i] != other.values[i] {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
impl Add for Matrix {
|
||||
type Output = Result<Self>;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
if self.domain != rhs.domain || self.codomain != rhs.codomain {
|
||||
return Err(self::Error::MatrixAdd(self.domain, self.codomain, rhs.domain, rhs.codomain).into())
|
||||
}
|
||||
let mut res = Matrix::empty(self.domain, self.codomain);
|
||||
for col in 0..self.domain {
|
||||
for row in 0..self.codomain {
|
||||
let add = self.get(row, col)? + rhs.get(row, col)?;
|
||||
res.set(row, col, add?)?;
|
||||
}
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub for Matrix {
|
||||
type Output = Result<Self>;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
if self.domain != rhs.domain || self.codomain != rhs.codomain {
|
||||
return Err(self::Error::MatrixAdd(self.domain, self.codomain, rhs.domain, rhs.codomain).into())
|
||||
}
|
||||
let mut res = Matrix::empty(self.domain, self.codomain);
|
||||
for col in 0..self.domain {
|
||||
for row in 0..self.codomain {
|
||||
let sub = self.get(row, col)? - rhs.get(row, col)?;
|
||||
res.set(row, col, sub?)?;
|
||||
}
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul for Matrix {
|
||||
type Output = Result<Self>;
|
||||
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
if self.domain != rhs.codomain {
|
||||
return Err(self::Error::MatrixMultiply(self.domain, rhs.codomain).into())
|
||||
}
|
||||
let mut res = Self::empty(rhs.domain, self.codomain);
|
||||
for (i, row) in self.rows().enumerate() {
|
||||
for (j, col) in rhs.cols().enumerate() {
|
||||
let dot = dot(row.clone(), col.clone())?;
|
||||
res.set(i, j, dot)?;
|
||||
}
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MatrixRows<'a> {
|
||||
matrix: &'a Matrix,
|
||||
row: usize
|
||||
}
|
||||
|
||||
impl<'a> Iterator for MatrixRows<'a> {
|
||||
type Item = Vec<&'a Value>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.row >= self.matrix.codomain {
|
||||
return None
|
||||
}
|
||||
|
||||
let row_start = self.row * self.matrix.domain;
|
||||
let row_end = row_start + self.matrix.domain;
|
||||
|
||||
let res = self.matrix.values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(idx, _)| *idx >= row_start && *idx < row_end)
|
||||
.map(|v| v.1)
|
||||
.collect();
|
||||
|
||||
self.row += 1;
|
||||
|
||||
Some(res)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MatrixCols<'a> {
|
||||
matrix: &'a Matrix,
|
||||
col: usize
|
||||
}
|
||||
|
||||
impl<'a> Iterator for MatrixCols<'a> {
|
||||
type Item = Vec<&'a Value>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.col >= self.matrix.domain {
|
||||
return None
|
||||
}
|
||||
|
||||
let res = self.matrix.values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(idx, _)| *idx % self.matrix.domain == self.col)
|
||||
.map(|v| v.1)
|
||||
.collect();
|
||||
|
||||
self.col += 1;
|
||||
|
||||
Some(res)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use std::{rc::Rc, fmt::{Debug, Display}, usize, ops::{Index, IndexMut}, collections::HashMap, cell::RefCell, sync::{atomic::{AtomicUsize, Ordering}, Arc}};
|
||||
use crate::{value::{Value, self, ValueMap}, gc::Gc, chunk::{Function, Instruction, Chunk, InnerFunction}, Result, compiler::NamesTable};
|
||||
use crate::{value::{Value, self, ValueMap, Matrix}, gc::Gc, chunk::{Function, Instruction, Chunk, InnerFunction}, Result, compiler::NamesTable};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
|
@ -233,11 +233,11 @@ impl Vm {
|
|||
}
|
||||
self.push(Value::Table(table.into()))
|
||||
},
|
||||
NewMatrix(items, codomain) => {
|
||||
let list = self.stack.split_off(self.stack.len() - items as usize);
|
||||
let codomain = codomain as usize;
|
||||
let domain = list.len() / codomain;
|
||||
self.push(Value::Matrix(Gc::new((domain, codomain, list.inner))));
|
||||
NewMatrix(items, domain) => {
|
||||
let values = self.stack.split_off(self.stack.len() - items as usize).inner;
|
||||
let domain = domain as usize;
|
||||
let codomain = values.len() / domain;
|
||||
self.push(Value::Matrix(Gc::new(Matrix::new(domain, codomain, values))));
|
||||
}
|
||||
Jump(idx) => {
|
||||
if self.check_interupt() {
|
||||
|
|
Loading…
Reference in a new issue