Help with PyO3 bindings: the trait `PyClass` is not implemented for `VarBuilderArgs<

I'm new to Rust and experimenting with PyO3. I'm trying to implement a high level API for HuggingFace Candle library and generate python bindings for it. I'm using maturin tool.

Here is how my code looks like:

Code
use pyo3::prelude::*;

use candle_core::{Tensor, Result};
use candle_nn::{Module, Linear, VarBuilder};

const VOTE_DIM: usize = 2;
const RESULTS: usize = 1;
const EPOCHS: usize = 10;
const LAYER1_OUT_SIZE: usize = 4;
const LAYER2_OUT_SIZE: usize = 2;
const LEARNING_RATE: f64 = 0.05;


pub trait Layer{
    fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}

#[pyclass]
struct Sequential{
    linear1: Linear,
    linear2: Linear
}

#[pymethods]
impl Sequential{
    #[new]
    fn new(vb: VarBuilder) -> Result<Self> {
        let linear1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vb.pp("linear1"))?;
        let linear2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vb.pp("linear2"))?;
        Ok(Self {linear1, linear2})
    }
}


impl Layer for Sequential{
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let xs = self.linear1.forward(xs)?;
        self.linear2.forward(&xs)
    }
}

#[pymodule]
fn candle_keras(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_class::<Sequential>()?;
    Ok(())
}

When I build this code with maturin build I get following error:

StackTrace of Error
🔗 Found pyo3 bindings
🐍 Found CPython 3.10 at /home/codespace/.python/current/bin/python3
📡 Using build options features from pyproject.toml
   Compiling candle_keras v0.1.0 (/workspaces/candle-keras)
error[E0277]: the trait bound `VarBuilderArgs<'_, Box<dyn SimpleBackend>>: PyClass` is not satisfied
  --> src/lib.rs:27:16
   |
27 |     fn new(vb: VarBuilder) -> Result<Self> {
   |                ^^^^^^^^^^ the trait `PyClass` is not implemented for `VarBuilderArgs<'_, Box<dyn SimpleBackend>>`
   |
   = help: the trait `PyClass` is implemented for `Sequential`
   = note: required for `VarBuilderArgs<'_, Box<dyn SimpleBackend>>` to implement `FromPyObject<'_>`
   = note: required for `VarBuilderArgs<'_, Box<dyn SimpleBackend>>` to implement `PyFunctionArgument<'_, '_>`
note: required by a bound in `extract_argument`
  --> /home/codespace/.cargo/registry/src/index.crates.io-6f17d22bba15001f/pyo3-0.19.2/src/impl_/extract_argument.rs:86:8
   |
80 | pub fn extract_argument<'a, 'py, T>(
   |        ---------------- required by a bound in this function
...
86 |     T: PyFunctionArgument<'a, 'py>,
   |        ^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `extract_argument`

error[E0599]: no method named `convert` found for enum `Result` in the current scope
  --> src/lib.rs:24:1
   |
24 | #[pymethods]
   | ^^^^^^^^^^^^ method not found in `Result<Sequential, Error>`
   |
note: the method `convert` exists on the type `Sequential`
  --> /home/codespace/.cargo/registry/src/index.crates.io-6f17d22bba15001f/pyo3-0.19.2/src/callback.rs:31:5
   |
31 |     fn convert(self, py: Python<'_>) -> PyResult<Target>;
   |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   = note: this error originates in the attribute macro `pymethods` (in Nightly builds, run with -Z macro-backtrace for more info)
help: use the `?` operator to extract the `Sequential` value, propagating a `Result::Err` value to the caller
   |
24 | #[pymethods]?
   |             +

Some errors have detailed explanations: E0277, E0599.
For more information about an error, try `rustc --explain E0277`.
error: could not compile `candle_keras` (lib) due to 2 previous errors
💥 maturin failed
  Caused by: Failed to build a native library through cargo
  Caused by: Cargo build finished with "exit status: 101": `PYO3_ENVIRONMENT_SIGNATURE="cpython-3.10-64bit" PYO3_PYTHON="/home/codespace/.python/current/bin/python3" PYTHON_SYS_EXECUTABLE="/home/codespace/.python/current/bin/python3" "cargo" "rustc" "--features" "pyo3/extension-module" "--message-format" "json-render-diagnostics" "--manifest-path" "/workspaces/candle-keras/Cargo.toml" "--lib"`

Could someone help me out?

What it breaks down to is that candle_nn::var_builder::VarBuilder doesn't implement PyClass. Since VarBuilder doesn't live in your own crate, you can't just implement PyClass for it. I guess your best option is to write a new type wrapping VarBuilder<'static>, implement PyClass for that and pass that to Sequential. Other than that, candle_core::error::Error also doesn't implement Into<PyErr>, so returning it from your constructor won't work. Here a minimal snippet that compiles on my machine:

use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;

use candle_nn::Linear;

const VOTE_DIM: usize = 2;
const LAYER1_OUT_SIZE: usize = 4;
const LAYER2_OUT_SIZE: usize = 2;

#[pyclass]
struct VarBuilder(candle_nn::var_builder::VarBuilder<'static>);

#[pyclass]
struct Sequential {
    linear1: Linear,
    linear2: Linear,
}

#[pymethods]
impl Sequential {
    #[new]
    fn new(vb: &VarBuilder) -> PyResult<Self> {
        let linear1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vb.0.pp("linear1"))
            .map_err(|_| PyErr::new::<PyTypeError, _>("Error message"))?;

        let linear2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vb.0.pp("linear2"))
            .map_err(|_| PyErr::new::<PyTypeError, _>("Error message"))?;

        Ok(Self { linear1, linear2 })
    }
}
1 Like

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.