Hello everybody. I created a simple ray tracer for solid modeling with rust (~200 lines of code). It was part of a university course. I would be really happy about general feedback on the code, since I am very new to rust.
use std::time::Instant;
use image::ImageBuffer;
use itertools_num::linspace;
use nalgebra::{vector, Vector3};
use f64 as Real;
#[derive(Debug)]
struct Ray {
origin: Vector3<Real>,
direction: Vector3<Real>,
}
impl Ray {
fn new(origin: Vector3<Real>, direction: Vector3<Real>) -> Self {
Self { origin, direction }
}
}
#[derive(Debug, Clone)]
struct SurfaceParameters {
// Ambient
a: Real,
// Diffuse
d: Real,
// Specular
s: Real,
// Specular phong exponent
sp: Real,
// Specular metalness
sm: Real,
}
trait SceneObject {
/// Return t where closest_intersection = t*ray_dir + ray_origin or None if no intersection.
///
/// This function is supposed to only return if both roots are positive, since
/// otherwise this would mean that the ray origin is inside the object.
fn smallest_positive_intersect(&self, ray: &Ray) -> Option<Real>;
fn color(&self) -> Vector3<Real>;
fn normal(&self, coord: Vector3<Real>) -> Vector3<Real>;
fn surface_parameters(&self) -> SurfaceParameters;
}
struct Sphere {
color: Vector3<Real>,
surface_parameters: SurfaceParameters,
center: Vector3<Real>,
radius: Real,
}
impl Sphere {
fn new(color: Vector3<Real>, center: Vector3<Real>, radius: Real) -> Self {
Sphere {
color,
surface_parameters: SurfaceParameters {
a: 1.,
d: 1.,
s: 1.,
sp: 40.,
sm: 0.2,
},
center,
radius,
}
}
}
impl SceneObject for Sphere {
fn smallest_positive_intersect(&self, ray: &Ray) -> Option<Real> {
let dist_vec = ray.origin - self.center;
let b = 2. * dist_vec.dot(&ray.direction);
let c = dist_vec.dot(&dist_vec) - self.radius.powi(2);
let discriminant = b.powi(2) - 4. * c;
if discriminant >= 0. {
let q = -0.5 * (b + b.signum() * discriminant.sqrt());
let t1 = q;
let t2 = c / q;
if t1 > 0. && t2 > 0. {
return Some(t1.min(t2));
}
}
None
}
fn color(&self) -> Vector3<Real> {
self.color
}
fn normal(&self, intersection: Vector3<Real>) -> Vector3<Real> {
(intersection - self.center).normalize()
}
fn surface_parameters(&self) -> SurfaceParameters {
self.surface_parameters.clone()
}
}
struct Light {
position: Vector3<Real>,
color: Vector3<Real>,
}
fn find_closest_intersecting_object<'a>(
objects: &'a Vec<Box<dyn SceneObject>>, ray: &Ray)
-> (Option<&'a Box<dyn SceneObject>>, Real) {
let mut t_min = Real::INFINITY;
let mut closest_object = None;
for object in objects {
let t_opt = object.smallest_positive_intersect(ray);
if let Some(t) = t_opt {
if t < t_min {
t_min = t;
closest_object = Some(object);
}
}
}
(closest_object, t_min)
}
fn phong_shading(light_rays: &Vec<Ray>, lights: Vec<&Light>, normal: Vector3<Real>, view_direction:
Vector3<Real>, object: &Box<dyn SceneObject>) -> Vector3<Real> {
// Source for mathematical model are the lecture notes for phong shading.
let V = view_direction;
let N = normal;
let SurfaceParameters { a: k_a, d: k_d, s: k_s, sp: k_sp, sm: k_sm } = object
.surface_parameters();
let mut sum: Vector3<Real> = Vector3::zeros();
for (light_ray, light) in light_rays.into_iter().zip(lights) {
let L = light_ray.direction;
let R = 2. * L.dot(&N) * N - L;
// Diffuse reflection.
let diffuse_color = k_d * light.color.component_mul(&object.color()) * Real::max(L.dot(&N), 0.);
sum += diffuse_color;
// Specular reflection.
let specular_highlight_color: Vector3<Real> = k_sm * object.color() + (1. - k_sm) * vector![1., 1.,1.];
let specular_color = k_s * specular_highlight_color.component_mul(&light.color) *
Real::max(R.dot(&V), 0.).powf(k_sp);
sum += specular_color;
}
// Ambient light.
let ambient_light = vector![0.3,0.3,0.3]; // Todo: Remove hard coding of ambient lighting.
let ambient_color = k_a * object.color().component_mul(&ambient_light);
sum += ambient_color;
sum
}
fn render() {
let multiplier = 4;
let width = multiplier * 600;
let height = multiplier * 400;
let ratio = width as Real / height as Real;
let camera_pos = vector![0_f64, 0., -1.];
let lights = vec![
Light { position: vector![4., 4., -3.], color: vector![1., 1., 1.] }
];
let scene_objects: Vec<Box<dyn SceneObject>> = vec![
Box::new(Sphere::new(vector![1., 0., 0.], vector![0.0, 0.0, 10.0], 5.)),
Box::new(Sphere::new(vector![0., 1., 0.], vector![0.5, 1.1, 3.5], 0.4)),
Box::new(Sphere::new(vector![0., 1., 0.7], vector![-0.5, 0.4, 4.5], 0.4)),
Box::new(Sphere::new(vector![0., 1., 1.], vector![0.7, 0.7, 2.5], 0.1)),
];
let mut pixels: ImageBuffer<image::Rgb<u8>, _> = ImageBuffer::new(width, height);
for (i, x) in linspace::<Real>(-1., 1., width as usize).enumerate() {
for (j, y) in linspace::<Real>(-1. / ratio, 1. / ratio, height as usize).enumerate() {
let pixel_pos = vector![x, y, 0.];
let direction = (pixel_pos - camera_pos).normalize();
let primary_ray = Ray::new(camera_pos, direction);
let (nearest_object, t_min) = find_closest_intersecting_object(&scene_objects, &primary_ray);
if nearest_object.is_none() {
continue;
}
let nearest_object = nearest_object.unwrap();
let intersection = primary_ray.direction * t_min + camera_pos;
let surface_normal = nearest_object.normal(intersection);
// Find light sources which are not shadowed by an object.
let mut light_rays: Vec<Ray> = vec![];
let mut active_lights: Vec<&Light> = vec![];
for light in lights.iter() {
// Move out intersection slightly to avoid self intersection problem.
let light_ray_origin = intersection + 1e-5 * surface_normal;
let light_ray = Ray::new(
light_ray_origin,
(light.position - light_ray_origin).normalize(),
);
let (_shadowing_object, t_min) = find_closest_intersecting_object(
&scene_objects,
&light_ray,
);
let is_shadowed = t_min < (light.position - intersection).norm();
if !is_shadowed {
light_rays.push(light_ray);
active_lights.push(light);
}
}
let color: Vector3<Real> = phong_shading(
&light_rays,
active_lights,
surface_normal,
-primary_ray.direction,
nearest_object);
let rgb_value: [u8; 3] = (color * 255.)
.iter()
.cloned()
.map(|x| x as u8) // Saturating cast.
.collect::<Vec<u8>>()
.try_into().unwrap();
pixels.put_pixel(
i as u32,
height - 1 - j as u32,
image::Rgb(rgb_value),
);
}
}
pixels.save("rt_image.png").unwrap();
}
fn main() {
let start = Instant::now();
render();
let duration = start.elapsed();
println!("{:?}", duration);
}