泛型#
Warp 支持编写泛型内核和函数,它们充当可以使用不同具体类型实例化的模板。这允许您编写一次代码,并将其与多种数据类型一起重用。此页面上讨论的概念也适用于 运行时内核创建。
泛型内核#
泛型内核定义语法与常规内核相同,但您可以使用 typing.Any
代替具体类型
from typing import Any
# generic kernel definition using Any as a placeholder for concrete types
@wp.kernel
def scale(x: wp.array(dtype=Any), s: Any):
i = wp.tid()
x[i] = s * x[i]
data = [1, 2, 3, 4, 5, 6, 7, 8, 9]
n = len(data)
x16 = wp.array(data, dtype=wp.float16)
x32 = wp.array(data, dtype=wp.float32)
x64 = wp.array(data, dtype=wp.float64)
# run the generic kernel with different data types
wp.launch(scale, dim=n, inputs=[x16, wp.float16(3)])
wp.launch(scale, dim=n, inputs=[x32, wp.float32(3)])
wp.launch(scale, dim=n, inputs=[x64, wp.float64(3)])
print(x16)
print(x32)
print(x64)
[ 3. 6. 9. 12. 15. 18. 21. 24. 27.]
[ 3. 6. 9. 12. 15. 18. 21. 24. 27.]
[ 3. 6. 9. 12. 15. 18. 21. 24. 27.]
在底层,Warp 将自动生成泛型内核的新实例以匹配给定的参数类型。
类型推断#
当启动泛型内核时,Warp 会从参数推断具体类型。wp.launch()
处理泛型内核时没有任何特殊语法,但我们应该注意作为参数传递的数据类型,以确保推断出正确的类型
标量可以作为常规 Python 数值传递(例如,
42
或0.5
)。Python 整数被解释为wp.int32
,Python 浮点数值被解释为wp.float32
。要指定不同的数据类型并避免歧义,应使用 Warp 数据类型(例如,wp.int64(42)
或wp.float16(0.5)
)。向量和矩阵应作为 Warp 类型传递,而不是元组或列表(例如,
wp.vec3f(1.0, 2.0, 3.0)
或wp.mat22h([[1.0, 0.0], [0.0, 1.0]])
)。Warp 数组和结构体可以正常传递。
隐式实例化#
当您使用一组新的数据类型启动泛型内核时,Warp 会自动使用给定的类型创建此内核的新实例。这很方便,但这种隐式实例化存在一些缺点。
考虑以下三个泛型内核启动
wp.launch(scale, dim=n, inputs=[x16, wp.float16(3)])
wp.launch(scale, dim=n, inputs=[x32, wp.float32(3)])
wp.launch(scale, dim=n, inputs=[x64, wp.float64(3)])
在每次启动期间,都会生成一个新的内核实例,这会强制重新加载模块。您可能会在输出中看到类似这样的内容
Module __main__ load on device 'cuda:0' took 170.37 ms
Module __main__ load on device 'cuda:0' took 171.43 ms
Module __main__ load on device 'cuda:0' took 179.49 ms
这会导致几个潜在问题
重复重建模块的开销会影响程序的整体性能。
在旧的 CUDA 驱动程序上不允许在图形捕获期间重新加载模块,这将导致捕获失败。
显式实例化可用于克服这些问题。
显式实例化#
Warp 允许使用不同的类型显式声明泛型内核的实例。一种方法是使用 @wp.overload
装饰器
@wp.overload
def scale(x: wp.array(dtype=wp.float16), s: wp.float16):
...
@wp.overload
def scale(x: wp.array(dtype=wp.float32), s: wp.float32):
...
@wp.overload
def scale(x: wp.array(dtype=wp.float64), s: wp.float64):
...
wp.launch(scale, dim=n, inputs=[x16, wp.float16(3)])
wp.launch(scale, dim=n, inputs=[x32, wp.float32(3)])
wp.launch(scale, dim=n, inputs=[x64, wp.float64(3)])
@wp.overload
装饰器允许重新声明泛型内核,而无需重复内核代码。内核主体只是被省略号 (...
) 替换。Warp 会跟踪每个内核的已知重载,因此如果重载存在,它将不会再次实例化。如果在内核启动之前声明了所有重载,则模块将仅加载一次,并将所有内核实例放置到位。
我们还可以使用 wp.overload()
作为函数,以获得稍微更简洁的语法。我们只需要指定泛型内核和具体参数类型的列表
wp.overload(scale, [wp.array(dtype=wp.float16), wp.float16])
wp.overload(scale, [wp.array(dtype=wp.float32), wp.float32])
wp.overload(scale, [wp.array(dtype=wp.float64), wp.float64])
也可以提供字典来代替参数列表
wp.overload(scale, {"x": wp.array(dtype=wp.float16), "s": wp.float16})
wp.overload(scale, {"x": wp.array(dtype=wp.float32), "s": wp.float32})
wp.overload(scale, {"x": wp.array(dtype=wp.float64), "s": wp.float64})
字典可能更易于阅读。使用字典时,只需要指定泛型参数,当重载某些参数不是泛型的内核时,这可能会更简洁。
我们可以像这样在一个循环中轻松创建重载
for T in [wp.float16, wp.float32, wp.float64]:
wp.overload(scale, [wp.array(dtype=T), T])
最后,wp.overload()
函数返回具体的内核实例,可以将其保存在变量中
scale_f16 = wp.overload(scale, [wp.array(dtype=wp.float16), wp.float16])
scale_f32 = wp.overload(scale, [wp.array(dtype=wp.float32), wp.float32])
scale_f64 = wp.overload(scale, [wp.array(dtype=wp.float64), wp.float64])
这些实例被视为常规内核,而不是泛型内核。这意味着启动应该更快,因为 Warp 不需要像启动泛型内核时那样从参数推断数据类型。内核参数的类型要求也比泛型内核更宽松,因为 Warp 可以将标量、向量和矩阵转换为已知的所需类型。
# launch concrete kernel instances
wp.launch(scale_f16, dim=n, inputs=[x16, 3])
wp.launch(scale_f32, dim=n, inputs=[x32, 3])
wp.launch(scale_f64, dim=n, inputs=[x64, 3])
泛型函数#
与 Warp 内核一样,我们也可以定义泛型 Warp 函数
# generic function
@wp.func
def f(x: Any):
return x * x
# use generic function in a regular kernel
@wp.kernel
def square_float(a: wp.array(dtype=float)):
i = wp.tid()
a[i] = f(a[i])
# use generic function in a generic kernel
@wp.kernel
def square_any(a: wp.array(dtype=Any)):
i = wp.tid()
a[i] = f(a[i])
data = [1, 2, 3, 4, 5, 6, 7, 8, 9]
n = len(data)
af = wp.array(data, dtype=float)
ai = wp.array(data, dtype=int)
# launch regular kernel
wp.launch(square_float, dim=n, inputs=[af])
print(af)
# launch generic kernel
wp.launch(square_any, dim=n, inputs=[af])
print(af)
wp.launch(square_any, dim=n, inputs=[ai])
print(ai)
[ 1. 4. 9. 16. 25. 36. 49. 64. 81.]
[1.000e+00 1.600e+01 8.100e+01 2.560e+02 6.250e+02 1.296e+03 2.401e+03
4.096e+03 6.561e+03]
[ 1 4 9 16 25 36 49 64 81]
泛型函数可以在常规内核和泛型内核中使用。没有必要显式重载泛型函数。当这些函数在内核中使用时,会自动生成所有必需的函数重载。
type() 运算符#
考虑以下泛型函数
@wp.func
def triple(x: Any):
return 3 * x
由于 Warp 严格的类型规则,在泛型表达式中使用数字字面量(如 3
)是有问题的。算术表达式中的操作数必须具有相同的数据类型,但整数文字总是被视为 wp.int32
。如果 x
具有除 wp.int32
之外的数据类型,则此函数将无法编译,这意味着它根本不是泛型的。
type()
运算符可以解决这个问题。type()
运算符返回其参数的类型,这在泛型函数或内核中非常方便,因为这些函数或内核中的数据类型是预先未知的。我们可以像这样重写函数,使其适用于更广泛的类型
@wp.func
def triple(x: Any):
return type(x)(3) * x
type()
运算符对于 Warp 内核和函数中的类型转换非常有用。例如,这是一个简单的泛型 arange()
内核
@wp.kernel
def arange(a: wp.array(dtype=Any)):
i = wp.tid()
a[i] = type(a[0])(i)
n = 10
ai = wp.empty(n, dtype=wp.int32)
af = wp.empty(n, dtype=wp.float32)
wp.launch(arange, dim=n, inputs=[ai])
wp.launch(arange, dim=n, inputs=[af])
wp.tid()
返回一个整数,但该值在存储到数组中之前会转换为数组的数据类型。或者,我们可以像这样编写我们的 arange()
内核
@wp.kernel
def arange(a: wp.array(dtype=Any)):
i = wp.tid()
a[i] = a.dtype(i)
此变体使用 array.dtype()
运算符,该运算符返回数组内容的类型。
局限性和粗糙之处#
Warp 泛型仍在开发中,并且存在一些限制。
模块重新加载行为#
如 隐式实例化 部分所述,启动新的内核重载会触发内核模块的重新编译。这增加了开销,并且与 Warp 当前的内核缓存策略不太兼容。内核缓存依赖于散列模块的内容,其中包括到目前为止在 Python 程序中遇到的所有具体内核和函数。每当添加新的内核或泛型内核的新实例时,都需要重新加载模块。重新运行 Python 程序会导致相同的内核序列被添加到模块中,这意味着泛型内核的隐式实例化将在每次运行时触发相同的模块重新加载。这显然不是理想的,我们计划将来改进此行为。
只要重载在任何内核启动之前以相同的顺序添加,使用 显式实例化 通常是解决此问题的好方法。
请注意,此问题并非特定于泛型内核。如果在内核定义与内核启动混合在一起,则将新的常规内核添加到模块中也会触发重复的模块重新加载。例如
@wp.kernel
def foo(x: float):
wp.print(x)
wp.launch(foo, dim=1, inputs=[17])
@wp.kernel
def bar(x: float):
wp.print(x)
wp.launch(bar, dim=1, inputs=[42])
此代码还将在每次内核启动期间触发模块重新加载,即使它根本不使用泛型
Module __main__ load on device 'cuda:0' took 155.73 ms
17
Module __main__ load on device 'cuda:0' took 164.83 ms
42
图捕获#
在 CUDA 12.2 或更早版本中,图捕获期间不允许模块重载。内核实例化可能会触发模块重载,这会导致在不支持较新 CUDA 版本的驱动程序上图捕获失败。 解决方法是再次在捕获开始之前显式声明所需的重载。
类型变量#
Warp 的 type()
操作符在原理上类似于 Python 的 type()
函数,但目前无法在 Warp 内核和函数中使用类型作为变量。 例如,目前 不 允许以下操作
@wp.func
def triple(x: Any):
# TODO:
T = type(x)
return T(3) * x
内核重载限制#
目前无法定义具有相同名称但参数数量不同的多个内核,但此限制将来可能会解除。