diff --git a/crates/RustQuant_autodiff/Cargo.toml b/crates/RustQuant_autodiff/Cargo.toml index 7931bc6..ff4e5fe 100644 --- a/crates/RustQuant_autodiff/Cargo.toml +++ b/crates/RustQuant_autodiff/Cargo.toml @@ -20,6 +20,7 @@ RustQuant = { path = "../RustQuant" } ndarray = { workspace = true } errorfunctions = { workspace = true } RustQuant_utils = { workspace = true } +num-traits = "0.2.19" ## ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ## RUSTDOC CONFIGURATION diff --git a/crates/RustQuant_autodiff/src/accumulate.rs b/crates/RustQuant_autodiff/src/accumulate.rs index 2cf1a1c..2a4e208 100644 --- a/crates/RustQuant_autodiff/src/accumulate.rs +++ b/crates/RustQuant_autodiff/src/accumulate.rs @@ -17,6 +17,12 @@ //! - `DVector>` <- Currently not possible due to lifetimes //! - `Array, Ix2>` <- Work in progress +use std::ops::AddAssign; + +use num_traits::{One, Zero}; + +use crate::DiffOps; + use super::variable::Variable; /// Trait to reverse accumulate the gradient for different types. @@ -25,18 +31,18 @@ pub trait Accumulate { fn accumulate(&self) -> OUT; } -impl Accumulate> for Variable<'_> { +impl Accumulate> for Variable<'_, T> where T: Zero + One + Copy + AddAssign { /// Function to reverse accumulate the gradient for a `Variable`. /// 1. Allocate the array of adjoints. /// 2. Set the seed (dx/dx = 1). /// 3. Traverse the graph backwards, updating the adjoints for the parent vertices. #[inline] - fn accumulate(&self) -> Vec { + fn accumulate(&self) -> Vec { // Set the seed. // The seed is the derivative of the output with respect to itself. // dy/dy = 1 - let mut adjoints = vec![0.0; self.graph.len()]; - adjoints[self.index] = 1.0; // SEED + let mut adjoints = vec![T::zero(); self.graph.len()]; + adjoints[self.index] = T::one(); // SEED // Traverse the graph backwards and update the adjoints for the parent vertices. // This is simply the generalised chain rule. diff --git a/crates/RustQuant_autodiff/src/diff_traits.rs b/crates/RustQuant_autodiff/src/diff_traits.rs new file mode 100644 index 0000000..11bb045 --- /dev/null +++ b/crates/RustQuant_autodiff/src/diff_traits.rs @@ -0,0 +1,88 @@ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// RustQuant: A Rust library for quantitative finance tools. +// Copyright (C) 2023 https://github.com/avhz +// Dual licensed under Apache 2.0 and MIT. +// See: +// - LICENSE-APACHE.md +// - LICENSE-MIT.md +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +use std::ops::{AddAssign, Neg}; + +use num_traits::{NumOps, One, Zero}; + +/// A nonadditive zero trait +pub trait Initial: Sized { + /// return zero + fn zero() -> Self; +} + +impl Initial for T where T: Zero { + fn zero() -> Self { + Zero::zero() + } +} + +/// A variable which can be evaluated upon by elementary operations. +pub trait DiffOps: + NumOps + One + Zero + AddAssign + Neg + Sized + Clone + Copy + NumOps +{ + /// Compute the derivative asin of the variable. + fn asin_diff(self) -> Self; + /// Compute the error function of the variable. + fn erf(self) -> Self; + /// Compute the complementary error function of the variable. + fn erfc(self) -> Self; + /// Compute the absolute value of the variable. + fn abs(self) -> Self; + /// Compute the signum of the variable. + fn signum(self) -> Self; + /// Compute the reciprocal of the variable. + fn recip(self) -> Self; + /// Compute the square root of the variable. + fn sqrt(self) -> Self; + /// Compute the cube root of the variable. + fn cbrt(self) -> Self; + /// Compute the integer power of the variable. + fn powi(self, n: i32) -> Self; + /// Compute the real power of the variable. + fn powf(self, other: Self) -> Self; + /// Compute the exponential of the variable. + fn exp(self) -> Self; + /// Compute the natural logarithm of the variable. + fn ln(self) -> Self; + /// Compute the natural logarithm of 1 + the variable. + fn ln_1p(self) -> Self; + /// Compute the exponential of the variable minus 1. + fn exp_m1(self) -> Self; + /// Compute the base 2 exponential of the variable. + fn exp2(self) -> Self; + /// Compute the binary logarithm of the variable. + fn log2(self) -> Self; + /// Compute the decimal logarithm of the variable. + fn log10(self) -> Self; + /// Compute the sine of the variable. + fn sin(self) -> Self; + /// Compute the cosine of the variable. + fn cos(self) -> Self; + /// Compute the tangent of the variable. + fn tan(self) -> Self; + /// Compute the arcsine of the variable. + fn asin(self) -> Self; + /// Compute the acosine of the variable. + fn acos(self) -> Self; + /// Compute the arctangent of the variable. + fn atan(self) -> Self; + /// Compute the sinh of the variable. + fn sinh(self) -> Self; + /// Compute the cosh of the variable. + fn cosh(self) -> Self; + /// Compute the tanh of the variable. + fn tanh(self) -> Self; + /// Compute the arcsinh of the variable. + fn asinh(self) -> Self; + /// Compute the arccosh of the variable. + fn acosh(self) -> Self; + /// Compute the atanh of the variable. + fn atanh(self) -> Self; +} diff --git a/crates/RustQuant_autodiff/src/graph.rs b/crates/RustQuant_autodiff/src/graph.rs index b161f09..d7c8cac 100644 --- a/crates/RustQuant_autodiff/src/graph.rs +++ b/crates/RustQuant_autodiff/src/graph.rs @@ -17,7 +17,9 @@ // IMPORTS // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -use crate::{variable::Variable, Arity, Vertex}; +use num_traits::Zero; + +use crate::{variable::Variable, Arity, DiffOps, Initial, Vertex}; use std::cell::RefCell; // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -26,9 +28,9 @@ use std::cell::RefCell; /// Struct to contain the graph (Wengert list), as a vector of `Vertex`s. #[derive(Debug, Clone)] -pub struct Graph { +pub struct Graph { /// Vector containing the vertices in the Wengert List. - pub vertices: RefCell>, + pub vertices: RefCell>>, } // pub struct Graph(RefCell>); @@ -40,7 +42,7 @@ impl Default for Graph { } /// Implementation for the `Graph` struct. -impl Graph { +impl Graph where T: Initial + Clone + Copy { /// Instantiate a new graph. #[must_use] #[inline] @@ -74,7 +76,7 @@ impl Graph { /// Add a new variable to the graph. /// Returns a new `Variable` instance (the contents of a vertex). #[inline] - pub fn var(&self, value: f64) -> Variable { + pub fn var(&self, value: T) -> Variable { Variable { graph: self, value, @@ -85,7 +87,7 @@ impl Graph { /// Add multiple variables (a slice) to the graph. /// Useful for larger functions with many inputs. #[inline] - pub fn vars<'v>(&'v self, values: &[f64]) -> Vec> { + pub fn vars<'v>(&'v self, values: &[T]) -> Vec> { values.iter().map(|&val| self.var(val)).collect() } @@ -113,7 +115,7 @@ impl Graph { self.vertices .borrow_mut() .iter_mut() - .for_each(|vertex| vertex.partials = [0.0; 2]); + .for_each(|vertex| vertex.partials = [T::zero(), T::zero()]); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -122,7 +124,7 @@ impl Graph { /// Pushes a vertex to the graph. #[inline] - pub fn push(&self, arity: Arity, parents: &[usize], partials: &[f64]) -> usize { + pub fn push(&self, arity: Arity, parents: &[usize], partials: &[T]) -> usize { let mut vertices = self.vertices.borrow_mut(); let len = vertices.len(); @@ -140,7 +142,7 @@ impl Graph { assert!(parents.is_empty()); Vertex { - partials: [0.0, 0.0], + partials: [T::zero(), T::zero()], parents: [len, len], } } @@ -157,7 +159,7 @@ impl Graph { assert!(parents.len() == 1); Vertex { - partials: [partials[0], 0.0], + partials: [partials[0], T::zero()], parents: [parents[0], len], } } diff --git a/crates/RustQuant_autodiff/src/lib.rs b/crates/RustQuant_autodiff/src/lib.rs index b58e440..20e5d64 100644 --- a/crates/RustQuant_autodiff/src/lib.rs +++ b/crates/RustQuant_autodiff/src/lib.rs @@ -91,6 +91,14 @@ pub use vertex::*; pub mod overload; pub use overload::*; +/// `DiffOps`s for `autodiff`. +pub mod diff_traits; +pub use diff_traits::*; + +/// `Symbol`s for `autodiff`. +pub mod symbol; +pub use symbol::*; + /// `Variable`s for `autodiff`. pub mod variable; pub use variable::*; diff --git a/crates/RustQuant_autodiff/src/overload.rs b/crates/RustQuant_autodiff/src/overload.rs index 24f36c6..6334817 100644 --- a/crates/RustQuant_autodiff/src/overload.rs +++ b/crates/RustQuant_autodiff/src/overload.rs @@ -7,6 +7,7 @@ // - LICENSE-MIT.md // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +use crate::DiffOps; use crate::{variable::Variable, vertex::Arity}; use std::iter::{Product, Sum}; use std::ops::Neg; @@ -47,9 +48,11 @@ impl<'v> AddAssign> for f64 { } /// Variable<'v> + Variable<'v> -impl<'v> Add> for Variable<'v> { - type Output = Variable<'v>; - +impl<'v, T> Add> for Variable<'v, T> +where + T: DiffOps, +{ + type Output = Variable<'v, T>; /// ``` /// # use RustQuant_autodiff::*; /// @@ -66,22 +69,27 @@ impl<'v> Add> for Variable<'v> { /// assert_eq!(grad.wrt(&y), 1.0); /// ``` #[inline] - fn add(self, other: Variable<'v>) -> Self::Output { + fn add(self, other: Variable<'v, T>) -> Self::Output { assert!(std::ptr::eq(self.graph, other.graph)); Variable { graph: self.graph, value: self.value + other.value, - index: self - .graph - .push(Arity::Binary, &[self.index, other.index], &[1.0, 1.0]), + index: self.graph.push( + Arity::Binary, + &[self.index, other.index], + &[T::one(), T::one()], + ), } } } /// Variable<'v> + f64 -impl<'v> Add for Variable<'v> { - type Output = Variable<'v>; +impl<'v, T> Add for Variable<'v, T> +where + T: DiffOps, +{ + type Output = Variable<'v, T>; /// ``` /// # use RustQuant_autodiff::*; @@ -102,16 +110,21 @@ impl<'v> Add for Variable<'v> { Variable { graph: self.graph, value: self.value + other, - index: self - .graph - .push(Arity::Binary, &[self.index, self.index], &[1.0, 0.0]), + index: self.graph.push( + Arity::Binary, + &[self.index, self.index], + &[T::one(), T::zero()], + ), } } } /// f64 + Variable<'v> -impl<'v> Add> for f64 { - type Output = Variable<'v>; +impl<'v, T> Add> for f64 +where + T: DiffOps, +{ + type Output = Variable<'v, T>; /// ``` /// # use RustQuant_autodiff::*; @@ -128,7 +141,7 @@ impl<'v> Add> for f64 { /// assert_eq!(grad.wrt(&x), 1.0); /// ``` #[inline] - fn add(self, other: Variable<'v>) -> Self::Output { + fn add(self, other: Variable<'v, T>) -> Self::Output { other + self } } @@ -165,8 +178,11 @@ impl<'v> DivAssign> for f64 { } /// Variable<'v> / Variable<'v> -impl<'v> Div> for Variable<'v> { - type Output = Variable<'v>; +impl<'v, T> Div> for Variable<'v, T> +where + T: DiffOps, +{ + type Output = Variable<'v, T>; /// ``` /// # use RustQuant_autodiff::*; @@ -184,7 +200,7 @@ impl<'v> Div> for Variable<'v> { /// assert_eq!(grad.wrt(&y), - 5.0 / (2.0 * 2.0)); /// ``` #[inline] - fn div(self, other: Variable<'v>) -> Self::Output { + fn div(self, other: Variable<'v, T>) -> Self::Output { assert!(std::ptr::eq(self.graph, other.graph)); self * other.recip() @@ -192,8 +208,11 @@ impl<'v> Div> for Variable<'v> { } /// Variable<'v> / f64 -impl<'v> Div for Variable<'v> { - type Output = Variable<'v>; +impl<'v, T> Div for Variable<'v, T> +where + T: DiffOps, +{ + type Output = Variable<'v, T>; /// ``` /// # use RustQuant_autodiff::*; @@ -252,7 +271,10 @@ impl<'v> Div> for f64 { // OVERLOADING: STANDARD MATH OPERATORS // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -impl<'v> std::ops::Neg for Variable<'v> { +impl<'v, T> std::ops::Neg for Variable<'v, T> +where + T: DiffOps, +{ type Output = Self; #[inline] @@ -265,7 +287,10 @@ impl<'v> std::ops::Neg for Variable<'v> { // OVERLOADING: PRIMITIVE FUNCTIONS // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -impl<'v> Variable<'v> { +impl<'v, T> Variable<'v, T> +where + T: DiffOps, +{ /// Absolute value function. /// d/dx abs(x) = sign(x) /// @@ -321,7 +346,7 @@ impl<'v> Variable<'v> { index: self.graph.push( Arity::Unary, &[self.index], - &[((1.0 - self.value.powi(2)).sqrt()).recip().neg()], + &[((T::one() - self.value.powi(2)).sqrt()).recip().neg()], ), } } @@ -349,7 +374,10 @@ impl<'v> Variable<'v> { index: self.graph.push( Arity::Unary, &[self.index], - &[((self.value - 1.0).sqrt() * (self.value + 1.0).sqrt()).recip()], + &[ + ((self.value - T::one()).sqrt() * (self.value + T::one()).sqrt()) + .recip(), + ], ), } } @@ -376,15 +404,9 @@ impl<'v> Variable<'v> { Variable { graph: self.graph, value: self.value.asin(), - index: self.graph.push( - Arity::Unary, - &[self.index], - &[if (self.value > -1.0) && (self.value < 1.0) { - ((1.0 - self.value.powi(2)).sqrt()).recip() - } else { - f64::NAN - }], - ), + index: self + .graph + .push(Arity::Unary, &[self.index], &[self.value.asin_diff()]), } } @@ -412,7 +434,7 @@ impl<'v> Variable<'v> { index: self.graph.push( Arity::Unary, &[self.index], - &[((1.0 + self.value.powi(2)).sqrt()).recip()], + &[((self.value.powi(2) + 1.0).sqrt()).recip()], ), } } @@ -441,7 +463,7 @@ impl<'v> Variable<'v> { index: self.graph.push( Arity::Unary, &[self.index], - &[((1.0 + self.value.powi(2)).recip())], + &[((self.value.powi(2) + 1.0).recip())], ), } } @@ -471,7 +493,7 @@ impl<'v> Variable<'v> { index: self.graph.push( Arity::Unary, &[self.index], - &[((1.0 - self.value.powi(2)).recip())], + &[-(self.value.powi(2) - 1.0).recip()], ), } } @@ -501,7 +523,7 @@ impl<'v> Variable<'v> { index: self.graph.push( Arity::Unary, &[self.index], - &[((3.0 * self.value.powf(2.0 / 3.0)).recip())], + &[self.value.cbrt() / (self.value * 3.0)], ), } } @@ -615,7 +637,7 @@ impl<'v> Variable<'v> { index: self.graph.push( Arity::Unary, &[self.index], - &[2_f64.powf(self.value) * 2_f64.ln()], + &[self.value.exp2() * 2_f64.ln()], ), } } @@ -698,9 +720,11 @@ impl<'v> Variable<'v> { Variable { graph: self.graph, value: self.value.ln_1p(), - index: self - .graph - .push(Arity::Unary, &[self.index], &[(1.0 + self.value).recip()]), + index: self.graph.push( + Arity::Unary, + &[self.index], + &[(self.value + 1.0).recip()], + ), } } @@ -881,7 +905,7 @@ impl<'v> Variable<'v> { index: self.graph.push( Arity::Unary, &[self.index], - &[(2.0 * self.value.sqrt()).recip()], + &[(self.value.sqrt() * 2.0).recip()], ), } } @@ -1260,8 +1284,11 @@ impl<'v> MulAssign> for f64 { } /// Variable<'v> * Variable<'v> -impl<'v> Mul> for Variable<'v> { - type Output = Variable<'v>; +impl<'v, T> Mul> for Variable<'v, T> +where + T: DiffOps, +{ + type Output = Variable<'v, T>; /// ``` /// # use RustQuant_autodiff::*; @@ -1279,7 +1306,7 @@ impl<'v> Mul> for Variable<'v> { /// assert_eq!(grad.wrt(&y), 5.0); /// ``` #[inline] - fn mul(self, other: Variable<'v>) -> Self::Output { + fn mul(self, other: Variable<'v, T>) -> Self::Output { assert!(std::ptr::eq(self.graph, other.graph)); Variable { @@ -1295,8 +1322,11 @@ impl<'v> Mul> for Variable<'v> { } /// Variable<'v> * f64 -impl<'v> Mul for Variable<'v> { - type Output = Variable<'v>; +impl<'v, T> Mul for Variable<'v, T> +where + T: DiffOps, +{ + type Output = Variable<'v, T>; /// ``` /// # use RustQuant_autodiff::*; @@ -1317,9 +1347,11 @@ impl<'v> Mul for Variable<'v> { Variable { graph: self.graph, value: self.value * other, - index: self - .graph - .push(Arity::Binary, &[self.index, self.index], &[other, 0.0]), + index: self.graph.push( + Arity::Binary, + &[self.index, self.index], + &[T::one() * other, T::zero()], + ), } } } @@ -1429,11 +1461,11 @@ pub trait Powi { } // Variable<'v> ^ Variable<'v> -impl<'v> Powi> for Variable<'v> { - type Output = Variable<'v>; +impl<'v, T> Powi> for Variable<'v, T> where T: DiffOps { + type Output = Variable<'v, T>; #[inline] - fn powi(&self, other: Variable<'v>) -> Self::Output { + fn powi(&self, other: Variable<'v, T>) -> Self::Output { assert!(std::ptr::eq(self.graph, other.graph)); Self::Output { @@ -1443,8 +1475,8 @@ impl<'v> Powi> for Variable<'v> { Arity::Binary, &[self.index, other.index], &[ - other.value * f64::powf(self.value, other.value - 1.), - f64::powf(self.value, other.value) * f64::ln(self.value), + other.value * self.value.powf(other.value - 1.0), + self.value.powf(other.value) * self.value.ln(), ], ), } @@ -1489,7 +1521,10 @@ impl<'v> Powi> for f64 { use std::f64::consts::PI; -impl<'v> Variable<'v> { +impl<'v, T> Variable<'v, T> +where + T: DiffOps, +{ /// Error function. /// d/dx erf(x) = 2e^(-x^2) / sqrt(PI) /// @@ -1513,11 +1548,11 @@ impl<'v> Variable<'v> { Variable { graph: self.graph, - value: errorfunctions::RealErrorFunctions::erf(self.value), + value: self.value.erf(), index: self.graph.push( Arity::Unary, &[self.index], - &[2.0 * self.value.powi(2).neg().exp() / PI.sqrt()], + &[self.value.powi(2).neg().exp() * 2.0 / PI.sqrt()], ), } } @@ -1545,11 +1580,11 @@ impl<'v> Variable<'v> { Variable { graph: self.graph, - value: errorfunctions::RealErrorFunctions::erfc(self.value), + value: self.value.erfc(), index: self.graph.push( Arity::Unary, &[self.index], - &[((2.0 * self.value.powi(2).neg().exp()).neg() / PI.sqrt())], + &[((self.value.powi(2).neg().exp()).neg() * 2.0 / PI.sqrt())], ), } } @@ -1587,8 +1622,11 @@ impl<'v> SubAssign> for f64 { } /// Variable<'v> - Variable<'v> -impl<'v> Sub> for Variable<'v> { - type Output = Variable<'v>; +impl<'v, T> Sub> for Variable<'v, T> +where + T: DiffOps, +{ + type Output = Variable<'v, T>; /// ``` /// # use RustQuant_autodiff::*; @@ -1606,7 +1644,7 @@ impl<'v> Sub> for Variable<'v> { /// assert_eq!(grad.wrt(&y), -1.0); /// ``` #[inline] - fn sub(self, other: Variable<'v>) -> Self::Output { + fn sub(self, other: Variable<'v, T>) -> Self::Output { assert!(std::ptr::eq(self.graph, other.graph)); self.add(other.neg()) diff --git a/crates/RustQuant_autodiff/src/symbol.rs b/crates/RustQuant_autodiff/src/symbol.rs new file mode 100644 index 0000000..bf0f946 --- /dev/null +++ b/crates/RustQuant_autodiff/src/symbol.rs @@ -0,0 +1,28 @@ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// RustQuant: A Rust library for quantitative finance tools. +// Copyright (C) 2023 https://github.com/avhz +// Dual licensed under Apache 2.0 and MIT. +// See: +// - LICENSE-APACHE.md +// - LICENSE-MIT.md +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +use crate::Initial; + +/// A symbolic expression. +#[derive(Debug, Clone, Copy)] +pub enum Expression { + /// A constant real value. + Constant(f64), + /// A variable. + Variable(&'static str), + /// A function applied to an expression, where the positive integer is the index of the expression on the graph. + Unary(&'static str, usize), + /// A function applied to two expressions. + Binary(&'static str, usize, usize), +} + +impl Initial for Expression { + fn zero() -> Self { + Self::Constant(0.0) + } +} diff --git a/crates/RustQuant_autodiff/src/variable.rs b/crates/RustQuant_autodiff/src/variable.rs index 5772174..17697c0 100644 --- a/crates/RustQuant_autodiff/src/variable.rs +++ b/crates/RustQuant_autodiff/src/variable.rs @@ -18,29 +18,76 @@ // IMPORTS // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -use crate::graph::Graph; + +use RustQuant_utils::forward; + +use crate::{graph::Graph, DiffOps}; use std::fmt::Display; // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // STRUCT AND IMPLEMENTATION // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +impl DiffOps for f64 { + fn asin_diff(self) -> f64 { + if (self > -1.0) && (self < 1.0) { + ((1.0 - self.powi(2)).sqrt()).recip() + } else { + f64::NAN + } + } + fn erfc(self) -> f64 { + 1.0 - self.erf() + } + fn erf(self) -> f64 { + errorfunctions::RealErrorFunctions::erf(self) + } + forward! { + Self::abs(self) -> Self; + Self::signum(self) -> Self; + Self::recip(self) -> Self; + Self::powi(self, n: i32) -> Self; + Self::sqrt(self) -> Self; + Self::cbrt(self) -> Self; + Self::powf(self, n: Self) -> Self; + Self::ln_1p(self) -> Self; + Self::exp_m1(self) -> Self; + Self::exp2(self) -> Self; + Self::log2(self) -> Self; + Self::log10(self) -> Self; + Self::exp(self) -> Self; + Self::ln(self) -> Self; + Self::sin(self) -> Self; + Self::cos(self) -> Self; + Self::tan(self) -> Self; + Self::asin(self) -> Self; + Self::acos(self) -> Self; + Self::atan(self) -> Self; + Self::sinh(self) -> Self; + Self::cosh(self) -> Self; + Self::tanh(self) -> Self; + Self::asinh(self) -> Self; + Self::acosh(self) -> Self; + Self::atanh(self) -> Self; + } +} + /// Struct to contain the initial variables. #[derive(Clone, Copy, Debug)] -pub struct Variable<'v> { +pub struct Variable<'v, T = f64> { /// Pointer to the graph. - pub graph: &'v Graph, + pub graph: &'v Graph, /// Index to the vertex. pub index: usize, /// Value associated to the vertex. - pub value: f64, // Value, + pub value: T, // Value, } -impl<'v> Variable<'v> { +impl<'v, T: Copy> Variable<'v, T> { /// Instantiate a new variable. #[must_use] #[inline] - pub const fn new(graph: &'v Graph, index: usize, value: f64) -> Self { + pub const fn new(graph: &'v Graph, index: usize, value: T) -> Self { Variable { graph, index, @@ -51,7 +98,7 @@ impl<'v> Variable<'v> { /// Function to return the value contained in a vertex. #[must_use] #[inline] - pub fn value(&self) -> f64 { + pub fn value(&self) -> T { self.value } @@ -65,10 +112,12 @@ impl<'v> Variable<'v> { /// Function to return the graph. #[must_use] #[inline] - pub fn graph(&self) -> &'v Graph { + pub fn graph(&self) -> &'v Graph { self.graph } +} +impl<'v> Variable<'v, f64> { /// Check if variable is finite. #[must_use] #[inline] diff --git a/crates/RustQuant_autodiff/src/vertex.rs b/crates/RustQuant_autodiff/src/vertex.rs index 59296d2..be76602 100644 --- a/crates/RustQuant_autodiff/src/vertex.rs +++ b/crates/RustQuant_autodiff/src/vertex.rs @@ -9,6 +9,8 @@ use std::fmt; +use crate::DiffOps; + /// Struct defining the vertex of the computational graph. /// /// Operations are assumed to be binary (e.g. x + y), @@ -16,9 +18,9 @@ use std::fmt; /// To deal with unary or nullary operations, we just adjust the weights /// (partials) and the dependencies (parents). #[derive(Clone, Copy, Debug)] -pub struct Vertex { +pub struct Vertex { /// Array that contains the partial derivatives wrt to x and y. - pub partials: [f64; 2], + pub partials: [T; 2], /// Array that contains the indices of the parent vertices. pub parents: [usize; 2], diff --git a/crates/RustQuant_utils/src/lib.rs b/crates/RustQuant_utils/src/lib.rs index fdbeb77..d7115fa 100644 --- a/crates/RustQuant_utils/src/lib.rs +++ b/crates/RustQuant_utils/src/lib.rs @@ -36,6 +36,19 @@ macro_rules! assert_approx_equal { }; } +/// Forward a method call to the struct. +#[macro_export] +macro_rules! forward { + ($( Self :: $method:ident ( self $( , $arg:ident : $ty:ty )* ) -> $ret:ty ; )*) => { + $( + #[inline] + fn $method(self $( , $arg : $ty )* ) -> $ret { + Self::$method(self $( , $arg )* ) + } + )* + }; +} + /// Plot a vector of values. #[macro_export] macro_rules! plot_vector {