New crate: tch, PyTorch bindings for rust

#1

We recently released a new crate tch (tch-rs github repo) providing Rust bindings for PyTorch using the C++ api (libtorch). These bindings provide a NumPy like tensor library with GPU acceleration and support for automatic differentiation.

It’s still a bit experimental and quickly evolving but the current version can be used to train some convnet models on the cifar-10 dataset on a GPU or some recurrent neural networks on some text data. These examples can be found in the the github repo. There is also a small tutorial using the mnist dataset.

To give an idea of how this looks like, here is the training loop for the convnet example on cifar:

    for epoch in 1..150 {
        opt.set_lr(learning_rate(epoch));
        for (bimages, blabels) in m.train_iter(64).shuffle().to_device(vs.device()) {
            let bimages = tch::vision::dataset::augmentation(&bimages, true, 4, 8);
            let loss = net
                .forward_t(&bimages, true)
                .cross_entropy_for_logits(&blabels);
            opt.backward_step(&loss);
        }
        let test_accuracy =
            net.batch_accuracy_for_logits(&m.test_images, &m.test_labels, vs.device(), 512);
        println!("epoch: {:4} test acc: {:5.2}%", epoch, 100. * test_accuracy,);
    }

Any feedback is very welcome. Obviously there is a lot of polishing to be done, more documentation/examples/tutorials will be added in the next few weeks.

4 Likes