c++ - Template 模板参数和容器 Value 类型

问题背景:

我正在开发一个处理复杂类型张量的 PyTorch CUDA 扩展。我有以下用于启动 CUDA 内核的代码片段

AT_DISPATCH_COMPLEX_TYPES(x.scalar_type(), "my_kernel_function_cuda",
    ([&] {
        my_kernel_function<scalar_t><<<gridDim, blockDim>>>(
            x_.data_ptr<scalar_t>(), h_.data_ptr<scalar_t>(), o_.data_ptr<scalar_t>()
        );
    })
);

https://github.com/pytorch/pytorch/blob/8d08b103be936d78d5d4ed90c0547aeccb8ce166/aten/src/ATen/Dispatch.h#L422 几乎可以为 c10:complex<float>c10:complex<double> 注册和调度函数。

使用模板的内核函数通常是这样的

template<typename scalar_t>
__global__ void my_kernel_function (
    const scalar_t* __restrict__ x, const scalar_t* __restrict__ h,
    scalar_t* __restrict__ o
) {
  // Function Body
}

问题问题:

我需要 store 这些复杂的 values 在 __shared__ 内存中。但是,当 scaler_t 不是基本类型(intfloatdouble 等)时,编译器不喜欢使用 __shared__ scalar_t As[32]

假设编译器最终生成了两个函数 scalar_t -> c10::complex<float>scalar_t -> c10::complex<double>。我正在寻找如何提取 c10::complex::value_type,我们称它为 ctype 在我的内核代码中包含类似

...
__shared__ ctype As_real[32];
__shared__ ctype As_imag[32];
...

我失败的尝试

根据https://stackoverflow.com/a/2024173/2313889https://stackoverflow.com/a/213811/2313889 但没有任何成功。这表明我没有完全理解模板模板参数是如何工作的。

这是我尝试为我的内核提供一个模板功能代码的尝试之一。

template < template < typename > class ComplexContainer, typename ComplexType>
__global__ void my_kernel_function (
    const ComplexContainer<ComplexType> * __restrict__ x, const ComplexContainer<ComplexType>* __restrict__ h,
    ComplexContainer<ComplexType>* __restrict__ o
) {
    printf("Template of template with complex class\n");
}

它确实编译,但它绝对不是我需要的解决方案,因为编译器找不到它

error: no instance of function template "<unnamed>::my_kernel_function" matches the argument list
            argument types are: (c10::complex<float> *, c10::complex<float> *, c10::complex<float> *)

error: no instance of function template "<unnamed>::my_kernel_function" matches the argument list
            argument types are: (c10::complex<double> *, c10::complex<double> *, c10::complex<double> *)

当前替代方案:

作为最后一个资源,我可以使用模板专业化。但是,我会重复代码只是为了更改数据类型,如下所示。

template<typename scalar_t>
__global__ void my_kernel_function (
    const scalar_t* __restrict__ x, const scalar_t* __restrict__ h,
    scalar_t* __restrict__ o
) {
  printf("General Case. It should not be invoked\n");
}

using cfloat = c10::complex<float>;
using cdouble = c10::complex<double>;

template<>
__global__ void my_kernel_function (
    const cfloat* __restrict__ x, const cfloat* __restrict__ h,
    cfloat* __restrict__ o
) {
    __shared__ float As_real[32];
    __shared__ float As_imag[32];
    printf("Specialization 1\n");
}

template<>
__global__ void my_kernel_function (
    const cdouble* __restrict__ x, const cdouble* __restrict__ h,
    cdouble* __restrict__ o
) {
    __shared__ double As_real[32];
    __shared__ double As_imag[32];
    printf("Specialization 2\n");
}

回答1

这个 https://docs.microsoft.com/en-us/cpp/cpp/typename?view=msvc-170#example 可能是您正在寻找的。应用于您的案例,您可以使用 https://fossies.org/dox/pytorch-1.11.0/structc10_1_1complex.html#a7ded384f1e0b9ae867aff68100ebf154,它会像:

template <class complex_t>
__global__ void my_kernel_function (
    const complex_t * __restrict__ x, const complex_t* __restrict__ h,
    complex_t* __restrict__ o
) {
    __shared__ typename complex_t::value_type As_real[32];
    __shared__ typename complex_t::value_type As_imag[32];
    printf("Template of template with complex class: %s --- sizeof: %d\n", __PRETTY_FUNCTION__, sizeof(typename complex_t::value_type));
}

当使用 float 时,每个线程的输出为

Template of template with complex class: void <unnamed>::my_kernel(const scalar_t *) [with scalar_t = c10::complex<float>] --- sizeof: 4

当使用 double 时,每个线程的输出为

...
Template of template with complex class: void <unnamed>::my_kernel(const scalar_t *) [with scalar_t = c10::complex<double>] --- sizeof: 8
...

相似文章

随机推荐

最新文章