代码生成#

概述#

Warp 核函数按 Python 模块分组。在设备上运行之前,必须将其翻译并编译成适用于该设备架构的代码。模块中的所有核函数会一起编译,这比单独编译每个核函数要快。启动核函数时,Warp 会检查模块是否最新,如果需要,会对其进行编译。在运行时向模块添加新的核函数会修改模块,这意味着在下次启动时需要重新加载该模块。

@wp.kernel
def kernel_foo():
    print("foo")

wp.launch(kernel_foo, dim=1)

@wp.kernel
def kernel_bar():
    print("bar")

wp.launch(kernel_bar, dim=1)

在上面的代码片段中,核函数定义与核函数启动穿插进行。为了执行 kernel_foo,模块会在第一次启动时进行编译。定义 kernel_bar 修改了模块,因此在第二次启动时需要重新编译。

Module __main__ 6cd1d53 load on device 'cuda:0' took 168.19 ms  (compiled)
foo
Module __main__ c7c0e9a load on device 'cuda:0' took 160.35 ms  (compiled)
bar

对于包含大量复杂核函数的模块,编译可能需要很长时间,因此 Warp 会缓存已编译的模块,并在程序下次运行时重用它们。

Module __main__ 6cd1d53 load on device 'cuda:0' took 4.97 ms  (cached)
foo
Module __main__ c7c0e9a load on device 'cuda:0' took 0.40 ms  (cached)
bar

加载缓存模块要快得多,但并非没有开销。此外,模块重新加载可能会在 CUDA graph capture 期间引发问题,因此有充分的理由尽量避免重新加载。

避免模块重新加载的最佳方法是在启动任何核函数之前定义所有核函数。这样,模块就只会编译一次。

@wp.kernel
def kernel_foo():
    print("foo")

@wp.kernel
def kernel_bar():
    print("bar")

wp.launch(kernel_foo, dim=1)
wp.launch(kernel_bar, dim=1)
Module __main__ c7c0e9a load on device 'cuda:0' took 174.57 ms  (compiled)
foo
bar

在后续运行中,它只会从核函数缓存中加载一次。

Module __main__ c7c0e9a load on device 'cuda:0' took 4.96 ms  (cached)
foo
bar

为了避免不必要的模块重新加载,Warp 会尝试识别重复的核函数。例如,此程序在循环中创建核函数,但它们始终是相同的,因此模块无需在每次启动时重新编译。

for i in range(3):

    @wp.kernel
    def kernel_hello():
        print("hello")

    wp.launch(kernel_hello, dim=1)

Warp 会过滤掉重复的核函数,因此模块只会加载一次。

Module __main__ 8194f57 load on device 'cuda:0' took 178.24 ms  (compiled)
hello
hello
hello

Warp 为 CPU/GPU 生成 C++/CUDA 源代码,并将 .cpp/.cu 源文件存储在核函数缓存的模块目录下。Warp 初始化 期间会打印核函数缓存文件夹路径,并且在 Warp 初始化后可以通过 warp.config.kernel_cache_dir 配置设置 获取。

考虑以下示例

@wp.func
def my_func(a: float, b: float):
    c = wp.sin(b) * a
    return c

生成的 CUDA 代码看起来与此类似

// example.py:5
static CUDA_CALLABLE wp::float32 my_func_0(
    wp::float32 var_a,
    wp::float32 var_b)
{
    //---------
    // primal vars
    wp::float32 var_0;
    wp::float32 var_1;
    //---------
    // forward
    // def my_func(a: float, b: float):                                                       <L 6>
    // c = wp.sin(b) * a                                                                      <L 7>
    var_0 = wp::sin(var_b);
    var_1 = wp::mul(var_0, var_a);
    // return c                                                                               <L 8>
    return var_1;
}

生成代码遵循静态单赋值 (SSA) 形式。为了提高可读性,代码中插入了引用原始 Python 源代码行的注释。除了前向传播之外,还会生成梯度函数,如果提供了自定义回放函数,则也会生成回放函数。

Warp 将生成的源代码传递给本地编译器(例如,用于 CPU 的 LLVM 和用于 CUDA 的 NVRTC),以生成在启动核函数时调用的可执行代码。

外部引用和常量#

Warp 核函数可以访问在其自身外部定义的常规 Python 变量,前提是这些变量是支持的类型。此类外部引用在核函数中被视为编译时常量。在不同设备上运行的代码无法访问 Python 解释器的状态,因此这些变量会按值折叠到核函数中。

C = 42

@wp.kernel
def k():
    print(C)

wp.launch(k, dim=1)

在代码生成期间,外部变量 C 变为常量。

{
    //---------
    // primal vars
    const wp::int32 var_0 = 42;
    //---------
    // forward
    // def k():
    // print(C)
    wp::print(var_0);
}

支持的常量类型#

只有值类型才能在 Warp 核函数中用作常量。这包括整数、浮点数、向量 (wp.vec*)、矩阵 (wp.mat*) 和其他内置数学类型。尝试捕获其他变量类型将导致异常。

global_array = wp.zeros(5, dtype=int)

@wp.kernel
def k():
    tid = wp.tid()
    global_array[tid] = 42  # referencing external arrays is not allowed!

wp.launch(k, dim=global_array.shape, inputs=[])

输出

TypeError: Invalid external reference type: <class 'warp.types.array'>

无法捕获数组的原因是它们存在于特定设备上,并包含指向设备内存的指针,这将导致核函数无法在不同设备之间移植。数组应始终作为核函数输入传递。

wp.constant() 的用法#

在旧版本的 Warp 中,需要在核函数中使用 wp.constant() 来声明可以使用的常量。现在不再需要这样做,但为了向后兼容,仍支持旧语法。wp.constant() 仍可用于检查某个值是否可以在核函数中引用。

x = wp.constant(17.0)  # ok
v = wp.constant(wp.vec3(1.0, 2.0, 3.0))  # ok
a = wp.constant(wp.zeros(n=5, dtype=int))  # error, invalid constant type

@wp.kernel
def k():
    tid = wp.tid()
    a[tid] = x * v

在此代码片段中,使用 wp.constant() 声明数组时将引发 TypeError。如果省略 wp.constant(),则错误将在代码生成期间稍后引发,这可能更难调试。

更新常量#

在 Warp 核函数中使用外部变量的一个限制是 Warp 不知道何时修改了值。

C = 17

@wp.kernel
def k():
    print(C)

wp.launch(k, dim=1)

# redefine constant
C = 42

wp.launch(k, dim=1)

这会打印

Module __main__ 4494df2 load on device 'cuda:0' took 163.54 ms  (compiled)
17
17

在第一次启动核函数 k 时,核函数使用 C 的现有值 (17) 进行编译。由于 C 只是一个普通的 Python 变量,Warp 无法检测何时修改了它。因此,在第二次启动时,再次打印旧值。

解决此限制的一种方法是告知 Warp 模块已修改。

C = 17

@wp.kernel
def k():
    print(C)

wp.launch(k, dim=1)

# redefine constant
C = 42

# tell Warp that the module was modified
k.module.mark_modified()

wp.launch(k, dim=1)

这会产生更新的输出

Module __main__ 4494df2 load on device 'cuda:0' took 167.92 ms  (compiled)
17
Module __main__ 9a0664f load on device 'cuda:0' took 164.83 ms  (compiled)
42

请注意,调用 module.mark_modified() 导致模块在第二次启动时使用 C 的最新值重新编译。

注意

Module 类和 mark_modified() 方法被认为是内部的。计划提供用于处理模块的公共 API,但目前它们可能会更改,恕不另行通知。程序不应过度依赖 mark_modified() 方法,但在紧急情况下可以使用它。

静态表达式#

我们经常遇到这样的情况:需要针对给定输入对核函数进行特殊化,或者在代码执行时代码的某些部分是静态的。使用静态表达式,我们可以编写在声明 Warp 函数或核函数时评估的 Python 表达式。

wp.static(...) 表达式允许用户在定义包含该表达式的 Warp 函数或核函数时运行任意 Python 代码。wp.static(expr) 接受一个 Python 表达式,并将其替换为结果。请注意,表达式只能访问在声明表达式时可以评估的变量。这包括全局变量以及在定义 Warp 函数或核函数的闭包中捕获的变量。此外,可以访问核函数或函数内部的 Warp 常量,例如用于静态 for 循环的常量迭代变量(即,在代码生成时范围已知)。

wp.static() 的结果必须是以下类型之一的非空值

  • Warp 函数

  • 字符串

  • 核函数内 Warp 支持的任何类型(例如标量、结构体、矩阵、向量等),但不包括 Warp 数组或包含 Warp 数组的结构体

示例:静态数学表达式#

import warp as wp
import scipy.linalg

@wp.kernel
def my_kernel():
    static_var = wp.static(3 + 2)
    # we can call arbitrary Python code inside wp.static()
    static_norm = wp.static(wp.float64(scipy.linalg.norm([3, 4])))
    wp.printf("static_var = %i\n", static_var)
    wp.printf("static_norm = %f\n", static_norm)

wp.launch(my_kernel, 1)

静态表达式在评估 @wp.kernel 装饰器时进行评估,并在代码中被各自的常量结果值替换。因此,生成的代码将包含硬编码在源文件中的表达式结果(显示了缩写版本)。

const wp::int32 var_0 = 5;
const wp::float64 var_1 = 5.0;
const wp::str var_2 = "static_var = %i\n";
const wp::str var_3 = "static_norm = %f\n";

// wp.printf("static_var = %i\n", static_var)                                             <L 10>
printf(var_2, var_0);
// wp.printf("static_norm = %f\n", static_norm)                                           <L 11>
printf(var_3, var_1);

示例:静态条件判断#

通过在分支条件中使用 wp.static() 产生常量布尔值,可以将常量 if/else/elif 条件从生成的代码中移除。这可以通过避免分支来提高性能,并且对于生成专门的核函数很有用。

import warp as wp

available_colors = {"red", "green", "blue"}

@wp.kernel
def my_kernel():
    if wp.static("red" in available_colors):
        print("red is available")
    else:
        print("red is not available")

全局变量 available_colors 在声明核函数时已知,生成的代码将只包含执行的分支。

const wp::str var_1 = "red is available";
wp::print(var_1);

示例:静态循环展开#

静态表达式可用于在代码生成期间展开 for 循环。我们在循环的 range 内部放置 wp.static() 表达式,以产生可以展开的静态 for 循环。迭代变量成为常量,因此可以从循环体内的静态表达式中访问。

import warp as wp

def loop_limit():
    return 3

@wp.kernel
def my_kernel():
    for i in range(wp.static(loop_limit())):
        static_i = wp.static(i)
        wp.printf("i = %i\n", static_i)

wp.launch(my_kernel, 1)

生成的代码将不包含 for 循环,而是循环体将重复三次。

const wp::int32 var_0 = 3;
const wp::int32 var_1 = 0;
const wp::int32 var_2 = 0;
const wp::str var_3 = "i = %i\n";
const wp::int32 var_4 = 1;
const wp::int32 var_5 = 1;
const wp::str var_6 = "i = %i\n";
const wp::int32 var_7 = 2;
const wp::int32 var_8 = 2;
const wp::str var_9 = "i = %i\n";
printf(var_3, var_2);
printf(var_6, var_5);
printf(var_9, var_8);

示例:函数指针#

wp.static(...) 也可以返回一个 Warp 函数。这对于根据声明 Warp 函数或核函数时可用的信息来特殊化核函数或函数,或者自动为不同类型生成重载非常有用。

import warp as wp

@wp.func
def do_add(a: float, b: float):
    return a + b

@wp.func
def do_sub(a: float, b: float):
    return a - b

@wp.func
def do_mul(a: float, b: float):
    return a * b

op_handlers = {
    "add": do_add,
    "sub": do_sub,
    "mul": do_mul,
}

inputs = wp.array([[1, 2], [3, 0]], dtype=wp.float32)
outputs = wp.empty(2, dtype=wp.float32)

for op in op_handlers.keys():

    @wp.kernel
    def operate(input: wp.array(dtype=inputs.dtype, ndim=2), output: wp.array(dtype=wp.float32)):
        tid = wp.tid()
        a, b = input[tid, 0], input[tid, 1]
        # retrieve the right function to use for the captured dtype variable
        output[tid] = wp.static(op_handlers[op])(a, b)

    wp.launch(operate, dim=2, inputs=[inputs], outputs=[outputs])
    print(outputs.numpy())

上述程序使用静态表达式根据捕获的 op 变量选择正确的函数,并在编译包含 operate 核函数的模块三次时打印以下输出。

[3. 3.]
[-1.  3.]
[2. 0.]

示例:静态长度查询#

Python 的内置函数 len() 也可以静态评估固定长度的类型,例如向量、四元数和矩阵,并且可以包装到 wp.static() 调用中以初始化其他构造。

import warp as wp

@wp.kernel
def my_kernel(v: wp.vec2):
    m = wp.identity(n=wp.static(len(v) + 1), dtype=v.dtype)
    wp.expect_eq(wp.ddot(m, m), 3.0)

v = wp.vec2(1, 2)
wp.launch(my_kernel, 1, inputs=(v,))

高级示例:使用静态循环展开消除分支#

在计算模拟中,根据运行时变量应用不同的操作或边界条件是很常见的。然而,使用运行时变量进行条件分支通常会导致寄存器压力,从而引发性能问题,因为即使某些分支永远不会被执行,GPU 也可能为所有分支分配资源。为了解决这个问题,我们可以利用通过 wp.static(...) 进行的静态循环展开,这有助于在编译时消除不必要的分支并提高并行执行效率。

场景

假设我们有三个不同的函数 apply_func_aapply_func_bapply_func_c,它们执行不同的数学运算。

我们当前只对给定数据集应用其中两个函数(apply_func_aapply_func_b)。我们将哪个函数应用于每个数据点由运行时变量 func_id 决定,该变量作为数组提供给名为 func_field 的核函数。

实际上,func_field 表示应将哪个操作应用于每个数据点的映射,在处理边界条件或物理模拟的不同区域时特别有用。例如,在流体模拟中,流体的不同区域可能需要根据预定义的边界条件进行不同的更新。

朴素方法实现

首先,让我们考虑一个实现此功能的朴素方法,该方法涉及基于 func_id 值的直接运行时分支。这种方法将突出显示为什么我们需要进一步优化。

import warp as wp
import numpy as np

# Define three functions that perform different operations
@wp.func
def apply_func_a(x: float) -> float:
    return x + 10.0

@wp.func
def apply_func_b(x: float) -> float:
    return x * 2.0

@wp.func
def apply_func_c(x: float) -> float:
    return x - 5.0

# Assign static IDs to represent each function
func_id_a = 0
func_id_b = 1
func_id_c = 2  # Not used in this kernel

# Kernel that applies the correct function to each element of the input array
@wp.kernel
def apply_func_conditions_naive(x: wp.array(dtype=wp.float32), func_field: wp.array(dtype=wp.int8)):
    tid = wp.tid()
    value = x[tid]
    result = value
    func_id = func_field[tid]  # Get the function ID for this element

    # Apply the corresponding function based on func_id
    if func_id == func_id_a:
        result = apply_func_a(value)
    elif func_id == func_id_b:
        result = apply_func_b(value)
    elif func_id == func_id_c:
        result = apply_func_c(value)

    x[tid] = result

# Example usage
data = wp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=wp.float32)

# Create an array that specifies which function to apply to each element
func_field = wp.array([func_id_a, func_id_b, func_id_b, func_id_a, func_id_b], dtype=wp.int8)

# Launch the kernel
wp.launch(apply_func_conditions_naive, inputs=[data, func_field], dim=data.size)

print(data.numpy())

输出

[11.  4.  6. 14. 10.]

由于 func_id 不是静态的,编译器无法在编译时消除未使用的函数。查看生成的 CUDA 代码,我们可以看到核函数包含了对未使用函数 apply_func_c 的额外分支。

//...
var_11 = wp::where(var_9, var_10, var_4);
if (!var_9) {
    var_13 = (var_7 == var_12);
    if (var_13) {
        var_14 = apply_func_b_0(var_3);
    }
    var_15 = wp::where(var_13, var_14, var_11);
    if (!var_13) {
        var_17 = (var_7 == var_16);
        if (var_17) {
            var_18 = apply_func_c_0(var_3);
        }
        var_19 = wp::where(var_17, var_18, var_15);
    }
    var_20 = wp::where(var_13, var_15, var_19);
}
//...

优化

为了避免额外的分支,我们可以通过 wp.static(...) 使用静态循环展开,以有效“编译掉”不必要的分支,只保留相关的操作。

实现

funcs = [apply_func_a, apply_func_b, apply_func_c]

# Assign static IDs to represent each function
func_id_a = 0
func_id_b = 1
func_id_c = 2  # Not used in this kernel

# Define which function IDs are actually used in this kernel
used_func_ids = (func_id_a, func_id_b)

@wp.kernel
def apply_func_conditions(x: wp.array(dtype=wp.float32), func_field: wp.array(dtype=wp.int8)):
    tid = wp.tid()
    value = x[tid]
    result = value
    func_id = func_field[tid]  # Get the function ID for this element

    # Unroll the loop over the used function IDs
    for i in range(wp.static(len(used_func_ids))):
        func_static_id = wp.static(used_func_ids[i])
        if func_id == func_static_id:
            result = wp.static(funcs[i])(value)

    x[tid] = result

在生成的 CUDA 代码中,我们可以看到优化后的代码不对未使用的函数进行分支。

//...
var_10 = (var_7 == var_9);
if (var_10) {
    var_11 = apply_func_a_1(var_3);
}
var_12 = wp::where(var_10, var_11, var_4);
var_15 = (var_7 == var_14);
if (var_15) {
    var_16 = apply_func_b_1(var_3);
}
//...

动态核函数创建#

通常希望使用不同的常量、类型或函数动态定制核函数。我们可以通过使用 Python 闭包进行运行时核函数特殊化来实现这一点。

核函数闭包#

常量#

Warp 允许在核函数中引用外部常量。

def create_kernel_with_constant(constant):
    @wp.kernel
    def k(a: wp.array(dtype=float)):
        tid = wp.tid()
        a[tid] += constant
    return k

k1 = create_kernel_with_constant(17.0)
k2 = create_kernel_with_constant(42.0)

a = wp.zeros(5, dtype=float)

wp.launch(k1, dim=a.shape, inputs=[a])
wp.launch(k2, dim=a.shape, inputs=[a])

print(a)

输出

[59. 59. 59. 59. 59.]

数据类型#

Warp 数据类型也可以在闭包中捕获。这里有一个创建处理不同向量维度的核函数的示例。

def create_kernel_with_dtype(vec_type):
    @wp.kernel
    def k(a: wp.array(dtype=vec_type)):
        tid = wp.tid()
        a[tid] += float(tid) * vec_type(1.0)
    return k

k2 = create_kernel_with_dtype(wp.vec2)
k4 = create_kernel_with_dtype(wp.vec4)

a2 = wp.ones(3, dtype=wp.vec2)
a4 = wp.ones(3, dtype=wp.vec4)

wp.launch(k2, dim=a2.shape, inputs=[a2])
wp.launch(k4, dim=a4.shape, inputs=[a4])

print(a2)
print(a4)

输出

[[1. 1.]
 [2. 2.]
 [3. 3.]]
[[1. 1. 1. 1.]
 [2. 2. 2. 2.]
 [3. 3. 3. 3.]]

函数#

这是一个使用不同函数参数化的核函数生成器。

def create_kernel_with_function(f):
    @wp.kernel
    def k(a: wp.array(dtype=float)):
        tid = wp.tid()
        a[tid] = f(a[tid])
    return k

@wp.func
def square(x: float):
    return x * x

@wp.func
def cube(x: float):
    return x * x * x

k1 = create_kernel_with_function(square)
k2 = create_kernel_with_function(cube)

a1 = wp.array([1, 2, 3, 4, 5], dtype=float)
a2 = wp.array([1, 2, 3, 4, 5], dtype=float)

wp.launch(k1, dim=a1.shape, inputs=[a1])
wp.launch(k2, dim=a2.shape, inputs=[a2])

print(a1)
print(a2)

输出

[ 1.  4.  9.  16.  25.]
[ 1.  8.  27.  64.  125.]

函数闭包#

Warp 函数 (@wp.func) 也支持闭包,就像核函数一样。

def create_function_with_constant(constant):
    @wp.func
    def f(x: float):
        return constant * x
    return f

f1 = create_function_with_constant(2.0)
f2 = create_function_with_constant(3.0)

@wp.kernel
def k(a: wp.array(dtype=float)):
    tid = wp.tid()
    x = float(tid)
    a[tid] = f1(x) + f2(x)

a = wp.ones(5, dtype=float)

wp.launch(k, dim=a.shape, inputs=[a])

print(a)

输出

[ 0.  5. 10. 15. 20.]

我们也可以像这样一起创建相关的函数和核函数闭包。

def create_fk(a, b):
    @wp.func
    def f(x: float):
        return a * x

    @wp.kernel
    def k(a: wp.array(dtype=float)):
        tid = wp.tid()
        a[tid] = f(a[tid]) + b

    return f, k

# create related function and kernel closures
f1, k1 = create_fk(2.0, 3.0)
f2, k2 = create_fk(4.0, 5.0)

# use the functions separately in a new kernel
@wp.kernel
def kk(a: wp.array(dtype=float)):
    tid = wp.tid()
    a[tid] = f1(a[tid]) + f2(a[tid])

a1 = wp.array([1, 2, 3, 4, 5], dtype=float)
a2 = wp.array([1, 2, 3, 4, 5], dtype=float)
ak = wp.array([1, 2, 3, 4, 5], dtype=float)

wp.launch(k1, dim=a1.shape, inputs=[a1])
wp.launch(k2, dim=a2.shape, inputs=[a2])
wp.launch(kk, dim=ak.shape, inputs=[ak])

print(a1)
print(a2)
print(ak)

输出

[ 5.  7.  9. 11. 13.]
[ 9. 13. 17. 21. 25.]
[ 6. 12. 18. 24. 30.]

动态结构体#

有时,使用不同的数据类型自定义 Warp 结构体很有用。

自定义精度#

例如,我们可以创建具有不同浮点精度的结构体。

def create_struct_with_precision(dtype):
    @wp.struct
    class S:
        a: dtype
        b: dtype
    return S

# create structs with different floating point precision
S16 = create_struct_with_precision(wp.float16)
S32 = create_struct_with_precision(wp.float32)
S64 = create_struct_with_precision(wp.float64)

s16 = S16()
s32 = S32()
s64 = S64()

s16.a, s16.b = 2.0001, 3.0000002
s32.a, s32.b = 2.0001, 3.0000002
s64.a, s64.b = 2.0001, 3.0000002

# create a generic kernel that works with the different types
@wp.kernel
def k(s: Any, output: wp.array(dtype=Any)):
    tid = wp.tid()
    x = output.dtype(tid)
    output[tid] = x * s.a + s.b

a16 = wp.empty(5, dtype=wp.float16)
a32 = wp.empty(5, dtype=wp.float32)
a64 = wp.empty(5, dtype=wp.float64)

wp.launch(k, dim=a16.shape, inputs=[s16, a16])
wp.launch(k, dim=a32.shape, inputs=[s32, a32])
wp.launch(k, dim=a64.shape, inputs=[s64, a64])

print(a16)
print(a32)
print(a64)

我们可以在输出中看到使用不同浮点精度的效果。

[ 3.  5.  7.  9. 11.]
[ 3.0000002  5.0001     7.0002003  9.000299  11.0004   ]
[ 3.0000002  5.0001002  7.0002002  9.0003002 11.0004002]

自定义维度#

动态结构体的另一个有用应用是自定义维度的能力。在这里,我们创建处理 2D 和 3D 数据的结构体。

# create struct with different vectors and matrix dimensions
def create_struct_nd(dim):
    @wp.struct
    class S:
        v: wp.types.vector(dim, float)
        m: wp.types.matrix((dim, dim), float)
    return S

S2 = create_struct_nd(2)
S3 = create_struct_nd(3)

s2 = S2()
s2.v = (1.0, 2.0)
s2.m = ((2.0, 0.0),
        (0.0, 0.5))

s3 = S3()
s3.v = (1.0, 2.0, 3.0)
s3.m = ((2.0, 0.0, 0.0),
        (0.0, 0.5, 0.0),
        (0.0, 0.0, 1.0))

# create a generic kernel that works with the different types
@wp.kernel
def k(s: Any, output: wp.array(dtype=Any)):
    tid = wp.tid()
    x = float(tid)
    output[tid] = x * s.v * s.m

a2 = wp.empty(5, dtype=wp.vec2)
a3 = wp.empty(5, dtype=wp.vec3)

wp.launch(k, dim=a2.shape, inputs=[s2, a2])
wp.launch(k, dim=a3.shape, inputs=[s3, a3])

print(a2)
print(a3)

输出

[[0. 0.]
 [2. 1.]
 [4. 2.]
 [6. 3.]
 [8. 4.]]
[[ 0.  0.  0.]
 [ 2.  1.  3.]
 [ 4.  2.  6.]
 [ 6.  3.  9.]
 [ 8.  4. 12.]]

模块重新加载#

频繁重新编译会增加程序的开销,特别是如果程序在运行时创建核函数。考虑此程序。

def create_kernel_with_constant(constant):
    @wp.kernel
    def k(a: wp.array(dtype=float)):
        tid = wp.tid()
        a[tid] += constant
    return k

a = wp.zeros(5, dtype=float)

k1 = create_kernel_with_constant(17.0)
wp.launch(k1, dim=a.shape, inputs=[a])
print(a)

k2 = create_kernel_with_constant(42.0)
wp.launch(k2, dim=a.shape, inputs=[a])
print(a)

k3 = create_kernel_with_constant(-9.0)
wp.launch(k3, dim=a.shape, inputs=[a])
print(a)

核函数创建与核函数启动穿插进行,这迫使在每次核函数启动时进行重新加载。

Module __main__ 96db544 load on device 'cuda:0' took 165.46 ms  (compiled)
[17. 17. 17. 17. 17.]
Module __main__ 9f609a4 load on device 'cuda:0' took 151.69 ms  (compiled)
[59. 59. 59. 59. 59.]
Module __main__ e93fbb9 load on device 'cuda:0' took 167.84 ms  (compiled)
[50. 50. 50. 50. 50.]

为了避免重新加载,应在启动所有核函数之前创建它们。

def create_kernel_with_constant(constant):
    @wp.kernel
    def k(a: wp.array(dtype=float)):
        tid = wp.tid()
        a[tid] += constant
    return k

k1 = create_kernel_with_constant(17.0)
k2 = create_kernel_with_constant(42.0)
k3 = create_kernel_with_constant(-9.0)

a = wp.zeros(5, dtype=float)

wp.launch(k1, dim=a.shape, inputs=[a])
print(a)

wp.launch(k2, dim=a.shape, inputs=[a])
print(a)

wp.launch(k3, dim=a.shape, inputs=[a])
print(a)
Module __main__ e93fbb9 load on device 'cuda:0' took 164.87 ms  (compiled)
[17. 17. 17. 17. 17.]
[59. 59. 59. 59. 59.]
[50. 50. 50. 50. 50.]

重新定义相同的核函数、函数和结构体不应导致模块重新加载,因为 Warp 能够检测到重复项。

def create_struct(dtype):
    @wp.struct
    class S:
        a: dtype
        b: dtype
    return S

def create_function(dtype, S):
    @wp.func
    def f(s: S):
        return s.a * s.b
    return f

def create_kernel(dtype, S, f, C):
    @wp.kernel
    def k(a: wp.array(dtype=dtype)):
        tid = wp.tid()
        s = S(a[tid], C)
        a[tid] = f(s)
    return k

# create identical struct, function, and kernel in a loop
for i in range(3):
    S = create_struct(float)
    f = create_function(float, S)
    k = create_kernel(float, S, f, 3.0)

    a = wp.array([1, 2, 3, 4, 5], dtype=float)

    wp.launch(k, dim=a.shape, inputs=[a])
    print(a)

即使在循环的每次迭代中重新创建了结构体 S、函数 f 和核函数 k,它们都是重复项,因此模块只加载一次。

Module __main__ 4af2d60 load on device 'cuda:0' took 181.34 ms  (compiled)
[ 3.  6.  9. 12. 15.]
[ 3.  6.  9. 12. 15.]
[ 3.  6.  9. 12. 15.]

延迟绑定和静态表达式#

Python 使用延迟绑定,这意味着可以在变量定义之前在函数中引用变量。

def k():
    # Function f() and constant C are not defined yet.
    # They will be resolved when k() is called.
    print(f() + C)

def f():
    return 42

C = 17

# late binding occurs in this call
k()

Warp 默认遵循此约定,因为这是 Python 的方式。这是一个用 Warp 编写的类似程序。

@wp.kernel
def k():
    # Function f() and constant C are not defined yet.
    # They will be resolved when k() is called.
    print(f() + C)

@wp.func
def f():
    return 42

C = 17

# late binding occurs in this launch, when the module is compiled
wp.launch(k, dim=1)

# wait for the output
wp.synchronize_device()

延迟绑定通常很方便,但有时会导致令人惊讶的结果。考虑此代码片段,它在循环中创建核函数。核函数将循环变量作为常量引用。

# create a list of kernels that use the loop variable
kernels = []
for i in range(3):
    @wp.kernel
    def k():
        print(i)
    kernels.append(k)

# launch the kernels
for k in kernels:
    wp.launch(k, dim=1)

wp.synchronize_device()

这会打印

2
2
2

这可能令人惊讶,但在纯 Python 中创建类似的程序会得到相同的结果。由于延迟绑定,捕获的循环变量 i 直到核函数启动时才被评估。那时,i 的值为 2,并且我们从每个核函数中看到相同的输出。

在 Warp 中,可以使用 wp.static() 来解决此问题。

# create a list of kernels that use the loop variable
kernels = []
for i in range(3):
    @wp.kernel
    def k():
        print(wp.static(i))  # wp.static() for the win
    kernels.append(k)

# launch the kernels
for k in kernels:
    wp.launch(k, dim=1)

wp.synchronize_device()

Warp 将 wp.static() 的调用替换为作为其参数传递的表达式的值。该表达式在核函数定义时立即评估。这类似于 C++ 等语言使用的静态绑定,这意味着静态表达式引用的所有变量必须已定义。

为了进一步说明默认延迟绑定行为与静态表达式之间的区别,考虑此程序。

C = 17

@wp.kernel
def k1():
    print(C)

@wp.kernel
def k2():
    print(wp.static(C))

# redefine constant
C = 42

wp.launch(k1, dim=1)
wp.launch(k2, dim=1)

wp.synchronize_device()

输出

42
17

核函数 k1 使用 C 的延迟绑定。这意味着它捕获 C 的最新值,该值是在启动期间构建模块时确定的。核函数 k2 在静态表达式中使用 C,因此它在定义核函数时捕获 C 的值。

相同的规则适用于解析 Warp 函数。

@wp.func
def f():
    return 17

@wp.kernel
def k1():
    print(f())

@wp.kernel
def k2():
    print(wp.static(f)())

# redefine function
@wp.func
def f():
    return 42

wp.launch(k1, dim=1)
wp.launch(k2, dim=1)

wp.synchronize_device()

输出

42
17

核函数 k1 使用函数 f 的最新定义,而核函数 k2 使用声明核函数时 f 的定义。