Hello,
I am converting the below C++ snippet to Rust. My solution was to create
a trait AccessorTrait and create two structs MultiDimensionalTensorAccessor and SingleDimensionalTensorAccessor and both of them can implement AccessorTrait. and it seems I couldn't figure out how to create a method get(index: usize) in the trait similar to operator[] which can be used with both structs.
template <typename T,
size_t N,
template <typename U> class PtrTraits = DefaultPtrTraits,
typename index_t = int64_t>
class TensorAccessorBase {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
TensorAccessorBase(PtrType data_, const index_t* strides_)
: data_(data_), strides_(strides_) {}
protected:
PtrType data_;
const index_t* strides_;
};
// This can be implemented as MultiDimensionalTensorAccessor<const N>
// because it has N > 1.
template <typename T,
size_t N,
template <typename U> class PtrTraits = DefaultPtrTraits,
typename index_t = int64_t>
class TensorAccessor : public TensorAccessorBase<T, N, PtrTraits, index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
TensorAccessor(PtrType data_, const index_t* strides_)
: TensorAccessorBase<T, N, PtrTraits, index_t>(data_, strides_) {}
// Indexing this returns a TensorAccessor with N-1 as a generic argument. In rust
// if N-1 > 1 it returns MultiDimensionalTensorAccessor<const N-1>
// else SingleDimensionalTensorAccessor if N-1 = 1.
TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
return TensorAccessor<T, N - 1, PtrTraits, index_t>(
this->data_ + this->strides_[0] * i, this->strides_ + 1);
}
};
// This can be implemented as SingleDimensionalTensorAccessor.
template <typename T,
template <typename U> class PtrTraits,
typename index_t>
class TensorAccessor<T, 1, PtrTraits, index_t>
: public TensorAccessorBase<T, 1, PtrTraits, index_t> {
public:
typedef typename PtrTraits<T>::PtrType PtrType;
TensorAccessor(PtrType data_, const index_t* strides_)
: TensorAccessorBase<T, 1, PtrTraits, index_t>(data_, strides_) {}
// Here it only returns T.
T& operator[](index_t i) { return this->data_[this->strides_[0] * i]; }
};
Any help will save my day. Thank you!