diff options
Diffstat (limited to 'matrix-stdlib/src/math.rs')
-rw-r--r-- | matrix-stdlib/src/math.rs | 518 |
1 files changed, 518 insertions, 0 deletions
diff --git a/matrix-stdlib/src/math.rs b/matrix-stdlib/src/math.rs new file mode 100644 index 0000000..3226af5 --- /dev/null +++ b/matrix-stdlib/src/math.rs @@ -0,0 +1,518 @@ +use core::f64; +use std::f64::{consts::{PI, E, TAU}, NAN, INFINITY}; + +use matrix::{vm::Vm, value::{Value, Matrix}, Result, unpack_args, Rational64, Complex64}; +use matrix_macros::native_func; +use crate::{error, VmArgs}; + +#[native_func(1)] +fn trans(_: VmArgs, args: Vec<Value>) -> Result<Value> { + let [value] = unpack_args!(args); + let mat = match value { + Value::Matrix(m) => m, + Value::List(l) => Matrix::from_list(l.to_vec()).into(), + _ => return error!("trans must be given a matrix") + }; + let values = mat + .cols() + .reduce(|mut a, b| {a.extend(b); a}) + .unwrap() + .into_iter() + .map(|e| e.clone()) + .collect(); + Ok(Value::Matrix(Matrix::new(mat.codomain, mat.domain, values).into())) +} + +fn mat_gauss_row_operation( + r1: usize, + r2: usize, + scale: Value, + mat: &mut Matrix +) -> Result<()> { + for col in 0..mat.domain { + let r1v = mat.get(r1, col)?; + let r2v = mat.get(r2, col)?; + let res = (r1v - (r2v * scale.clone())?)?; + mat.set(r1, col, res)?; + } + Ok(()) +} + +fn mat_swap_rows( + r1: usize, + r2: usize, + mat: &mut Matrix +) -> Result<()> { + let cols = mat.domain; + for col in 0..cols { + let a = mat.get(r1, col)?; + let b = mat.get(r2, col)?; + mat.set(r2, col, a)?; + mat.set(r1, col, b)?; + } + Ok(()) +} + +fn mat_find_non_zero_col( + mat: &Matrix +) -> Option<usize> { + for (i,col) in mat.cols().enumerate() { + for val in col.iter() { + if **val != Value::Int(0) { + return Some(i) + } + } + } + return None +} + +fn mat_scale_pivot_row( + row: usize, + mat: &mut Matrix +) -> Result<()> { + let scale = mat.get(row, row)?; + if scale.is_zero() { + return Ok(()) + } + for col in 0..mat.domain { + let res = (mat.get(row, col)?.clone() / scale.clone())?; + mat.set(row, col, res)?; + } + Ok(()) +} + +fn mat_get_non_zero_pivot_row( + row: usize, + mat: &mut Matrix, +) -> Result<()> { + let col = row; + let test = mat.get(row, col)?; + if test.is_zero() { + for r in row..mat.codomain { + let cur = mat.get(r, col)?; + if !cur.is_zero() { + mat_swap_rows(row, r, mat)?; + break; + } + } + } + mat_scale_pivot_row(row, mat)?; + Ok(()) +} + +fn mat_rref(mat: Matrix) -> Result<Matrix> { + let mut mat = mat; + let Some(start) = mat_find_non_zero_col(&mat) else { + return Ok(mat) + }; + let end = mat.domain.min(mat.codomain); + for col in start..end { + let pivot_row = col; + mat_get_non_zero_pivot_row(pivot_row, &mut mat)?; + if mat.get(pivot_row, col)?.is_zero() { + break + } + for row in 0..mat.codomain { + if row == pivot_row { continue; }; + let scale = mat.get(row, col)?; + mat_gauss_row_operation(row, pivot_row, scale, &mut mat)?; + } + } + Ok(mat) +} + +#[native_func(1)] +fn rref(_: VmArgs, args: Vec<Value>) -> Result<Value> { + let [value] = unpack_args!(args); + let mat = match value { + Value::Matrix(m) => m, + Value::List(l) => Matrix::from_list(l.to_vec()).into(), + _ => return error!("trans must be given a matrix") + }; + Ok(Value::Matrix(mat_rref(mat.into_inner())?.into())) +} + +fn mat_det(mat: Matrix) -> Result<Value> { + if mat.domain == 1 { + return Ok(mat.get(0,0)?) + } + if mat.domain == 2 { + let a = mat.get(0,0)? * mat.get(1,1)?; + let b = mat.get(0,1)? * mat.get(1,0)?; + return Ok((a? - b?)?) + } + let mut res = Value::Int(0); + for col in 0..mat.domain { + let sub_values = mat.rows() + .skip(1) + .map(|r| + r.into_iter() + .enumerate() + .filter(|(idx,_)| *idx != col) + .map(|(_, v)| v.clone()) + .collect::<Vec<Value>>() + ) + .reduce(|mut a, b| {a.extend(b); a}) + .unwrap(); + let sub = Matrix::new(mat.domain - 1, mat.domain - 1, sub_values); + let val = mat.get(0, col)?; + let part = (val * mat_det(sub)?)?; + if col % 2 == 0 { + res = (res + part)?; + } else { + res = (res - part)?; + } + } + Ok(res) +} + +#[native_func(1)] +fn det(_: VmArgs, args: Vec<Value>) -> Result<Value> { + let [value] = unpack_args!(args); + let mat = match value { + Value::Matrix(m) if m.domain == m.codomain => m, + Value::List(l) if l.len() == 1 => Matrix::from_list(l.to_vec()).into(), + _ => return error!("det requires a square matrix") + }; + let mat = mat.into_inner(); + Ok(mat_det(mat)?) +} + +fn mat_ident(dim: usize) -> Matrix { + let len = dim * dim; + let mut values = vec![Value::Int(0); len]; + let mut idx = 0; + loop { + if idx >= len { break }; + values[idx] = Value::Int(1); + idx += dim + 1; + } + Matrix::new(dim, dim, values) +} + +#[native_func(1)] +fn ident(_: VmArgs, args: Vec<Value>) -> Result<Value> { + let [value] = unpack_args!(args); + let dim = match value { + Value::Int(i) if i > 0 => i, + Value::Ratio(r) + if *r.denom() == 1 && + *r.numer() > 0 + => *r.numer(), + _ => return error!("ident requries a positive [Int] dimension") + }; + Ok(Value::Matrix(mat_ident(dim as usize).into())) +} + +fn mat_splith(mat: Matrix) -> (Matrix, Matrix) { + let mut m1 = Vec::new(); + let mut m2 = Vec::new(); + + mat.rows() + .for_each(|r| { + let split = r.len() / 2; + r.into_iter().enumerate().for_each(|(i, v)| { + if i < split { + m1.push(v.clone()); + } else { + m2.push(v.clone()); + } + }) + }); + + let m1 = Matrix::new(mat.domain/2, mat.codomain, m1); + let m2 = Matrix::new(mat.domain/2, mat.codomain, m2); + (m1, m2) +} + +#[native_func(1)] +fn inv(_: VmArgs, args: Vec<Value>) -> Result<Value> { + let [value] = unpack_args!(args); + let mat = match value { + Value::Matrix(m) if m.domain == m.codomain => m, + Value::List(l) if l.len() == 1 => Matrix::from_list(l.to_vec()).into(), + _ => return error!("det requires a square matrix") + }; + let mat = mat.into_inner(); + let ident = mat_ident(mat.domain); + let joined = mat.join_right(&ident)?; + let refed = mat_rref(joined)?; + let (new_ident, new_inv) = mat_splith(refed); + + if new_ident == ident { + Ok(Value::Matrix(new_inv.into())) + } else { + error!("matrix does not have an inverse") + } +} + +macro_rules! mathr { + ($type:ident) => { + #[native_func(1)] + fn $type(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(i) => Ok(V::Int(i)), + V::Ratio(r) => Ok(V::Ratio(r.$type())), + V::Float(f) => Ok(V::Float(f.$type())), + v => error!("cannot compute {} on {v}", stringify!($type)) + } + } + }; +} + +macro_rules! trig { + ($type:ident) => { + #[native_func(1)] + fn $type(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value.promote_trig() { + V::Float(f) => Ok(V::Float(f.$type())), + V::Complex(c) => Ok(V::Complex(c.$type())), + v => error!("cannot compute {} on {v}", stringify!($type)) + } + } + }; +} + +macro_rules! trigf { + ($type:ident, $str:ident) => { + #[native_func(1)] + fn $str(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value.promote_trig() { + V::Float(f) => Ok(V::Float(f.$type())), + v => error!("cannot compute {} on {v}", stringify!($str)) + } + } + }; +} + +#[native_func(2)] +fn log(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [base, value] = unpack_args!(args); + match (base.promote_trig(), value.promote_trig()) { + (V::Float(base), V::Float(arg)) => Ok(V::Float(arg.log(base))), + (V::Float(base), V::Complex(arg)) => Ok(V::Complex(arg.log(base))), + (V::Complex(base), V::Float(arg)) => Ok(V::Complex(arg.ln() / base.ln())), + (V::Complex(base), V::Complex(arg)) => Ok(V::Complex(arg.ln() / base.ln())), + (base, arg) => error!("cannot compute log base {base} argument {arg}") + } +} + +#[native_func(1)] +fn abs(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(i) => Ok(V::Int(i.abs())), + V::Float(f) => Ok(V::Float(f.abs())), + V::Ratio(r) => Ok(V::Ratio(Rational64::new(r.numer().abs(), r.denom().abs()))), + V::Complex(c) => Ok(V::Float(c.norm())), + arg => error!("cannot compute abs for {arg}") + } +} + +#[native_func(1)] +fn fract(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(_) => Ok(V::Int(0)), + V::Float(f) => Ok(V::Float(f.fract())), + V::Ratio(r) => Ok(V::Ratio(r.fract())), + arg => error!("cannot compute fract for {arg}") + } +} + +#[native_func(1)] +fn sign(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(i) => Ok(V::Int(i.signum())), + V::Ratio(r) => Ok(V::Int(r.numer().signum())), + V::Float(f) => Ok(V::Float(f.signum())), + arg => error!("cannot compute sign for {arg}") + } +} + +#[native_func(1)] +fn int(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(i) => Ok(V::Int(i)), + V::Ratio(r) => Ok(V::Int(r.numer() / r.denom())), + V::Float(f) => Ok(V::Int(f as i64)), + V::Complex(c) => Ok(V::Int(c.re as i64)), + arg => error!("cannot cast {arg} to int") + } +} + +#[native_func(1)] +fn ratio(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(i) => Ok(V::Ratio(Rational64::new(i, 1))), + V::Ratio(r) => Ok(V::Ratio(r)), + V::Float(f) => Ok(V::Ratio(Rational64::approximate_float(f).unwrap_or(Rational64::new(0, 1)))), + V::Complex(c) => Ok(V::Ratio(Rational64::approximate_float(c.re).unwrap_or(Rational64::new(0, 1)))), + arg => error!("cannot cast {arg} to ratio") + } +} + +#[native_func(1)] +fn float(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(i) => Ok(V::Float(i as f64)), + V::Ratio(r) => Ok(V::Float((*r.numer() as f64) / (*r.denom() as f64))), + V::Float(f) => Ok(V::Float(f)), + V::Complex(c) => Ok(V::Float(c.re)), + arg => error!("cannot cast {arg} to float") + } +} + +#[native_func(1)] +fn complex(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(i) => Ok(V::Complex(Complex64::new(i as f64, 0.0))), + V::Ratio(r) => Ok(V::Complex(Complex64::new((*r.numer() as f64) / (*r.denom() as f64), 0.0))), + V::Float(f) => Ok(V::Complex(Complex64::new(f, 0.0))), + V::Complex(c) => Ok(V::Complex(c)), + arg => error!("cannot cast {arg} to float") + } +} + +#[native_func(1)] +fn numer(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(i) => Ok(V::Int(i)), + V::Ratio(r) => Ok(V::Int(*r.numer())), + _ => error!("numer can only take a integer or ratio") + } +} + +#[native_func(1)] +fn denom(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(_) => Ok(V::Int(1)), + V::Ratio(r) => Ok(V::Int(*r.denom())), + _ => error!("denom can only take a integer or ratio") + } +} + +#[native_func(1)] +fn re(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(_) | V::Ratio(_) | V::Float(_) => Ok(value), + V::Complex(c) => Ok(V::Float(c.re)), + _ => error!("re can only take a valid number") + } +} + +#[native_func(1)] +fn im(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(_) | V::Ratio(_) | V::Float(_ )=> Ok(V::Int(0)), + V::Complex(c) => Ok(V::Float(c.im)), + _ => error!("re can only take a valid number") + } +} + +mathr!(floor); +mathr!(ceil); +mathr!(round); +mathr!(trunc); +trig!(sqrt); +trig!(cbrt); +trig!(ln); +trig!(log2); +trig!(log10); +trig!(exp); +trig!(exp2); +trig!(sin); +trig!(cos); +trig!(tan); +trig!(sinh); +trig!(cosh); +trig!(tanh); +trig!(asin); +trig!(acos); +trig!(atan); +trig!(asinh); +trig!(acosh); +trig!(atanh); +trigf!(to_degrees, deg); +trigf!(to_radians, rad); + +pub fn load(vm: &mut Vm) { + vm.load_global_fn(trans(), "trans"); + vm.load_global_fn(rref(), "rref"); + vm.load_global_fn(det(), "det"); + vm.load_global_fn(ident(), "ident"); + vm.load_global_fn(inv(), "inv"); + + vm.load_global(Value::Float(PI), "pi"); + vm.load_global(Value::Float(TAU), "tau"); + vm.load_global(Value::Float(E), "e"); + vm.load_global(Value::Float(NAN), "nan"); + vm.load_global(Value::Float(NAN), "NaN"); + vm.load_global(Value::Float(INFINITY), "inf"); + + vm.load_global_fn(int(), "int"); + vm.load_global_fn(ratio(), "ratio"); + vm.load_global_fn(float(), "float"); + vm.load_global_fn(complex(), "complex"); + vm.load_global_fn(abs(), "abs"); + vm.load_global_fn(sign(), "sign"); + vm.load_global_fn(floor(), "floor"); + vm.load_global_fn(ceil(), "ceil"); + vm.load_global_fn(round(), "round"); + vm.load_global_fn(trunc(), "trunc"); + vm.load_global_fn(fract(), "fract"); + vm.load_global_fn(sqrt(), "sqrt"); + vm.load_global_fn(cbrt(), "cbrt"); + vm.load_global_fn(ln(), "ln"); + vm.load_global_fn(log(), "log"); + vm.load_global_fn(log2(), "log2"); + vm.load_global_fn(log10(), "log10"); + vm.load_global_fn(exp(), "exp"); + vm.load_global_fn(exp2(), "exp2"); + vm.load_global_fn(sin(), "sin"); + vm.load_global_fn(cos(), "cos"); + vm.load_global_fn(tan(), "tan"); + vm.load_global_fn(sinh(), "sinh"); + vm.load_global_fn(cosh(), "cosh"); + vm.load_global_fn(tanh(), "tanh"); + vm.load_global_fn(asin(), "asin"); + vm.load_global_fn(acos(), "acos"); + vm.load_global_fn(atan(), "atan"); + vm.load_global_fn(asinh(), "asinh"); + vm.load_global_fn(acosh(), "acosh"); + vm.load_global_fn(atanh(), "atanh"); + vm.load_global_fn(deg(), "deg"); + vm.load_global_fn(rad(), "rad"); + + vm.load_global_fn(denom(), "denom"); + vm.load_global_fn(numer(), "numer"); + vm.load_global_fn(re(), "re"); + vm.load_global_fn(im(), "im"); +} |