Thanks guys, I ended up doing this, feedback welcome
#![allow(non_snake_case)]
use ndarray::Array2;
#[derive(Clone)]
pub struct BilinearInterpolator {
table: Array2<f64>,
x_knots: Vec<f64>,
y_knots: Vec<f64>,
}
impl BilinearInterpolator {
/// Creates a new BilinearInterpolator.
pub fn new(table: Array2<f64>, x_knots: Vec<f64>, y_knots: Vec<f64>) -> Self {
Self {
table,
x_knots,
y_knots,
}
}
/// Evaluates the interpolation at given coordinates without extrapolation.
pub fn eval_no_extrapolation(&self, kx: f64, ky: f64) -> Option<f64> {
bilinear_interpolation(&self.table, &self.x_knots, &self.y_knots, kx, ky)
}
}
/// Performs bilinear interpolation given a table and knots.
fn bilinear_interpolation(
table: &Array2<f64>,
x_knots: &[f64],
y_knots: &[f64],
kx: f64,
ky: f64,
) -> Option<f64> {
// Check if kx and ky are within the bounds of the knots
if kx < x_knots[0] || kx > *x_knots.last()? || ky < y_knots[0] || ky > *y_knots.last()? {
return None;
}
// Find indices of the knots just before the interpolation points
let (x0, x1) = find_knots(x_knots, kx)?;
let (y0, y1) = find_knots(y_knots, ky)?;
// Get the values at the four surrounding points
let q11 = table[(x0, y0)];
let q12 = table[(x0, y1)];
let q21 = table[(x1, y0)];
let q22 = table[(x1, y1)];
// Calculate interpolation weights
let wx1 = (x_knots[x1] - kx) / (x_knots[x1] - x_knots[x0]);
let wx2 = (kx - x_knots[x0]) / (x_knots[x1] - x_knots[x0]);
let wy1 = (y_knots[y1] - ky) / (y_knots[y1] - y_knots[y0]);
let wy2 = (ky - y_knots[y0]) / (y_knots[y1] - y_knots[y0]);
// Perform bilinear interpolation
let p = wy1 * (wx1 * q11 + wx2 * q21) + wy2 * (wx1 * q12 + wx2 * q22);
Some(p)
}
/// Finds the indices of the knots just before the interpolation point.
fn find_knots(knots: &[f64], k: f64) -> Option<(usize, usize)> {
knots.iter().enumerate().find_map(|(i, &val)| {
if i + 1 < knots.len() && k >= val && k <= knots[i + 1] {
Some((i, i + 1))
} else {
None
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn test_find_knots() {
let knots = vec![0.0, 1.0, 2.0, 3.0, 4.0];
assert_eq!(find_knots(&knots, 1.5), Some((1, 2)));
assert_eq!(find_knots(&knots, 0.5), Some((0, 1)));
assert_eq!(find_knots(&knots, 3.5), Some((3, 4)));
assert_eq!(find_knots(&knots, 4.5), None);
assert_eq!(find_knots(&knots, -0.5), None);
}
#[test]
fn test_bilinear_interpolation() {
let table = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let x_knots = vec![0.0, 1.0, 2.0];
let y_knots = vec![0.0, 1.0, 2.0];
assert_eq!(
bilinear_interpolation(&table, &x_knots, &y_knots, 0.5, 0.5),
Some(3.0)
);
assert_eq!(
bilinear_interpolation(&table, &x_knots, &y_knots, 1.5, 1.5),
Some(7.0)
);
assert_eq!(
bilinear_interpolation(&table, &x_knots, &y_knots, 2.5, 2.5),
None
); // Out of bounds
}
#[test]
fn test_eval_no_extrapolation() {
let table = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let x_knots = vec![0.0, 1.0, 2.0];
let y_knots = vec![0.0, 1.0, 2.0];
let interpolator = BilinearInterpolator::new(table, x_knots, y_knots);
assert_eq!(interpolator.eval_no_extrapolation(0.5, 0.5), Some(3.0));
assert_eq!(interpolator.eval_no_extrapolation(1.5, 1.5), Some(7.0));
assert_eq!(interpolator.eval_no_extrapolation(2.5, 2.5), None); // Out of bounds
}
#[test]
fn test_bilinear_interpolation_edge_cases() {
let table = array![[1.0, 2.0], [3.0, 4.0]];
let x_knots = vec![0.0, 1.0];
let y_knots = vec![0.0, 1.0];
// Test corners
assert_eq!(
bilinear_interpolation(&table, &x_knots, &y_knots, 0.0, 0.0),
Some(1.0)
);
assert_eq!(
bilinear_interpolation(&table, &x_knots, &y_knots, 1.0, 0.0),
Some(3.0)
);
assert_eq!(
bilinear_interpolation(&table, &x_knots, &y_knots, 0.0, 1.0),
Some(2.0)
);
assert_eq!(
bilinear_interpolation(&table, &x_knots, &y_knots, 1.0, 1.0),
Some(4.0)
);
}
}