use core::f64; use std::f64::{consts::{PI, E, TAU}, NAN, INFINITY}; use matrix_lang::prelude::*; use matrix_macros::native_func; use crate::{error, VmArgs, unpack_args}; #[native_func(1)] fn trans(_: VmArgs, args: Vec) -> Result { let [value] = unpack_args!(args); let mat = match value { Value::Matrix(m) => m, Value::List(l) => Matrix::from_list(l.to_vec()).into(), _ => return error!("trans must be given a matrix") }; let values = mat .cols() .reduce(|mut a, b| {a.extend(b); a}) .unwrap() .into_iter() .map(|e| e.clone()) .collect(); Ok(Value::Matrix(Matrix::new(mat.codomain, mat.domain, values).into())) } fn mat_gauss_row_operation( r1: usize, r2: usize, scale: Value, mat: &mut Matrix ) -> Result<()> { for col in 0..mat.domain { let r1v = mat.get(r1, col)?; let r2v = mat.get(r2, col)?; let res = (r1v - (r2v * scale.clone())?)?; mat.set(r1, col, res)?; } Ok(()) } fn mat_swap_rows( r1: usize, r2: usize, mat: &mut Matrix ) -> Result<()> { let cols = mat.domain; for col in 0..cols { let a = mat.get(r1, col)?; let b = mat.get(r2, col)?; mat.set(r2, col, a)?; mat.set(r1, col, b)?; } Ok(()) } fn mat_find_non_zero_col( mat: &Matrix ) -> Option { for (i,col) in mat.cols().enumerate() { for val in col.iter() { if **val != Value::Int(0) { return Some(i) } } } return None } fn mat_scale_pivot_row( row: usize, mat: &mut Matrix ) -> Result<()> { let scale = mat.get(row, row)?; if scale.is_zero() { return Ok(()) } for col in 0..mat.domain { let res = (mat.get(row, col)?.clone() / scale.clone())?; mat.set(row, col, res)?; } Ok(()) } fn mat_get_non_zero_pivot_row( row: usize, mat: &mut Matrix, ) -> Result<()> { let col = row; let test = mat.get(row, col)?; if test.is_zero() { for r in row..mat.codomain { let cur = mat.get(r, col)?; if !cur.is_zero() { mat_swap_rows(row, r, mat)?; break; } } } mat_scale_pivot_row(row, mat)?; Ok(()) } fn mat_rref(mat: Matrix, full_rref: bool) -> Result { let mut mat = mat; let Some(start) = mat_find_non_zero_col(&mat) else { return Ok(mat) }; let end = mat.domain.min(mat.codomain); for col in start..end { let pivot_row = col; mat_get_non_zero_pivot_row(pivot_row, &mut mat)?; if mat.get(pivot_row, col)?.is_zero() { break } let min = if full_rref { 0 } else { col }; for row in min..mat.codomain { if row == pivot_row { continue; }; let scale = mat.get(row, col)?; mat_gauss_row_operation(row, pivot_row, scale, &mut mat)?; } } Ok(mat) } #[native_func(1)] fn rref(_: VmArgs, args: Vec) -> Result { let [value] = unpack_args!(args); let mat = match value { Value::Matrix(m) => m, Value::List(l) => Matrix::from_list(l.to_vec()).into(), _ => return error!("rref must be given a matrix") }; Ok(Value::Matrix(mat_rref(mat.into_inner(), true)?.into())) } #[native_func(1)] fn mat_ref(_: VmArgs, args: Vec) -> Result { let [value] = unpack_args!(args); let mat = match value { Value::Matrix(m) => m, Value::List(l) => Matrix::from_list(l.to_vec()).into(), _ => return error!("ref must be given a matrix") }; Ok(Value::Matrix(mat_rref(mat.into_inner(), false)?.into())) } fn mat_det(mat: Matrix) -> Result { if mat.domain == 1 { return Ok(mat.get(0,0)?) } if mat.domain == 2 { let a = mat.get(0,0)? * mat.get(1,1)?; let b = mat.get(0,1)? * mat.get(1,0)?; return Ok((a? - b?)?) } let mut res = Value::Int(0); for col in 0..mat.domain { let sub_values = mat.rows() .skip(1) .map(|r| r.into_iter() .enumerate() .filter(|(idx,_)| *idx != col) .map(|(_, v)| v.clone()) .collect::>() ) .reduce(|mut a, b| {a.extend(b); a}) .unwrap(); let sub = Matrix::new(mat.domain - 1, mat.domain - 1, sub_values); let val = mat.get(0, col)?; let part = (val * mat_det(sub)?)?; if col % 2 == 0 { res = (res + part)?; } else { res = (res - part)?; } } Ok(res) } #[native_func(1)] fn det(_: VmArgs, args: Vec) -> Result { let [value] = unpack_args!(args); let mat = match value { Value::Matrix(m) if m.domain == m.codomain => m, Value::List(l) if l.len() == 1 => Matrix::from_list(l.to_vec()).into(), _ => return error!("det requires a square matrix") }; let mat = mat.into_inner(); Ok(mat_det(mat)?) } fn mat_ident(dim: usize) -> Matrix { let len = dim * dim; let mut values = vec![Value::Int(0); len]; let mut idx = 0; loop { if idx >= len { break }; values[idx] = Value::Int(1); idx += dim + 1; } Matrix::new(dim, dim, values) } #[native_func(1)] fn ident(_: VmArgs, args: Vec) -> Result { let [value] = unpack_args!(args); let dim = match value { Value::Int(i) if i > 0 => i, Value::Ratio(r) if *r.denom() == 1 && *r.numer() > 0 => *r.numer(), _ => return error!("ident requries a positive [Int] dimension") }; Ok(Value::Matrix(mat_ident(dim as usize).into())) } fn mat_splith(mat: Matrix) -> (Matrix, Matrix) { let mut m1 = Vec::new(); let mut m2 = Vec::new(); mat.rows() .for_each(|r| { let split = r.len() / 2; r.into_iter().enumerate().for_each(|(i, v)| { if i < split { m1.push(v.clone()); } else { m2.push(v.clone()); } }) }); let m1 = Matrix::new(mat.domain/2, mat.codomain, m1); let m2 = Matrix::new(mat.domain/2, mat.codomain, m2); (m1, m2) } #[native_func(1)] fn inv(_: VmArgs, args: Vec) -> Result { let [value] = unpack_args!(args); let mat = match value { Value::Matrix(m) if m.domain == m.codomain => m, Value::List(l) if l.len() == 1 => Matrix::from_list(l.to_vec()).into(), _ => return error!("det requires a square matrix") }; let mat = mat.into_inner(); let ident = mat_ident(mat.domain); let joined = mat.join_right(&ident)?; let refed = mat_rref(joined, true)?; let (new_ident, new_inv) = mat_splith(refed); if new_ident == ident { Ok(Value::Matrix(new_inv.into())) } else { error!("matrix does not have an inverse") } } macro_rules! mathr { ($type:ident) => { #[native_func(1)] fn $type(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(i) => Ok(V::Int(i)), V::Ratio(r) => Ok(V::Ratio(r.$type())), V::Float(f) => Ok(V::Float(f.$type())), v => error!("cannot compute {} on {v}", stringify!($type)) } } }; } macro_rules! trig { ($type:ident) => { #[native_func(1)] fn $type(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value.floaty() { V::Float(f) => Ok(V::Float(f.$type())), V::Complex(c) => Ok(V::Complex(c.$type())), v => error!("cannot compute {} on {v}", stringify!($type)) } } }; } macro_rules! trigf { ($type:ident, $str:ident) => { #[native_func(1)] fn $str(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value.floaty() { V::Float(f) => Ok(V::Float(f.$type())), v => error!("cannot compute {} on {v}", stringify!($str)) } } }; } #[native_func(2)] fn log(_: VmArgs, args: Vec) -> Result { use Value as V; let [base, value] = unpack_args!(args); match (base.floaty(), value.floaty()) { (V::Float(base), V::Float(arg)) => Ok(V::Float(arg.log(base))), (V::Float(base), V::Complex(arg)) => Ok(V::Complex(arg.log(base))), (V::Complex(base), V::Float(arg)) => Ok(V::Complex(arg.ln() / base.ln())), (V::Complex(base), V::Complex(arg)) => Ok(V::Complex(arg.ln() / base.ln())), (base, arg) => error!("cannot compute log base {base} argument {arg}") } } #[native_func(1)] fn abs(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(i) => Ok(V::Int(i.abs())), V::Float(f) => Ok(V::Float(f.abs())), V::Ratio(r) => Ok(V::Ratio(Rational64::new(r.numer().abs(), r.denom().abs()))), V::Complex(c) => Ok(V::Float(c.norm())), arg => error!("cannot compute abs for {arg}") } } #[native_func(1)] fn fract(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(_) => Ok(V::Int(0)), V::Float(f) => Ok(V::Float(f.fract())), V::Ratio(r) => Ok(V::Ratio(r.fract())), arg => error!("cannot compute fract for {arg}") } } #[native_func(1)] fn sign(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(i) => Ok(V::Int(i.signum())), V::Ratio(r) => Ok(V::Int(r.numer().signum())), V::Float(f) => Ok(V::Float(f.signum())), arg => error!("cannot compute sign for {arg}") } } #[native_func(1)] fn int(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(i) => Ok(V::Int(i)), V::Ratio(r) => Ok(V::Int(r.numer() / r.denom())), V::Float(f) => Ok(V::Int(f as i64)), V::Complex(c) => Ok(V::Int(c.re as i64)), arg => error!("cannot cast {arg} to int") } } #[native_func(1)] fn ratio(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(i) => Ok(V::Ratio(Rational64::new(i, 1))), V::Ratio(r) => Ok(V::Ratio(r)), V::Float(f) => Ok(V::Ratio(Rational64::approximate_float(f).unwrap_or(Rational64::new(0, 1)))), V::Complex(c) => Ok(V::Ratio(Rational64::approximate_float(c.re).unwrap_or(Rational64::new(0, 1)))), arg => error!("cannot cast {arg} to ratio") } } #[native_func(1)] fn float(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(i) => Ok(V::Float(i as f64)), V::Ratio(r) => Ok(V::Float((*r.numer() as f64) / (*r.denom() as f64))), V::Float(f) => Ok(V::Float(f)), V::Complex(c) => Ok(V::Float(c.re)), arg => error!("cannot cast {arg} to float") } } #[native_func(1)] fn complex(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(i) => Ok(V::Complex(Complex64::new(i as f64, 0.0))), V::Ratio(r) => Ok(V::Complex(Complex64::new((*r.numer() as f64) / (*r.denom() as f64), 0.0))), V::Float(f) => Ok(V::Complex(Complex64::new(f, 0.0))), V::Complex(c) => Ok(V::Complex(c)), arg => error!("cannot cast {arg} to float") } } #[native_func(1)] fn mat(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::List(l) => Ok(V::Matrix(Matrix::from_list(l.to_vec()).into())), V::Matrix(m) => Ok(V::Matrix(m)), arg => error!("cannot cast {arg} to mat") } } #[native_func(1)] fn numer(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(i) => Ok(V::Int(i)), V::Ratio(r) => Ok(V::Int(*r.numer())), _ => error!("numer can only take a integer or ratio") } } #[native_func(1)] fn denom(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(_) => Ok(V::Int(1)), V::Ratio(r) => Ok(V::Int(*r.denom())), _ => error!("denom can only take a integer or ratio") } } #[native_func(1)] fn re(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(_) | V::Ratio(_) | V::Float(_) => Ok(value), V::Complex(c) => Ok(V::Float(c.re)), _ => error!("re can only take a valid number") } } #[native_func(1)] fn im(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(_) | V::Ratio(_) | V::Float(_ )=> Ok(V::Int(0)), V::Complex(c) => Ok(V::Float(c.im)), _ => error!("re can only take a valid number") } } #[native_func(1)] fn cis(_: VmArgs, args: Vec) -> Result { let [value] = unpack_args!(args); match value.floaty() { Value::Float(f) => Ok(Value::Complex(Complex64::cis(f))), Value::Complex(c) => Ok((Value::Complex(Complex64::cis(c.re)) * Value::Float((-c.im).exp()))?), _ => error!("cis can only take floats") } } #[native_func(1)] fn is_finite(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(_) | V::Ratio(_) => Ok(V::Bool(true)), V::Float(f) => Ok(V::Bool(f.is_finite())), V::Complex(c) => Ok(V::Bool(c.is_finite())), _ => error!("is_finite can only take a valid number") } } #[native_func(1)] fn is_infinite(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(_) | V::Ratio(_) => Ok(V::Bool(false)), V::Float(f) => Ok(V::Bool(f.is_infinite())), V::Complex(c) => Ok(V::Bool(c.is_infinite())), _ => error!("is_infinite can only take a valid number") } } #[native_func(1)] fn is_nan(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(_) | V::Ratio(_) => Ok(V::Bool(false)), V::Float(f) => Ok(V::Bool(f.is_nan())), V::Complex(c) => Ok(V::Bool(c.is_nan())), _ => error!("is_nan can only take a valid number") } } #[native_func(1)] fn is_normal(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(_) | V::Ratio(_) => Ok(V::Bool(true)), V::Float(f) => Ok(V::Bool(f.is_normal())), V::Complex(c) => Ok(V::Bool(c.is_normal())), _ => error!("is_normal can only take a valid number") } } #[native_func(1)] fn is_subnormal(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(_) | V::Ratio(_) => Ok(V::Bool(false)), V::Float(f) => Ok(V::Bool(f.is_subnormal())), _ => error!("is_subnormal can only take subnormal") } } #[native_func(1)] fn is_sign_positive(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(i) => Ok(V::Bool(i > 0)), V::Ratio(r) => Ok(V::Bool(*r.numer() > 0)), V::Float(f) => Ok(V::Bool(f.is_sign_positive())), _ => error!("is_sign_positive can only take a real number") } } #[native_func(1)] fn is_sign_negative(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(i) => Ok(V::Bool(i < 0)), V::Ratio(r) => Ok(V::Bool(*r.numer() < 0)), V::Float(f) => Ok(V::Bool(f.is_sign_negative())), _ => error!("is_sign_negative can only take a real number") } } #[native_func(1)] fn is_zero(_: VmArgs, args: Vec) -> Result { use Value as V; let [value] = unpack_args!(args); match value { V::Int(i) => Ok(V::Bool(i == 0)), V::Ratio(r) => Ok(V::Bool(*r.numer() == 0 && *r.denom() != 0)), V::Float(f) => Ok(V::Bool(f == 0.0)), V::Complex(c) => Ok(V::Bool(c.re == 0.0 && c.im == 0.0)), _ => error!("is_zero can only take a valid number") } } #[native_func(2)] fn mat_joinh(_: VmArgs, args: Vec) -> Result { let [l, r] = unpack_args!(args); let (l, r) = match (l, r) { (Value::List(l), Value::List(r)) => (Matrix::from_list(l.to_vec()), Matrix::from_list(r.to_vec())), (Value::List(l), Value::Matrix(r)) => (Matrix::from_list(l.to_vec()), r.into_inner()), (Value::Matrix(l), Value::List(r)) => (l.into_inner(), Matrix::from_list(r.to_vec())), (Value::Matrix(l), Value::Matrix(r)) => (l.into_inner(), r.into_inner()), _ => return error!("mat_joinh takes two matrices") }; let mat = l.join_right(&r)?; Ok(Value::Matrix(mat.into())) } #[native_func(2)] fn mat_joinv(_: VmArgs, args: Vec) -> Result { let [l, r] = unpack_args!(args); let (l, r) = match (l, r) { (Value::List(l), Value::List(r)) => (Matrix::from_list(l.to_vec()), Matrix::from_list(r.to_vec())), (Value::List(l), Value::Matrix(r)) => (Matrix::from_list(l.to_vec()), r.into_inner()), (Value::Matrix(l), Value::List(r)) => (l.into_inner(), Matrix::from_list(r.to_vec())), (Value::Matrix(l), Value::Matrix(r)) => (l.into_inner(), r.into_inner()), _ => return error!("mat_joinv takes two matrices") }; let mat = l.join_bottom(&r)?; Ok(Value::Matrix(mat.into())) } mathr!(floor); mathr!(ceil); mathr!(round); mathr!(trunc); trig!(sqrt); trig!(cbrt); trig!(ln); trig!(log2); trig!(log10); trig!(exp); trig!(exp2); trig!(sin); trig!(cos); trig!(tan); trig!(sinh); trig!(cosh); trig!(tanh); trig!(asin); trig!(acos); trig!(atan); trig!(asinh); trig!(acosh); trig!(atanh); trigf!(to_degrees, deg); trigf!(to_radians, rad); pub fn load(vm: &mut Vm) { vm.load_global_fn(trans(), "trans"); vm.load_global_fn(mat_ref(), "ref"); vm.load_global_fn(rref(), "rref"); vm.load_global_fn(det(), "det"); vm.load_global_fn(ident(), "ident"); vm.load_global_fn(inv(), "inv"); vm.load_global_fn(mat_joinh(), "mat_joinh"); vm.load_global_fn(mat_joinv(), "mat_joinv"); vm.load_global(Value::Float(PI), "pi"); vm.load_global(Value::Float(TAU), "tau"); vm.load_global(Value::Float(E), "e"); vm.load_global(Value::Float(NAN), "nan"); vm.load_global(Value::Float(NAN), "NaN"); vm.load_global(Value::Float(INFINITY), "inf"); vm.load_global_fn(int(), "int"); vm.load_global_fn(ratio(), "ratio"); vm.load_global_fn(float(), "float"); vm.load_global_fn(complex(), "complex"); vm.load_global_fn(mat(), "mat"); vm.load_global_fn(abs(), "abs"); vm.load_global_fn(sign(), "sign"); vm.load_global_fn(floor(), "floor"); vm.load_global_fn(ceil(), "ceil"); vm.load_global_fn(round(), "round"); vm.load_global_fn(trunc(), "trunc"); vm.load_global_fn(fract(), "fract"); vm.load_global_fn(sqrt(), "sqrt"); vm.load_global_fn(cbrt(), "cbrt"); vm.load_global_fn(ln(), "ln"); vm.load_global_fn(log(), "log"); vm.load_global_fn(log2(), "log2"); vm.load_global_fn(log10(), "log10"); vm.load_global_fn(exp(), "exp"); vm.load_global_fn(exp2(), "exp2"); vm.load_global_fn(sin(), "sin"); vm.load_global_fn(cos(), "cos"); vm.load_global_fn(tan(), "tan"); vm.load_global_fn(sinh(), "sinh"); vm.load_global_fn(cosh(), "cosh"); vm.load_global_fn(tanh(), "tanh"); vm.load_global_fn(asin(), "asin"); vm.load_global_fn(acos(), "acos"); vm.load_global_fn(atan(), "atan"); vm.load_global_fn(asinh(), "asinh"); vm.load_global_fn(acosh(), "acosh"); vm.load_global_fn(atanh(), "atanh"); vm.load_global_fn(deg(), "deg"); vm.load_global_fn(rad(), "rad"); vm.load_global_fn(cis(), "cis"); vm.load_global_fn(denom(), "denom"); vm.load_global_fn(numer(), "numer"); vm.load_global_fn(re(), "re"); vm.load_global_fn(im(), "im"); vm.load_global_fn(is_finite(), "is_finite"); vm.load_global_fn(is_infinite(), "is_infinite"); vm.load_global_fn(is_nan(), "is_nan"); vm.load_global_fn(is_zero(), "is_zero"); vm.load_global_fn(is_normal(), "is_normal"); vm.load_global_fn(is_subnormal(), "is_subnormal"); vm.load_global_fn(is_sign_negative(), "is_sign_negative"); vm.load_global_fn(is_sign_positive(), "is_sign_positive"); }