Skip to content

Commit 2cbe56d

Browse files
committed
Lagrange-interpolator: Create Interpolation class for the Lagrange interpolation method
1 parent 8ae4aa7 commit 2cbe56d

File tree

1 file changed

+129
-0
lines changed

1 file changed

+129
-0
lines changed
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2+
// RustQuant: A Rust library for quantitative finance tools.
3+
// Copyright (C) 2023 https://github.com/avhz
4+
// Dual licensed under Apache 2.0 and MIT.
5+
// See:
6+
// - LICENSE-APACHE.md
7+
// - LICENSE-MIT.md
8+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
9+
10+
//! Module containing functionality for interpolation.
11+
12+
use crate::interpolation::{InterpolationIndex, InterpolationValue, Interpolator};
13+
use RustQuant_error::RustQuantError;
14+
15+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
16+
// STRUCTS & ENUMS
17+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
18+
19+
/// Linear Interpolator.
20+
pub struct LagrangeInterpolator<IndexType, ValueType>
21+
where
22+
IndexType: InterpolationIndex<DeltaDiv = ValueType>,
23+
ValueType: InterpolationValue,
24+
{
25+
/// X-axis values for the interpolator.
26+
pub xs: Vec<IndexType>,
27+
28+
/// Y-axis values for the interpolator.
29+
pub ys: Vec<ValueType>,
30+
31+
/// Whether the interpolator has been fitted.
32+
pub fitted: bool,
33+
}
34+
35+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
36+
// IMPLEMENTATIONS, FUNCTIONS, AND MACROS
37+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
38+
39+
impl<IndexType, ValueType> LagrangeInterpolator<IndexType, ValueType>
40+
where
41+
IndexType: InterpolationIndex<DeltaDiv = ValueType>,
42+
ValueType: InterpolationValue,
43+
{
44+
/// Create a new LagrangeInterpolator.
45+
///
46+
/// # Errors
47+
/// - `RustQuantError::UnequalLength` if ```xs.length() != ys.length()```.
48+
///
49+
/// # Panics
50+
/// Panics if NaN is in the index.
51+
pub fn new(
52+
xs: Vec<IndexType>,
53+
ys: Vec<ValueType>,
54+
) -> Result<LagrangeInterpolator<IndexType, ValueType>, RustQuantError> {
55+
if xs.len() != ys.len() {
56+
return Err(RustQuantError::UnequalLength);
57+
}
58+
59+
let mut tmp: Vec<_> = xs.into_iter().zip(ys).collect();
60+
61+
tmp.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
62+
63+
let (xs, ys): (Vec<IndexType>, Vec<ValueType>) = tmp.into_iter().unzip();
64+
65+
Ok(Self {
66+
xs,
67+
ys,
68+
fitted: false,
69+
})
70+
}
71+
72+
fn cardinal_function(&self, point: IndexType, pivot: IndexType, index: usize) -> ValueType {
73+
let mut lagrange_basis: ValueType = ValueType::one();
74+
for (i, x) in self.xs.iter().enumerate() {
75+
if i != index {
76+
lagrange_basis *= (point - *x) / (pivot - *x);
77+
}
78+
}
79+
lagrange_basis
80+
}
81+
82+
fn lagrange_polynomial(&self, point: IndexType) -> ValueType {
83+
let mut polynomial: ValueType = ValueType::zero();
84+
for (i, (x, y)) in self.xs.iter().zip(&self.ys).enumerate() {
85+
polynomial += *y * self.cardinal_function(point, *x, i);
86+
87+
}
88+
polynomial
89+
}
90+
}
91+
92+
impl<IndexType, ValueType> Interpolator<IndexType, ValueType>
93+
for LagrangeInterpolator<IndexType, ValueType>
94+
where
95+
IndexType: InterpolationIndex<DeltaDiv = ValueType>,
96+
ValueType: InterpolationValue,
97+
{
98+
fn fit(&mut self) -> Result<(), RustQuantError> {
99+
self.fitted = true;
100+
Ok(())
101+
}
102+
103+
fn range(&self) -> (IndexType, IndexType) {
104+
(*self.xs.first().unwrap(), *self.xs.last().unwrap())
105+
}
106+
107+
fn add_point(&mut self, point: (IndexType, ValueType)) {
108+
let idx = self.xs.partition_point(|&x| x < point.0);
109+
self.xs.insert(idx, point.0);
110+
self.ys.insert(idx, point.1);
111+
}
112+
113+
fn interpolate(&self, point: IndexType) -> Result<ValueType, RustQuantError> {
114+
let range = self.range();
115+
if point.partial_cmp(&range.0).unwrap() == std::cmp::Ordering::Less
116+
|| point.partial_cmp(&range.1).unwrap() == std::cmp::Ordering::Greater
117+
{
118+
return Err(RustQuantError::OutsideOfRange);
119+
}
120+
if let Ok(idx) = self
121+
.xs
122+
.binary_search_by(|p| p.partial_cmp(&point).expect("Cannot compare values."))
123+
{
124+
return Ok(self.ys[idx]);
125+
}
126+
127+
Ok(self.lagrange_polynomial(point))
128+
}
129+
}

0 commit comments

Comments
 (0)