diff options
Diffstat (limited to 'matrix-stdlib/src/math.rs')
-rw-r--r-- | matrix-stdlib/src/math.rs | 181 |
1 files changed, 176 insertions, 5 deletions
diff --git a/matrix-stdlib/src/math.rs b/matrix-stdlib/src/math.rs index 3226af5..3f33951 100644 --- a/matrix-stdlib/src/math.rs +++ b/matrix-stdlib/src/math.rs @@ -100,7 +100,7 @@ fn mat_get_non_zero_pivot_row( Ok(()) } -fn mat_rref(mat: Matrix) -> Result<Matrix> { +fn mat_rref(mat: Matrix, full_rref: bool) -> Result<Matrix> { let mut mat = mat; let Some(start) = mat_find_non_zero_col(&mat) else { return Ok(mat) @@ -112,7 +112,8 @@ fn mat_rref(mat: Matrix) -> Result<Matrix> { if mat.get(pivot_row, col)?.is_zero() { break } - for row in 0..mat.codomain { + let min = if full_rref { 0 } else { col }; + for row in min..mat.codomain { if row == pivot_row { continue; }; let scale = mat.get(row, col)?; mat_gauss_row_operation(row, pivot_row, scale, &mut mat)?; @@ -127,9 +128,20 @@ fn rref(_: VmArgs, args: Vec<Value>) -> Result<Value> { let mat = match value { Value::Matrix(m) => m, Value::List(l) => Matrix::from_list(l.to_vec()).into(), - _ => return error!("trans must be given a matrix") + _ => return error!("rref must be given a matrix") }; - Ok(Value::Matrix(mat_rref(mat.into_inner())?.into())) + Ok(Value::Matrix(mat_rref(mat.into_inner(), true)?.into())) +} + +#[native_func(1)] +fn mat_ref(_: VmArgs, args: Vec<Value>) -> Result<Value> { + let [value] = unpack_args!(args); + let mat = match value { + Value::Matrix(m) => m, + Value::List(l) => Matrix::from_list(l.to_vec()).into(), + _ => return error!("ref must be given a matrix") + }; + Ok(Value::Matrix(mat_rref(mat.into_inner(), false)?.into())) } fn mat_det(mat: Matrix) -> Result<Value> { @@ -236,7 +248,7 @@ fn inv(_: VmArgs, args: Vec<Value>) -> Result<Value> { let mat = mat.into_inner(); let ident = mat_ident(mat.domain); let joined = mat.join_right(&ident)?; - let refed = mat_rref(joined)?; + let refed = mat_rref(joined, true)?; let (new_ident, new_inv) = mat_splith(refed); if new_ident == ident { @@ -394,6 +406,17 @@ fn complex(_: VmArgs, args: Vec<Value>) -> Result<Value> { } #[native_func(1)] +fn mat(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::List(l) => Ok(V::Matrix(Matrix::from_list(l.to_vec()).into())), + V::Matrix(m) => Ok(V::Matrix(m)), + arg => error!("cannot cast {arg} to mat") + } +} + +#[native_func(1)] fn numer(_: VmArgs, args: Vec<Value>) -> Result<Value> { use Value as V; let [value] = unpack_args!(args); @@ -437,6 +460,140 @@ fn im(_: VmArgs, args: Vec<Value>) -> Result<Value> { } } +#[native_func(1)] +fn cis(_: VmArgs, args: Vec<Value>) -> Result<Value> { + let [value] = unpack_args!(args); + match value.promote_trig() { + Value::Float(f) => Ok(Value::Complex(Complex64::cis(f))), + Value::Complex(c) => Ok((Value::Complex(Complex64::cis(c.re)) * Value::Float((-c.im).exp()))?), + _ => error!("cis can only take floats") + } +} + +#[native_func(1)] +fn is_finite(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(_) | V::Ratio(_) => Ok(V::Bool(true)), + V::Float(f) => Ok(V::Bool(f.is_finite())), + V::Complex(c) => Ok(V::Bool(c.is_finite())), + _ => error!("is_finite can only take a valid number") + } +} + +#[native_func(1)] +fn is_infinite(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(_) | V::Ratio(_) => Ok(V::Bool(false)), + V::Float(f) => Ok(V::Bool(f.is_infinite())), + V::Complex(c) => Ok(V::Bool(c.is_infinite())), + _ => error!("is_infinite can only take a valid number") + } +} + +#[native_func(1)] +fn is_nan(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(_) | V::Ratio(_) => Ok(V::Bool(false)), + V::Float(f) => Ok(V::Bool(f.is_nan())), + V::Complex(c) => Ok(V::Bool(c.is_nan())), + _ => error!("is_nan can only take a valid number") + } +} + +#[native_func(1)] +fn is_normal(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(_) | V::Ratio(_) => Ok(V::Bool(true)), + V::Float(f) => Ok(V::Bool(f.is_normal())), + V::Complex(c) => Ok(V::Bool(c.is_normal())), + _ => error!("is_normal can only take a valid number") + } +} + +#[native_func(1)] +fn is_subnormal(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(_) | V::Ratio(_) => Ok(V::Bool(false)), + V::Float(f) => Ok(V::Bool(f.is_subnormal())), + _ => error!("is_subnormal can only take subnormal") + } +} + +#[native_func(1)] +fn is_sign_positive(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(i) => Ok(V::Bool(i > 0)), + V::Ratio(r) => Ok(V::Bool(*r.numer() > 0)), + V::Float(f) => Ok(V::Bool(f.is_sign_positive())), + _ => error!("is_sign_positive can only take a real number") + } +} + +#[native_func(1)] +fn is_sign_negative(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(i) => Ok(V::Bool(i < 0)), + V::Ratio(r) => Ok(V::Bool(*r.numer() < 0)), + V::Float(f) => Ok(V::Bool(f.is_sign_negative())), + _ => error!("is_sign_negative can only take a real number") + } +} + +#[native_func(1)] +fn is_zero(_: VmArgs, args: Vec<Value>) -> Result<Value> { + use Value as V; + let [value] = unpack_args!(args); + match value { + V::Int(i) => Ok(V::Bool(i == 0)), + V::Ratio(r) => Ok(V::Bool(*r.numer() == 0 && *r.denom() != 0)), + V::Float(f) => Ok(V::Bool(f == 0.0)), + V::Complex(c) => Ok(V::Bool(c.re == 0.0 && c.im == 0.0)), + _ => error!("is_zero can only take a valid number") + } +} + +#[native_func(2)] +fn mat_joinh(_: VmArgs, args: Vec<Value>) -> Result<Value> { + let [l, r] = unpack_args!(args); + let (l, r) = match (l, r) { + (Value::List(l), Value::List(r)) => (Matrix::from_list(l.to_vec()), Matrix::from_list(r.to_vec())), + (Value::List(l), Value::Matrix(r)) => (Matrix::from_list(l.to_vec()), r.into_inner()), + (Value::Matrix(l), Value::List(r)) => (l.into_inner(), Matrix::from_list(r.to_vec())), + (Value::Matrix(l), Value::Matrix(r)) => (l.into_inner(), r.into_inner()), + _ => return error!("mat_joinh takes two matrices") + }; + let mat = l.join_right(&r)?; + Ok(Value::Matrix(mat.into())) +} + +#[native_func(2)] +fn mat_joinv(_: VmArgs, args: Vec<Value>) -> Result<Value> { + let [l, r] = unpack_args!(args); + let (l, r) = match (l, r) { + (Value::List(l), Value::List(r)) => (Matrix::from_list(l.to_vec()), Matrix::from_list(r.to_vec())), + (Value::List(l), Value::Matrix(r)) => (Matrix::from_list(l.to_vec()), r.into_inner()), + (Value::Matrix(l), Value::List(r)) => (l.into_inner(), Matrix::from_list(r.to_vec())), + (Value::Matrix(l), Value::Matrix(r)) => (l.into_inner(), r.into_inner()), + _ => return error!("mat_joinv takes two matrices") + }; + let mat = l.join_bottom(&r)?; + Ok(Value::Matrix(mat.into())) +} + mathr!(floor); mathr!(ceil); mathr!(round); @@ -465,10 +622,13 @@ trigf!(to_radians, rad); pub fn load(vm: &mut Vm) { vm.load_global_fn(trans(), "trans"); + vm.load_global_fn(mat_ref(), "ref"); vm.load_global_fn(rref(), "rref"); vm.load_global_fn(det(), "det"); vm.load_global_fn(ident(), "ident"); vm.load_global_fn(inv(), "inv"); + vm.load_global_fn(mat_joinh(), "mat_joinh"); + vm.load_global_fn(mat_joinv(), "mat_joinv"); vm.load_global(Value::Float(PI), "pi"); vm.load_global(Value::Float(TAU), "tau"); @@ -481,6 +641,7 @@ pub fn load(vm: &mut Vm) { vm.load_global_fn(ratio(), "ratio"); vm.load_global_fn(float(), "float"); vm.load_global_fn(complex(), "complex"); + vm.load_global_fn(mat(), "mat"); vm.load_global_fn(abs(), "abs"); vm.load_global_fn(sign(), "sign"); vm.load_global_fn(floor(), "floor"); @@ -510,9 +671,19 @@ pub fn load(vm: &mut Vm) { vm.load_global_fn(atanh(), "atanh"); vm.load_global_fn(deg(), "deg"); vm.load_global_fn(rad(), "rad"); + vm.load_global_fn(cis(), "cis"); vm.load_global_fn(denom(), "denom"); vm.load_global_fn(numer(), "numer"); vm.load_global_fn(re(), "re"); vm.load_global_fn(im(), "im"); + + vm.load_global_fn(is_finite(), "is_finite"); + vm.load_global_fn(is_infinite(), "is_infinite"); + vm.load_global_fn(is_nan(), "is_nan"); + vm.load_global_fn(is_zero(), "is_zero"); + vm.load_global_fn(is_normal(), "is_normal"); + vm.load_global_fn(is_subnormal(), "is_subnormal"); + vm.load_global_fn(is_sign_negative(), "is_sign_negative"); + vm.load_global_fn(is_sign_positive(), "is_sign_positive"); } |