问题背景:
我正在开发一个处理复杂类型张量的 PyTorch CUDA 扩展。我有以下用于启动 CUDA 内核的代码片段
AT_DISPATCH_COMPLEX_TYPES(x.scalar_type(), "my_kernel_function_cuda",
([&] {
my_kernel_function<scalar_t><<<gridDim, blockDim>>>(
x_.data_ptr<scalar_t>(), h_.data_ptr<scalar_t>(), o_.data_ptr<scalar_t>()
);
})
);
https://github.com/pytorch/pytorch/blob/8d08b103be936d78d5d4ed90c0547aeccb8ce166/aten/src/ATen/Dispatch.h#L422 几乎可以为 c10:complex<float>
和 c10:complex<double>
注册和调度函数。
使用模板的内核函数通常是这样的
template<typename scalar_t>
__global__ void my_kernel_function (
const scalar_t* __restrict__ x, const scalar_t* __restrict__ h,
scalar_t* __restrict__ o
) {
// Function Body
}
问题问题:
我需要 store 这些复杂的 values 在 __shared__
内存中。但是,当 scaler_t
不是基本类型(int
、float
、double
等)时,编译器不喜欢使用 __shared__ scalar_t As[32]
。
假设编译器最终生成了两个函数 scalar_t
-> c10::complex<float>
和 scalar_t
-> c10::complex<double>
。我正在寻找如何提取 c10::complex::value_type
,我们称它为 ctype
在我的内核代码中包含类似
...
__shared__ ctype As_real[32];
__shared__ ctype As_imag[32];
...
我失败的尝试
根据https://stackoverflow.com/a/2024173/2313889和https://stackoverflow.com/a/213811/2313889 但没有任何成功。这表明我没有完全理解模板模板参数是如何工作的。
这是我尝试为我的内核提供一个模板功能代码的尝试之一。
template < template < typename > class ComplexContainer, typename ComplexType>
__global__ void my_kernel_function (
const ComplexContainer<ComplexType> * __restrict__ x, const ComplexContainer<ComplexType>* __restrict__ h,
ComplexContainer<ComplexType>* __restrict__ o
) {
printf("Template of template with complex class\n");
}
它确实编译,但它绝对不是我需要的解决方案,因为编译器找不到它
error: no instance of function template "<unnamed>::my_kernel_function" matches the argument list
argument types are: (c10::complex<float> *, c10::complex<float> *, c10::complex<float> *)
error: no instance of function template "<unnamed>::my_kernel_function" matches the argument list
argument types are: (c10::complex<double> *, c10::complex<double> *, c10::complex<double> *)
当前替代方案:
作为最后一个资源,我可以使用模板专业化。但是,我会重复代码只是为了更改数据类型,如下所示。
template<typename scalar_t>
__global__ void my_kernel_function (
const scalar_t* __restrict__ x, const scalar_t* __restrict__ h,
scalar_t* __restrict__ o
) {
printf("General Case. It should not be invoked\n");
}
using cfloat = c10::complex<float>;
using cdouble = c10::complex<double>;
template<>
__global__ void my_kernel_function (
const cfloat* __restrict__ x, const cfloat* __restrict__ h,
cfloat* __restrict__ o
) {
__shared__ float As_real[32];
__shared__ float As_imag[32];
printf("Specialization 1\n");
}
template<>
__global__ void my_kernel_function (
const cdouble* __restrict__ x, const cdouble* __restrict__ h,
cdouble* __restrict__ o
) {
__shared__ double As_real[32];
__shared__ double As_imag[32];
printf("Specialization 2\n");
}
回答1
这个 https://docs.microsoft.com/en-us/cpp/cpp/typename?view=msvc-170#example 可能是您正在寻找的。应用于您的案例,您可以使用 https://fossies.org/dox/pytorch-1.11.0/structc10_1_1complex.html#a7ded384f1e0b9ae867aff68100ebf154,它会像:
template <class complex_t>
__global__ void my_kernel_function (
const complex_t * __restrict__ x, const complex_t* __restrict__ h,
complex_t* __restrict__ o
) {
__shared__ typename complex_t::value_type As_real[32];
__shared__ typename complex_t::value_type As_imag[32];
printf("Template of template with complex class: %s --- sizeof: %d\n", __PRETTY_FUNCTION__, sizeof(typename complex_t::value_type));
}
当使用 float
时,每个线程的输出为
Template of template with complex class: void <unnamed>::my_kernel(const scalar_t *) [with scalar_t = c10::complex<float>] --- sizeof: 4
当使用 double
时,每个线程的输出为
...
Template of template with complex class: void <unnamed>::my_kernel(const scalar_t *) [with scalar_t = c10::complex<double>] --- sizeof: 8
...