We recently released a new crate tch (tch-rs github repo) providing Rust bindings for PyTorch using the C++ api (libtorch). These bindings provide a NumPy like tensor library with GPU acceleration and support for automatic differentiation.
It's still a bit experimental and quickly evolving but the current version can be used to train some convnet models on the cifar-10 dataset on a GPU or some recurrent neural networks on some text data. These examples can be found in the the github repo. There is also a small tutorial using the mnist dataset.
To give an idea of how this looks like, here is the training loop for the convnet example on cifar:
for epoch in 1..150 {
opt.set_lr(learning_rate(epoch));
for (bimages, blabels) in m.train_iter(64).shuffle().to_device(vs.device()) {
let bimages = tch::vision::dataset::augmentation(&bimages, true, 4, 8);
let loss = net
.forward_t(&bimages, true)
.cross_entropy_for_logits(&blabels);
opt.backward_step(&loss);
}
let test_accuracy =
net.batch_accuracy_for_logits(&m.test_images, &m.test_labels, vs.device(), 512);
println!("epoch: {:4} test acc: {:5.2}%", epoch, 100. * test_accuracy,);
}
Any feedback is very welcome. Obviously there is a lot of polishing to be done, more documentation/examples/tutorials will be added in the next few weeks.
I'm commenting on this for the selfish reason of bumping this post so more people see this library, start to use it, and hopefully make it even better.
TLDR: This library is amazing. If you're doing ML/DL/Optimization in Rust, you need to check out this library.
Full utilization of the GPU. (Credit here may really belong to LibTorch). With hand written Cuda I'm only able to achieve 75% ish GPU compute usage (this means I'm using the memory bandwidth poorly). With LibTorch, I'm easily hitting the 90%+ compute usage.
There are more nice thing to say, but I'll save it for later to bump the post again.
If you're looking to do AI / ML / DL / Optimization in Rust, I highly recommend you give this library a try.
Glad that you like it, the GPU optimization credit certainly goes to libtorch and not to tch-rs.
Matching the PyTorch api as closely as possible is indeed a design goal and hopefully it should make it easy to port PyTorch tutorials to rust, e.g. here is a rust version of the translation tutorial from the PyTorch website, or there is also a version of Yolo-v3 for object detection, we will try adding some small write-ups for these examples so that they get easier to re-use.
Also just to mention that there are couple issues, e.g. i64 are often used over usize as indexes (github issue), or there are possible UB around the use of &mut (github issue), but there are some plans to improve this in the near future.
.