Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/RustQuant_autodiff/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ RustQuant = { path = "../RustQuant" }
ndarray = { workspace = true }
errorfunctions = { workspace = true }
RustQuant_utils = { workspace = true }
num-traits = "0.2.19"

## ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
## RUSTDOC CONFIGURATION
Expand Down
14 changes: 10 additions & 4 deletions crates/RustQuant_autodiff/src/accumulate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
//! - `DVector<Variable<'v>>` <- Currently not possible due to lifetimes
//! - `Array<Variable<'v>, Ix2>` <- Work in progress

use std::ops::AddAssign;

use num_traits::{One, Zero};

use crate::DiffOps;

use super::variable::Variable;

/// Trait to reverse accumulate the gradient for different types.
Expand All @@ -25,18 +31,18 @@ pub trait Accumulate<OUT> {
fn accumulate(&self) -> OUT;
}

impl Accumulate<Vec<f64>> for Variable<'_> {
impl<T> Accumulate<Vec<T>> for Variable<'_, T> where T: Zero + One + Copy + AddAssign {
/// Function to reverse accumulate the gradient for a `Variable`.
/// 1. Allocate the array of adjoints.
/// 2. Set the seed (dx/dx = 1).
/// 3. Traverse the graph backwards, updating the adjoints for the parent vertices.
#[inline]
fn accumulate(&self) -> Vec<f64> {
fn accumulate(&self) -> Vec<T> {
// Set the seed.
// The seed is the derivative of the output with respect to itself.
// dy/dy = 1
let mut adjoints = vec![0.0; self.graph.len()];
adjoints[self.index] = 1.0; // SEED
let mut adjoints = vec![T::zero(); self.graph.len()];
adjoints[self.index] = T::one(); // SEED

// Traverse the graph backwards and update the adjoints for the parent vertices.
// This is simply the generalised chain rule.
Expand Down
88 changes: 88 additions & 0 deletions crates/RustQuant_autodiff/src/diff_traits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// RustQuant: A Rust library for quantitative finance tools.
// Copyright (C) 2023 https://github.com/avhz
// Dual licensed under Apache 2.0 and MIT.
// See:
// - LICENSE-APACHE.md
// - LICENSE-MIT.md
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

use std::ops::{AddAssign, Neg};

use num_traits::{NumOps, One, Zero};

/// A nonadditive zero trait
pub trait Initial: Sized {
/// return zero
fn zero() -> Self;
}

impl<T> Initial for T where T: Zero {
fn zero() -> Self {
Zero::zero()
}
}

/// A variable which can be evaluated upon by elementary operations.
pub trait DiffOps:
NumOps + One + Zero + AddAssign + Neg<Output = Self> + Sized + Clone + Copy + NumOps<f64>
{
/// Compute the derivative asin of the variable.
fn asin_diff(self) -> Self;
/// Compute the error function of the variable.
fn erf(self) -> Self;
/// Compute the complementary error function of the variable.
fn erfc(self) -> Self;
/// Compute the absolute value of the variable.
fn abs(self) -> Self;
/// Compute the signum of the variable.
fn signum(self) -> Self;
/// Compute the reciprocal of the variable.
fn recip(self) -> Self;
/// Compute the square root of the variable.
fn sqrt(self) -> Self;
/// Compute the cube root of the variable.
fn cbrt(self) -> Self;
/// Compute the integer power of the variable.
fn powi(self, n: i32) -> Self;
/// Compute the real power of the variable.
fn powf(self, other: Self) -> Self;
/// Compute the exponential of the variable.
fn exp(self) -> Self;
/// Compute the natural logarithm of the variable.
fn ln(self) -> Self;
/// Compute the natural logarithm of 1 + the variable.
fn ln_1p(self) -> Self;
/// Compute the exponential of the variable minus 1.
fn exp_m1(self) -> Self;
/// Compute the base 2 exponential of the variable.
fn exp2(self) -> Self;
/// Compute the binary logarithm of the variable.
fn log2(self) -> Self;
/// Compute the decimal logarithm of the variable.
fn log10(self) -> Self;
/// Compute the sine of the variable.
fn sin(self) -> Self;
/// Compute the cosine of the variable.
fn cos(self) -> Self;
/// Compute the tangent of the variable.
fn tan(self) -> Self;
/// Compute the arcsine of the variable.
fn asin(self) -> Self;
/// Compute the acosine of the variable.
fn acos(self) -> Self;
/// Compute the arctangent of the variable.
fn atan(self) -> Self;
/// Compute the sinh of the variable.
fn sinh(self) -> Self;
/// Compute the cosh of the variable.
fn cosh(self) -> Self;
/// Compute the tanh of the variable.
fn tanh(self) -> Self;
/// Compute the arcsinh of the variable.
fn asinh(self) -> Self;
/// Compute the arccosh of the variable.
fn acosh(self) -> Self;
/// Compute the atanh of the variable.
fn atanh(self) -> Self;
}
22 changes: 12 additions & 10 deletions crates/RustQuant_autodiff/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
// IMPORTS
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

use crate::{variable::Variable, Arity, Vertex};
use num_traits::Zero;

use crate::{variable::Variable, Arity, DiffOps, Initial, Vertex};
use std::cell::RefCell;

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -26,9 +28,9 @@ use std::cell::RefCell;

/// Struct to contain the graph (Wengert list), as a vector of `Vertex`s.
#[derive(Debug, Clone)]
pub struct Graph {
pub struct Graph<T = f64> {
/// Vector containing the vertices in the Wengert List.
pub vertices: RefCell<Vec<Vertex>>,
pub vertices: RefCell<Vec<Vertex<T>>>,
}
// pub struct Graph(RefCell<Rc<[Vertex]>>);

Expand All @@ -40,7 +42,7 @@ impl Default for Graph {
}

/// Implementation for the `Graph` struct.
impl Graph {
impl<T> Graph<T> where T: Initial + Clone + Copy {
/// Instantiate a new graph.
#[must_use]
#[inline]
Expand Down Expand Up @@ -74,7 +76,7 @@ impl Graph {
/// Add a new variable to the graph.
/// Returns a new `Variable` instance (the contents of a vertex).
#[inline]
pub fn var(&self, value: f64) -> Variable {
pub fn var(&self, value: T) -> Variable<T> {
Variable {
graph: self,
value,
Expand All @@ -85,7 +87,7 @@ impl Graph {
/// Add multiple variables (a slice) to the graph.
/// Useful for larger functions with many inputs.
#[inline]
pub fn vars<'v>(&'v self, values: &[f64]) -> Vec<Variable<'v>> {
pub fn vars<'v>(&'v self, values: &[T]) -> Vec<Variable<'v, T>> {
values.iter().map(|&val| self.var(val)).collect()
}

Expand Down Expand Up @@ -113,7 +115,7 @@ impl Graph {
self.vertices
.borrow_mut()
.iter_mut()
.for_each(|vertex| vertex.partials = [0.0; 2]);
.for_each(|vertex| vertex.partials = [T::zero(), T::zero()]);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -122,7 +124,7 @@ impl Graph {

/// Pushes a vertex to the graph.
#[inline]
pub fn push(&self, arity: Arity, parents: &[usize], partials: &[f64]) -> usize {
pub fn push(&self, arity: Arity, parents: &[usize], partials: &[T]) -> usize {
let mut vertices = self.vertices.borrow_mut();
let len = vertices.len();

Expand All @@ -140,7 +142,7 @@ impl Graph {
assert!(parents.is_empty());

Vertex {
partials: [0.0, 0.0],
partials: [T::zero(), T::zero()],
parents: [len, len],
}
}
Expand All @@ -157,7 +159,7 @@ impl Graph {
assert!(parents.len() == 1);

Vertex {
partials: [partials[0], 0.0],
partials: [partials[0], T::zero()],
parents: [parents[0], len],
}
}
Expand Down
8 changes: 8 additions & 0 deletions crates/RustQuant_autodiff/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ pub use vertex::*;
pub mod overload;
pub use overload::*;

/// `DiffOps`s for `autodiff`.
pub mod diff_traits;
pub use diff_traits::*;

/// `Symbol`s for `autodiff`.
pub mod symbol;
pub use symbol::*;

/// `Variable`s for `autodiff`.
pub mod variable;
pub use variable::*;
Loading