互操作性#

Warp 可以通过标准接口协议与其他基于 Python 的框架(如 NumPy)进行互操作。

Warp 支持直接将外部数组传递给内核,只要它们实现了 __array____array_interface____cuda_array_interface__ 协议。 这适用于许多常见框架,如 NumPy、CuPy 或 PyTorch。

例如,在 CPU 上启动 Warp 内核时,我们可以直接使用 NumPy 数组

import numpy as np
import warp as wp

@wp.kernel
def saxpy(x: wp.array(dtype=float), y: wp.array(dtype=float), a: float):
    i = wp.tid()
    y[i] = a * x[i] + y[i]

x = np.arange(n, dtype=np.float32)
y = np.ones(n, dtype=np.float32)

wp.launch(saxpy, dim=n, inputs=[x, y, 1.0], device="cpu")

同样,我们可以在 CUDA 设备上使用 CuPy 数组

import cupy as cp

with cp.cuda.Device(0):
    x = cp.arange(n, dtype=cp.float32)
    y = cp.ones(n, dtype=cp.float32)

wp.launch(saxpy, dim=n, inputs=[x, y, 1.0], device="cuda:0")

请注意,对于 CUDA 数组,重要的是确保数组所在的设备与内核启动的设备相同。

PyTorch 支持 CPU 和 GPU 张量,这两种类型都可以传递给相应设备上的 Warp 内核。

import random
import torch

if random.choice([False, True]):
    device = "cpu"
else:
    device = "cuda:0"

x = torch.arange(n, dtype=torch.float32, device=device)
y = torch.ones(n, dtype=torch.float32, device=device)

wp.launch(saxpy, dim=n, inputs=[x, y, 1.0], device=device)

NumPy#

Warp 数组可以通过 array.numpy() 方法转换为 NumPy 数组。 当 Warp 数组位于 cpu 设备上时,这将返回对底层 Warp 分配的零拷贝视图。 如果数组位于 cuda 设备上,那么它将首先被复制回临时缓冲区并复制到 NumPy。

Warp CPU 数组也实现了 __array_interface__ 协议,因此可以直接用于构造 NumPy 数组

w = wp.array([1.0, 2.0, 3.0], dtype=float, device="cpu")
a = np.array(w)
print(a)
> [1. 2. 3.]

为了方便起见,还提供了数据类型转换实用程序

warp_type = wp.float32
...
numpy_type = wp.dtype_to_numpy(warp_type)
...
a = wp.zeros(n, dtype=warp_type)
b = np.zeros(n, dtype=numpy_type)

要从 NumPy 数组创建 Warp 数组,请使用 warp.from_numpy() 或将 NumPy 数组作为 warp.array 构造函数的 data 参数直接传递。

warp.from_numpy(
arr,
dtype=None,
shape=None,
device=None,
requires_grad=False,
)[source]#

返回从 NumPy 数组创建的 Warp 数组。

参数:
  • arr (ndarray) – 提供数据以构造 Warp 数组的 NumPy 数组。

  • dtype (type | None) – 新 Warp 数组的数据类型。 如果未提供,则将推断数据类型。

  • shape (Sequence[int] | None) – Warp 数组的形状。

  • device (Device | str | None) – 将在其上构造 Warp 数组的设备。

  • requires_grad (bool) – 是否将跟踪此数组的梯度。

Raises:

RuntimeError – 不支持 NumPy 数组的数据类型。

返回类型:

array

warp.dtype_from_numpy(numpy_dtype)[source]#

返回与 NumPy dtype 对应的 Warp dtype。

warp.dtype_to_numpy(warp_dtype)[source]#

返回与 Warp dtype 对应的 NumPy dtype。

PyTorch#

Warp 提供了辅助函数来在 PyTorch 之间转换数组

w = wp.array([1.0, 2.0, 3.0], dtype=float, device="cpu")

# convert to Torch tensor
t = wp.to_torch(w)

# convert from Torch tensor
w = wp.from_torch(t)

这些辅助函数允许在 Warp 数组和 PyTorch 张量之间转换,而无需复制底层数据。 同时,如果可用,梯度数组和张量会转换为/从 PyTorch autograd 张量转换,从而允许在 PyTorch autograd 计算中使用 Warp 数组。

warp.from_torch(
t,
dtype=None,
requires_grad=None,
grad=None,
return_ctype=False,
)[source]#

将 Torch 张量转换为 Warp 数组,而无需复制数据。

参数:
  • t (torch.Tensor) – 要包装的 torch 张量。

  • dtype (warp.dtype, optional) – 生成的 Warp 数组的目标数据类型。 默认为映射到 Warp 数组值类型的张量值类型。

  • requires_grad (bool, optional) – 结果数组是否应包装张量的梯度(如果存在)(否则将分配 grad 张量)。 默认为张量的 requires_grad 值。

  • return_ctype (bool, optional) – 是否返回低级数组描述符而不是 wp.array 对象(更快)。 描述符可以传递给 Warp 内核。

Returns:

包装的数组或数组描述符。

返回类型:

warp.array

warp.to_torch(a, requires_grad=None)[source]#

将 Warp 数组转换为 Torch 张量,而无需复制数据。

参数:
  • a (warp.array) – 要转换的 Warp 数组。

  • requires_grad (bool, optional) – 结果张量是否应将数组的梯度(如果存在)转换为 grad 张量。 默认为数组的 requires_grad 值。

Returns:

转换后的张量。

返回类型:

torch.Tensor

warp.device_from_torch(torch_device)[source]#

返回与 Torch 设备对应的 Warp 设备。

参数:

torch_device (torch.devicestr) – Torch 设备标识符

Raises:

RuntimeError – Torch 设备没有对应的 Warp 设备

返回类型:

Device

warp.device_to_torch(warp_device)[源代码]#

返回与 Warp 设备对应的 Torch 设备字符串。

参数:

warp_device (Device | str | None) – 可以解析为 warp.context.Device 的标识符。

Raises:

RuntimeError – Warp 设备与 PyTorch 不兼容。

返回类型:

str

warp.dtype_from_torch(torch_dtype)[源代码]#

返回与 Torch dtype 对应的 Warp dtype。

参数:

torch_dtype – 具有相应 Warp 数据类型的 torch.dtype。当前不支持 torch.bfloat16torch.complex64torch.complex128

Raises:

TypeError – 无法找到相应的 Warp 数据类型。

warp.dtype_to_torch(warp_dtype)[源代码]#

返回与 Warp dtype 对应的 Torch dtype。

参数:

warp_dtype – 具有相应 torch.dtype 的 Warp 数据类型。warp.uint16warp.uint32warp.uint64 映射到相同宽度的有符号整数 torch.dtype

Raises:

TypeError – 无法找到相应的 PyTorch 数据类型。

为了将 PyTorch CUDA 流转换为 Warp CUDA 流,反之亦然,Warp 提供了以下函数

warp.stream_from_torch(stream_or_device=None)[源代码]#

从 Torch CUDA 流转换为 Warp CUDA 流。

warp.stream_to_torch(stream_or_device=None)[源代码]#

从 Warp CUDA 流转换为 Torch CUDA 流。

示例:使用 warp.from_torch() 进行优化#

以下是使用 PyTorch 的 Adam 优化器,通过 warp.from_torch(),最小化 Warp 中编写的 2D 点数组上的损失函数的示例用法

import warp as wp
import torch


@wp.kernel()
def loss(xs: wp.array(dtype=float, ndim=2), l: wp.array(dtype=float)):
    tid = wp.tid()
    wp.atomic_add(l, 0, xs[tid, 0] ** 2.0 + xs[tid, 1] ** 2.0)

# indicate requires_grad so that Warp can accumulate gradients in the grad buffers
xs = torch.randn(100, 2, requires_grad=True)
l = torch.zeros(1, requires_grad=True)
opt = torch.optim.Adam([xs], lr=0.1)

wp_xs = wp.from_torch(xs)
wp_l = wp.from_torch(l)

tape = wp.Tape()
with tape:
    # record the loss function kernel launch on the tape
    wp.launch(loss, dim=len(xs), inputs=[wp_xs], outputs=[wp_l], device=wp_xs.device)

for i in range(500):
    tape.zero()
    tape.backward(loss=wp_l)  # compute gradients
    # now xs.grad will be populated with the gradients computed by Warp
    opt.step()  # update xs (and thereby wp_xs)

    # these lines are only needed for evaluating the loss
    # (the optimization just needs the gradient, not the loss value)
    wp_l.zero_()
    wp.launch(loss, dim=len(xs), inputs=[wp_xs], outputs=[wp_l], device=wp_xs.device)
    print(f"{i}\tloss: {l.item()}")

示例:使用 warp.to_torch 进行优化#

当我们在 Warp 中直接声明优化变量,并使用 warp.to_torch() 将它们转换为 PyTorch 张量时,所需的代码更少。 在这里,我们重新审视了上面的相同示例,现在只需要一次转换为 PyTorch 张量,就可以为 Adam 提供优化变量

import warp as wp
import numpy as np
import torch


@wp.kernel()
def loss(xs: wp.array(dtype=float, ndim=2), l: wp.array(dtype=float)):
    tid = wp.tid()
    wp.atomic_add(l, 0, xs[tid, 0] ** 2.0 + xs[tid, 1] ** 2.0)

# initialize the optimization variables in Warp
xs = wp.array(np.random.randn(100, 2), dtype=wp.float32, requires_grad=True)
l = wp.zeros(1, dtype=wp.float32, requires_grad=True)
# just a single wp.to_torch call is needed, Adam optimizes using the Warp array gradients
opt = torch.optim.Adam([wp.to_torch(xs)], lr=0.1)

tape = wp.Tape()
with tape:
    wp.launch(loss, dim=len(xs), inputs=[xs], outputs=[l], device=xs.device)

for i in range(500):
    tape.zero()
    tape.backward(loss=l)
    opt.step()

    l.zero_()
    wp.launch(loss, dim=len(xs), inputs=[xs], outputs=[l], device=xs.device)
    print(f"{i}\tloss: {l.numpy()[0]}")

示例:使用 torch.autograd.function 进行优化 (PyTorch <= 2.3.1)#

可以通过定义一个 torch.autograd.Function 类,将 Warp 内核启动插入 PyTorch 图中,该类需要定义前向和后向函数。 将传入的 PyTorch 张量映射到 Warp 数组后,可以像往常一样启动 Warp 内核。 在后向传递中,可以通过在 wp.launch() 中设置 adjoint = True 来启动相同内核的伴随内核。 或者,用户可以选择依赖 Warp 的磁带。 在以下示例中,我们演示了如何在优化上下文中使用 Warp 来评估 Rosenbrock 函数。

import warp as wp
import numpy as np
import torch

# Define the Rosenbrock function
@wp.func
def rosenbrock(x: float, y: float):
    return (1.0 - x) ** 2.0 + 100.0 * (y - x**2.0) ** 2.0

@wp.kernel
def eval_rosenbrock(
    xs: wp.array(dtype=wp.vec2),
    # outputs
    z: wp.array(dtype=float),
):
    i = wp.tid()
    x = xs[i]
    z[i] = rosenbrock(x[0], x[1])


class Rosenbrock(torch.autograd.Function):
    @staticmethod
    def forward(ctx, xy, num_points):
        ctx.xy = wp.from_torch(xy, dtype=wp.vec2, requires_grad=True)
        ctx.num_points = num_points

        # allocate output
        ctx.z = wp.zeros(num_points, requires_grad=True)

        wp.launch(
            kernel=eval_rosenbrock,
            dim=ctx.num_points,
            inputs=[ctx.xy],
            outputs=[ctx.z]
        )

        return wp.to_torch(ctx.z)

    @staticmethod
    def backward(ctx, adj_z):
        # map incoming Torch grads to our output variables
        ctx.z.grad = wp.from_torch(adj_z)

        wp.launch(
            kernel=eval_rosenbrock,
            dim=ctx.num_points,
            inputs=[ctx.xy],
            outputs=[ctx.z],
            adj_inputs=[ctx.xy.grad],
            adj_outputs=[ctx.z.grad],
            adjoint=True
        )

        # return adjoint w.r.t. inputs
        return (wp.to_torch(ctx.xy.grad), None)


num_points = 1500
learning_rate = 5e-2

torch_device = wp.device_to_torch(wp.get_device())

rng = np.random.default_rng(42)
xy = torch.tensor(rng.normal(size=(num_points, 2)), dtype=torch.float32, requires_grad=True, device=torch_device)
opt = torch.optim.Adam([xy], lr=learning_rate)

for _ in range(10000):
    # step
    opt.zero_grad()
    z = Rosenbrock.apply(xy, num_points)
    z.backward(torch.ones_like(z))

    opt.step()

# minimum at (1, 1)
xy_np = xy.numpy(force=True)
print(np.mean(xy_np, axis=0))

请注意,如果 Warp 代码包含在 torch.autograd.Function 中,并且该函数在 torch.compile() 中被调用,那么它会自动将该函数从编译器优化中排除。 如果您的脚本使用 torch.compile(),我们建议使用 PyTorch 2.3.0+ 版本,该版本具有解决此情况的改进。

示例:使用 PyTorch 自定义运算符进行优化 (PyTorch >= 2.4.0)#

PyTorch 2.4+ 引入了 自定义运算符 来取代 PyTorch autograd 函数。 这些将任意 Python 函数(包括 Warp 调用)视为不透明的可调用对象,从而阻止 torch.compile() 跟踪到它们。 这意味着包含 Warp 内核启动的前向 PyTorch 图评估可以安全地使用 torch.compile() 加速。 我们可以使用自定义运算符重写前面的示例,如下所示

import warp as wp
import numpy as np
import torch

# Define the Rosenbrock function
@wp.func
def rosenbrock(x: float, y: float):
    return (1.0 - x) ** 2.0 + 100.0 * (y - x**2.0) ** 2.0


@wp.kernel
def eval_rosenbrock(
    xy: wp.array(dtype=wp.vec2),
    # outputs
    z: wp.array(dtype=float),
):
    i = wp.tid()
    v = xy[i]
    z[i] = rosenbrock(v[0], v[1])


@torch.library.custom_op("wp::warp_rosenbrock", mutates_args=())
def warp_rosenbrock(xy: torch.Tensor, num_points: int) -> torch.Tensor:
    wp_xy = wp.from_torch(xy, dtype=wp.vec2)
    wp_z = wp.zeros(num_points, dtype=wp.float32)

    wp.launch(kernel=eval_rosenbrock, dim=num_points, inputs=[wp_xy], outputs=[wp_z])

    return wp.to_torch(wp_z)


@warp_rosenbrock.register_fake
def _(xy, num_points):
    return torch.empty(num_points, dtype=torch.float32)


@torch.library.custom_op("wp::warp_rosenbrock_backward", mutates_args=())
def warp_rosenbrock_backward(
    xy: torch.Tensor, num_points: int, z: torch.Tensor, adj_z: torch.Tensor
) -> torch.Tensor:
    wp_xy = wp.from_torch(xy, dtype=wp.vec2)
    wp_z = wp.from_torch(z, requires_grad=False)
    wp_adj_z = wp.from_torch(adj_z, requires_grad=False)

    wp.launch(
        kernel=eval_rosenbrock,
        dim=num_points,
        inputs=[wp_xy],
        outputs=[wp_z],
        adj_inputs=[wp_xy.grad],
        adj_outputs=[wp_adj_z],
        adjoint=True,
    )

    return wp.to_torch(wp_xy.grad)


@warp_rosenbrock_backward.register_fake
def _(xy, num_points, z, adj_z):
    return torch.empty_like(xy)


def backward(ctx, adj_z):
    ctx.xy.grad = warp_rosenbrock_backward(ctx.xy, ctx.num_points, ctx.z, adj_z)
    return ctx.xy.grad, None


def setup_context(ctx, inputs, output):
    ctx.xy, ctx.num_points = inputs
    ctx.z = output


warp_rosenbrock.register_autograd(backward, setup_context=setup_context)

num_points = 1500
learning_rate = 5e-2

torch_device = wp.device_to_torch(wp.get_device())

rng = np.random.default_rng(42)
xy = torch.tensor(rng.normal(size=(num_points, 2)), dtype=torch.float32, requires_grad=True, device=torch_device)
opt = torch.optim.Adam([xy], lr=learning_rate)

@torch.compile(fullgraph=True)
def forward():
    global xy, num_points

    z = warp_rosenbrock(xy, num_points)
    return z

for _ in range(10000):
    # step
    opt.zero_grad()
    z = forward()
    z.backward(torch.ones_like(z))
    opt.step()

# minimum at (1, 1)
xy_np = xy.numpy(force=True)
print(np.mean(xy_np, axis=0))

性能说明#

wp.from_torch() 函数创建一个 Warp 数组对象,该对象与 PyTorch 张量共享数据。 尽管此函数不复制数据,但在转换过程中始终存在一些 CPU 开销。 如果这些转换频繁发生,则整体程序性能可能会受到影响。 作为一般规则,应避免重复转换相同的张量。 而不是

x_t = torch.arange(n, dtype=torch.float32, device=device)
y_t = torch.ones(n, dtype=torch.float32, device=device)

for i in range(10):
    x_w = wp.from_torch(x_t)
    y_w = wp.from_torch(y_t)
    wp.launch(saxpy, dim=n, inputs=[x_w, y_w, 1.0], device=device)

尝试仅转换一次数组并重复使用它们

x_t = torch.arange(n, dtype=torch.float32, device=device)
y_t = torch.ones(n, dtype=torch.float32, device=device)

x_w = wp.from_torch(x_t)
y_w = wp.from_torch(y_t)

for i in range(10):
    wp.launch(saxpy, dim=n, inputs=[x_w, y_w, 1.0], device=device)

如果无法重用数组(例如,每次迭代都构造一个新的 PyTorch 张量),则将 return_ctype=True 传递给 wp.from_torch() 应该会产生更好的性能。 将此参数设置为 True 可避免构造 wp.array 对象,而是返回一个低级数组描述符。 此描述符是一个简单的 C 结构,可以传递给 Warp 内核而不是 wp.array,但不能在需要 wp.array 的其他地方使用。

for n in range(1, 10):
    # get Torch tensors for this iteration
    x_t = torch.arange(n, dtype=torch.float32, device=device)
    y_t = torch.ones(n, dtype=torch.float32, device=device)

    # get Warp array descriptors
    x_ctype = wp.from_torch(x_t, return_ctype=True)
    y_ctype = wp.from_torch(y_t, return_ctype=True)

    wp.launch(saxpy, dim=n, inputs=[x_ctype, y_ctype, 1.0], device=device)

另一种方法是将 PyTorch 张量直接传递给 Warp 内核。 这避免了通过利用 PyTorch 和 Warp 都支持的标准数组接口(如 __cuda_array_interface__)来构造临时 Warp 数组。 这种方法的主要优点是方便,因为不需要调用任何转换函数。 主要限制是它不处理梯度,因为梯度信息不包含在标准数组接口中。 因此,此技术最适合不涉及微分的算法。

x = torch.arange(n, dtype=torch.float32, device=device)
y = torch.ones(n, dtype=torch.float32, device=device)

for i in range(10):
    wp.launch(saxpy, dim=n, inputs=[x, y, 1.0], device=device)
python -m warp.examples.benchmarks.benchmark_interop_torch

示例输出

5095 ms  from_torch(...)
2113 ms  from_torch(..., return_ctype=True)
2950 ms  direct from torch

默认的 wp.from_torch() 转换是最慢的。 传递 return_ctype=True 是最快的,因为它会跳过创建临时 Warp 数组对象。 将 PyTorch 张量直接传递给 Warp 内核的性能介于两者之间。 它跳过了创建临时 Warp 数组,但访问 PyTorch 张量的 __cuda_array_interface__ 属性会增加开销,因为它们是按需初始化的。

CuPy/Numba#

Warp GPU 数组支持 __cuda_array_interface__ 协议,用于与其他 Python GPU 框架共享数据。 这允许 CuPy 等框架直接使用 Warp GPU 数组。

同样,可以从任何公开 __cuda_array_interface__ 的对象创建 Warp 数组。 这样的对象也可以直接传递给 Warp 内核,而无需创建 Warp 数组对象。

JAX#

通过以下方法支持与 JAX 数组的互操作性。 在内部,这些方法使用 DLPack 协议以零复制的方式与 JAX 交换数据

warp_array = wp.from_jax(jax_array)
jax_array = wp.to_jax(warp_array)

最好直接使用 DLPack 协议,以获得更好的性能并控制流同步。

warp.from_jax(jax_array, dtype=None)[source]#

将 Jax 数组转换为 Warp 数组,无需复制数据。

参数:
  • jax_array (jax.Array) – 要转换的 Jax 数组。

  • dtype (可选) – 结果 Warp 数组的目标数据类型。默认为映射到 Warp 数据类型的 Jax 数组的数据类型。

Returns:

转换后的 Warp 数组。

返回类型:

warp.array

warp.to_jax(warp_array)[source]#

将 Warp 数组转换为 Jax 数组,无需复制数据。

参数:

warp_array (warp.array) – 要转换的 Warp 数组。

Returns:

转换后的 Jax 数组。

返回类型:

jax.Array

warp.device_from_jax(jax_device)[source]#

返回对应于 Jax 设备的 Warp 设备。

参数:

jax_device (jax.Device) – Jax 设备描述符。

Raises:

RuntimeError – Jax 设备既不是 CPU 也不是 GPU 设备。

返回类型:

Device

warp.device_to_jax(warp_device)[source]#

返回对应于 Warp 设备的 Jax 设备。

Returns:

jax.Device

Raises:

RuntimeError – 无法找到对应的 Jax 设备。

参数:

warp_device (Device | str | None)

warp.dtype_from_jax(jax_dtype)[source]#

返回对应于 Jax dtype 的 Warp dtype。

Raises:

TypeError – 无法找到相应的 Warp 数据类型。

warp.dtype_to_jax(warp_dtype)[source]#

返回对应于 Warp dtype 的 Jax dtype。

参数:

warp_dtype – 具有对应 Jax 数据类型的 Warp 数据类型。

Raises:

TypeError – 无法找到对应的 Jax 数据类型。

将 Warp 内核用作 JAX 原语#

注意

此版本的 jax_kernel() 基于现在已弃用的 JAX 功能。

对于 JAX 0.4.31 或更高版本,建议用户切换到基于新的 Foreign Function Interface (FFI) 的新版本的 jax_kernel()

Warp 内核可以用作 JAX 原语,这允许在 jitted JAX 函数内部调用它们

import warp as wp
import jax
import jax.numpy as jnp

from warp.jax_experimental import jax_kernel

@wp.kernel
def triple_kernel(input: wp.array(dtype=float), output: wp.array(dtype=float)):
    tid = wp.tid()
    output[tid] = 3.0 * input[tid]

# create a Jax primitive from a Warp kernel
jax_triple = jax_kernel(triple_kernel)

# use the Warp kernel in a Jax jitted function
@jax.jit
def f():
    x = jnp.arange(0, 64, dtype=jnp.float32)
    return jax_triple(x)

print(f())
warp.jax_experimental.jax_kernel(kernel, launch_dims=None)[source]#

从 Warp 内核创建一个 Jax 原语。

注意:这是一个正在开发中的实验性功能。

参数:
  • kernel – 要包装的 Warp 内核。

  • launch_dims – 可选。指定内核启动维度。如果为 None,则从第一个参数的形状推断维度。设置此选项时,将指定输出维度。

局限性
  • 所有内核参数必须是连续数组。

  • 输入参数之后是 Warp 内核定义中的输出参数。

  • 必须至少有一个输入参数和一个输出参数。

  • 仅支持 CUDA 后端。

输入和输出语义#

所有内核参数必须是连续数组。输入参数必须在内核定义中的输出参数之前。至少需要一个输入数组和一个输出数组。这是一个具有三个输入和两个输出的内核

import warp as wp
import jax
import jax.numpy as jnp

from warp.jax_experimental import jax_kernel

# kernel with multiple inputs and outputs
@wp.kernel
def multiarg_kernel(
    # inputs
    a: wp.array(dtype=float),
    b: wp.array(dtype=float),
    c: wp.array(dtype=float),
    # outputs
    ab: wp.array(dtype=float),
    bc: wp.array(dtype=float),
):
    tid = wp.tid()
    ab[tid] = a[tid] + b[tid]
    bc[tid] = b[tid] + c[tid]

# create a Jax primitive from a Warp kernel
jax_multiarg = jax_kernel(multiarg_kernel)

# use the Warp kernel in a Jax jitted function with three inputs and two outputs
@jax.jit
def f():
    a = jnp.full(64, 1, dtype=jnp.float32)
    b = jnp.full(64, 2, dtype=jnp.float32)
    c = jnp.full(64, 3, dtype=jnp.float32)
    return jax_multiarg(a, b, c)

x, y = f()

print(x)
print(y)

内核启动和输出维度#

默认情况下,启动维度是从第一个输入数组的形状推断出来的。当这不合适时,可以使用 launch_dims 参数来覆盖此行为。启动维度还决定了输出数组的形状。这是一个简单的矩阵乘法内核,它将一个 NxK 矩阵乘以一个 KxM 矩阵。启动维度和输出形状必须是 (N, M),这与输入数组的形状不同

import warp as wp
import jax
import jax.numpy as jnp

import warp as wp
from warp.jax_experimental import jax_kernel

@wp.kernel
def matmul_kernel(
    a: wp.array2d(dtype=float),  # NxK input
    b: wp.array2d(dtype=float),  # KxM input
    c: wp.array2d(dtype=float),  # NxM output
):
    # launch dims should be (N, M)
    i, j = wp.tid()
    N = a.shape[0]
    K = a.shape[1]
    M = b.shape[1]
    if i < N and j < M:
        s = wp.float32(0)
        for k in range(K):
            s += a[i, k] * b[k, j]
        c[i, j] = s

N, M, K = 3, 4, 2

# specify custom launch dimensions
jax_matmul = jax_kernel(matmul_kernel, launch_dims=(N, M))

@jax.jit
def f():
    a = jnp.full((N, K), 2, dtype=jnp.float32)
    b = jnp.full((K, M), 3, dtype=jnp.float32)

    # use default launch dims
    return jax_matmul(a, b)

print(f())

JAX 外部函数接口 (FFI)#

在版本 1.7 中添加。

JAX v0.4.31 引入了一个新的 外部函数接口,它取代了旧的自定义调用机制。一个重要的好处是,它允许将外部函数与其它 JAX 操作一起捕获在 CUDA 图中。这可以带来显著的性能提升。

建议新版本的 JAX 用户切换到基于 FFI 的 jax_kernel() 的新实现。旧的实现仍然可用,以避免破坏现有代码,但未来的开发可能会集中在 FFI 版本上。

from warp.jax_experimental.ffi import jax_kernel  # new FFI-based jax_kernel()

新的实现可能会更快,而且也更灵活。

warp.jax_experimental.ffi.jax_kernel(
kernel,
num_outputs=1,
vmap_method='broadcast_all',
launch_dims=None,
output_dims=None,
)[source]#

从 Warp 内核创建一个 JAX 回调。

注意:这是一个正在开发中的实验性功能。

参数:
  • kernel – 要启动的 Warp 内核。

  • num_outputs – 可选。如果大于 1,则指定输出参数的数量。

  • vmap_method – 可选。指定回调在 vmap() 下如何转换的字符串。也可以为单个调用指定此参数。

  • launch_dims – 可选。指定默认内核启动维度。如果为 None,则从第一个数组参数的形状推断启动维度。也可以为单个调用指定此参数。

  • output_dims – 可选。指定输出数组的默认维度。如果为 None,则从启动维度推断输出维度。也可以为单个调用指定此参数。

局限性
  • 所有内核参数必须是连续数组或标量。

  • 标量必须是 JAX 中的静态参数。

  • 输入参数之后是 Warp 内核定义中的输出参数。

  • 必须至少有一个输出参数。

  • 仅支持 CUDA 后端。

输入和输出语义#

输入参数必须在内核定义中的输出参数之前。至少需要一个输出数组,但是可以有无输入的内核。新的 jax_kernel() 允许使用 num_outputs 参数指定输出的数量。它默认为 1,因此只有对于具有多个输出的内核才需要此参数。

这是一个具有两个输入和一个输出的内核

import jax
import jax.numpy as jnp

import warp as wp
from warp.jax_experimental.ffi import jax_kernel

@wp.kernel
def add_kernel(a: wp.array(dtype=int),
               b: wp.array(dtype=int),
               output: wp.array(dtype=int)):
    tid = wp.tid()
    output[tid] = a[tid] + b[tid]

jax_add = jax_kernel(add_kernel)

@jax.jit
def f():
    n = 10
    a = jnp.arange(n, dtype=jnp.int32)
    b = jnp.ones(n, dtype=jnp.int32)
    return jax_add(a, b)

print(f())

一个输入和两个输出

import math

import jax
import jax.numpy as jnp

import warp as wp
from warp.jax_experimental.ffi import jax_kernel

@wp.kernel
def sincos_kernel(angle: wp.array(dtype=float),
                  # outputs
                  sin_out: wp.array(dtype=float),
                  cos_out: wp.array(dtype=float)):
    tid = wp.tid()
    sin_out[tid] = wp.sin(angle[tid])
    cos_out[tid] = wp.cos(angle[tid])

jax_sincos = jax_kernel(sincos_kernel, num_outputs=2)  # specify multiple outputs

@jax.jit
def f():
    a = jnp.linspace(0, 2 * math.pi, 32)
    return jax_sincos(a)

s, c = f()
print(s)
print(c)

这是一个没有输入的内核,它使用对角线值 (1, 2, 3) 初始化一个 3x3 矩阵数组。如果没有输入,则需要指定启动维度以确定输出数组的形状

@wp.kernel
def diagonal_kernel(output: wp.array(dtype=wp.mat33)):
    tid = wp.tid()
    output[tid] = wp.mat33(1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0)

jax_diagonal = jax_kernel(diagonal_kernel)

@jax.jit
def f():
    # launch dimensions determine the output shape
    return jax_diagonal(launch_dims=4)

print(f())
标量输入#

支持标量输入参数,但存在一些限制。目前,传递给 Warp 内核的标量必须是 JAX 中的常量或静态值。

@wp.kernel
def scale_kernel(a: wp.array(dtype=float),
                 s: float,  # scalar input
                 output: wp.array(dtype=float)):
    tid = wp.tid()
    output[tid] = a[tid] * s


jax_scale = jax_kernel(scale_kernel)

@jax.jit
def f():
    a = jnp.arange(10, dtype=jnp.float32)
    return jax_scale(a, 2.0)  # ok: constant scalar argument

print(f())

尝试使用跟踪的标量值将导致异常。

@jax.jit
def f(a, s):
    return jax_scale(a, s)  # ERROR: traced scalar argument

a = jnp.arange(10, dtype=jnp.float32)

print(f(a, 2.0))

JAX 静态参数来救援。

# make scalar arguments static
@partial(jax.jit, static_argnames=["s"])
def f(a, s):
    return jax_scale(a, s)  # ok: static scalar argument

a = jnp.arange(10, dtype=jnp.float32)

print(f(a, 2.0))

内核启动和输出维度#

默认情况下,启动维度是从第一个输入数组的形状推断出来的。如果不合适,可以使用 launch_dims 参数来覆盖此行为。启动维度也决定了输出数组的形状。

这是一个简单的矩阵乘法内核,它将 NxK 矩阵乘以 KxM 矩阵。启动维度和输出形状必须是 (N, M),这与输入数组的形状不同。

请注意,新的 jax_kernel() 允许在每次调用时指定自定义启动维度,这比旧的实现更灵活,但仍然支持旧方法。

@wp.kernel
def matmul_kernel(
    a: wp.array2d(dtype=float),  # NxK input
    b: wp.array2d(dtype=float),  # KxM input
    c: wp.array2d(dtype=float),  # NxM output
):
    # launch dimensions should be (N, M)
    i, j = wp.tid()
    N = a.shape[0]
    K = a.shape[1]
    M = b.shape[1]
    if i < N and j < M:
        s = wp.float32(0)
        for k in range(K):
            s += a[i, k] * b[k, j]
        c[i, j] = s

# no need to specify launch dims here
jax_matmul = jax_kernel(matmul_kernel)

@jax.jit
def f():
    N1, M1, K1 = 3, 4, 2
    a1 = jnp.full((N1, K1), 2, dtype=jnp.float32)
    b1 = jnp.full((K1, M1), 3, dtype=jnp.float32)

    # use custom launch dims
    result1 = jax_matmul(a1, b1, launch_dims=(N1, M1))

    N2, M2, K2 = 4, 3, 2
    a2 = jnp.full((N2, K2), 2, dtype=jnp.float32)
    b2 = jnp.full((K2, M2), 3, dtype=jnp.float32)

    # use custom launch dims
    result2 = jax_matmul(a2, b2, launch_dims=(N2, M2))

    return result1, result2

r1, r2 = f()
print(r1)
print(r2)

默认情况下,输出数组形状由启动维度确定,但可以使用 output_dims 参数指定自定义输出维度。考虑如下内核:

@wp.kernel
def funky_kernel(a: wp.array(dtype=float),
                 # outputs
                 b: wp.array(dtype=float),
                 c: wp.array(dtype=float)):
    ...

jax_funky = jax_kernel(funky_kernel, num_outputs=2)

指定用于所有输出的自定义输出形状

b, c = jax_funky(a, output_dims=n)

使用字典为每个输出指定不同的输出维度

b, c = jax_funky(a, output_dims={"b": n, "c": m})

一起指定自定义启动维度和输出维度

b, c = jax_funky(a, launch_dims=k, output_dims={"b": n, "c": m})

可以使用整数指定一维形状。可以使用整数元组或列表指定多维形状。

向量和矩阵数组#

支持 Warp 向量和矩阵类型的数组。由于 JAX 没有相应的数据类型,因此组件被打包到 JAX 数组的额外内部维度中。例如,wp.vec3 的 Warp 数组将具有 (…, 3) 的 JAX 数组形状,wp.mat22 的 Warp 数组将具有 (…, 2, 2) 的 JAX 数组形状。

@wp.kernel
def vecmat_kernel(a: wp.array(dtype=float),
                  b: wp.array(dtype=wp.vec3),
                  c: wp.array(dtype=wp.mat22),
                  # outputs
                  d: wp.array(dtype=float),
                  e: wp.array(dtype=wp.vec3),
                  f: wp.array(dtype=wp.mat22)):
    ...

jax_vecmat = jax_kernel(vecmat_kernel, num_outputs=3)

@jax.jit
def f():
    n = 10
    a = jnp.zeros(n, dtype=jnp.float32)          # scalar array
    b = jnp.zeros((n, 3), dtype=jnp.float32)     # vec3 array
    c = jnp.zeros((n, 2, 2), dtype=jnp.float32)  # mat22 array

    d, e, f = vecmat_kernel(a, b, c)

重要的是要认识到向量和矩阵类型的 Warp 和 JAX 数组形状是不同的。在上面的代码片段中,Warp 将 abc 视为 wp.float32wp.vec3wp.mat22 的一维数组。在 JAX 中,a 是长度为 n 的一维数组,b 是形状为 (n, 3) 的二维数组,c 是形状为 (n, 2, 2) 的三维数组。

指定自定义输出维度时,可以使用任一约定。以下调用是等效的:

d, e, f = vecmat_kernel(a, b, c, output_dims=n)
d, e, f = vecmat_kernel(a, b, c, output_dims={"d": n, "e": n, "f": n})
d, e, f = vecmat_kernel(a, b, c, output_dims={"d": n, "e": (n, 3), "f": (n, 2, 2)})

这是一个为了简化编写代码的便利功能。例如,当 Warp 期望数组具有相同的形状时,我们只需要指定一次形状,而无需担心 JAX 要求的额外向量和矩阵维度。

d, e, f = vecmat_kernel(a, b, c, output_dims=n)

另一方面,也接受 JAX 维度,以允许直接从 JAX 传递形状。

d, e, f = vecmat_kernel(a, b, c, output_dims={"d": a.shape, "e": b.shape, "f": c.shape})

有关示例,请参见 example_jax_kernel.py

JAX VMAP 支持#

可以使用 vmap_method 参数来指定回调在 jax.vmap() 下的转换方式。默认值为 "broadcast_all"。此参数可以传递给 jax_kernel(),也可以传递给每次调用。

# set default vmap behavior
jax_callback = jax_kernel(my_kernel, vmap_method="sequential")

@jax.jit
def f():
    ...
    b = jax_callback(a)  # uses "sequential"
    ...
    d = jax_callback(c, vmap_method="expand_dims")  # uses "expand_dims"
    ...

调用带注释的 Python 函数#

jax_kernel() 机制可用于从 JAX 启动单个 Warp 内核,但也可以调用启动多个内核的 Python 函数。目标 Python 函数应该具有参数类型注释,就像它是一个 Warp 内核一样。要从 JAX 调用此函数,请使用 jax_callable()

from warp.jax_experimental.ffi import jax_callable

@wp.kernel
def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
    tid = wp.tid()
    output[tid] = a[tid] * s

@wp.kernel
def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
    tid = wp.tid()
    output[tid] = a[tid] * s


# The Python function to call.
# Note the argument type annotations, just like Warp kernels.
def example_func(
    # inputs
    a: wp.array(dtype=float),
    b: wp.array(dtype=wp.vec2),
    s: float,
    # outputs
    c: wp.array(dtype=float),
    d: wp.array(dtype=wp.vec2),
):
    # launch multiple kernels
    wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
    wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])


jax_func = jax_callable(example_func, num_outputs=2)

@jax.jit
def f():
    # inputs
    a = jnp.arange(10, dtype=jnp.float32)
    b = jnp.arange(10, dtype=jnp.float32).reshape((5, 2))  # wp.vec2
    s = 2.0

    # output shapes
    output_dims = {"c": a.shape, "d": b.shape}

    c, d = jax_func(a, b, s, output_dims=output_dims)

    return c, d

r1, r2 = f()
print(r1)
print(r2)

jax_callable() 的输入和输出语义与 jax_kernel() 类似,因此我们不会在此处重述所有内容,而只关注差异。

  • jax_callable() 不接受 launch_dims 参数,因为目标函数负责使用适当的维度启动内核。

  • jax_callable() 接受一个可选的布尔值 graph_compatible 参数,其默认值为 True。此参数确定 JAX 是否可以在 CUDA 图中捕获该函数。通常,这是可取的,因为 CUDA 图可以大大提高应用程序性能。但是,如果目标函数执行了图形捕获期间不允许的操作,则可能会导致错误。这包括任何需要与主机同步的操作。在这种情况下,传递 graph_compatible=False

有关示例,请参见 example_jax_callable.py

warp.jax_experimental.ffi.jax_callable(
func,
num_outputs=1,
graph_compatible=True,
vmap_method='broadcast_all',
output_dims=None,
)[source]#

从带注释的 Python 函数创建一个 JAX 回调。

Python 函数参数必须具有像 Warp 内核一样的类型注释。

注意:这是一个正在开发中的实验性功能。

参数:
  • func (Callable) – 要调用的 Python 函数。

  • num_outputs (int) – 可选。如果大于 1,则指定输出参数的数量。

  • graph_compatible (bool) – 可选。该函数是否可以在 CUDA 图形捕获期间调用。

  • vmap_method (str) – 可选。指定回调在 vmap() 下如何转换的字符串。也可以为单独的调用指定此参数。

  • output_dims – 可选。指定输出数组的默认维度。如果 None,则从启动维度推断输出维度。也可以为单独的调用指定此参数。

局限性
  • 所有内核参数必须是连续数组或标量。

  • 标量必须是 JAX 中的静态参数。

  • 输入参数之后是 Warp 内核定义中的输出参数。

  • 必须至少有一个输出参数。

  • 仅支持 CUDA 后端。

通用 JAX FFI 回调#

调用 Python 函数的另一种方法是使用 register_ffi_callback()

from warp.jax_experimental.ffi import register_ffi_callback

这允许调用没有 Warp 风格类型注释的函数,但必须具有以下形式:

func(inputs, outputs, attrs, ctx)

其中

  • inputs 是输入缓冲区列表。

  • outputs 是输出缓冲区列表。

  • attrs 是属性字典。

  • ctx 是执行上下文,包括 CUDA 流。

输入和输出缓冲区既不是 JAX 数组,也不是 Warp 数组。它们是公开 __cuda_array_interface__ 的对象,可以直接传递给 Warp 内核。这是一个例子

from warp.jax_experimental.ffi import register_ffi_callback

@wp.kernel
def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
    tid = wp.tid()
    output[tid] = a[tid] * s

@wp.kernel
def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
    tid = wp.tid()
    output[tid] = a[tid] * s

# the Python function to call
def warp_func(inputs, outputs, attrs, ctx):
    # input arrays
    a = inputs[0]
    b = inputs[1]

    # scalar attributes
    s = attrs["scale"]

    # output arrays
    c = outputs[0]
    d = outputs[1]

    device = wp.device_from_jax(get_jax_device())
    stream = wp.Stream(device, cuda_stream=ctx.stream)

    with wp.ScopedStream(stream):
        # launch with arrays of scalars
        wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])

        # launch with arrays of vec2
        # NOTE: the input shapes are from JAX arrays, so we need to strip the inner dimension for vec2 arrays
        wp.launch(scale_vec_kernel, dim=b.shape[0], inputs=[b, s], outputs=[d])

# register callback
register_ffi_callback("warp_func", warp_func)

n = 10

# inputs
a = jnp.arange(n, dtype=jnp.float32)
b = jnp.arange(n, dtype=jnp.float32).reshape((n // 2, 2))  # array of wp.vec2
s = 2.0

# set up the call
out_types = [
    jax.ShapeDtypeStruct(a.shape, jnp.float32),
    jax.ShapeDtypeStruct(b.shape, jnp.float32),  # array of wp.vec2
]
call = jax.ffi.ffi_call("warp_func", out_types)

# call it
c, d = call(a, b, scale=s)

print(c)
print(d)

这是一种更底层的 JAX FFI 回调方法。有人提出了一项建议,将这种机制纳入 JAX 中,但目前我们在这里有一个原型。这种方法将很多工作留给了用户,例如验证参数类型和形状,但是当其他实用程序(如 jax_kernel()jax_callable())不足以满足需求时,可以使用它。

有关示例,请参见 example_jax_ffi_callback.py

warp.jax_experimental.ffi.register_ffi_callback(name, func, graph_compatible=True)[source]#

从 Python 函数创建一个 JAX 回调。

Python 函数必须具有 func(inputs, outputs, attrs, ctx) 的形式。

注意:这是一个正在开发中的实验性功能。

参数:
  • name (str) – 唯一的 FFI 回调名称。

  • func (Callable) – 要调用的 Python 函数。

  • graph_compatible (bool) – 可选。该函数是否可以在 CUDA 图形捕获期间调用。

返回类型:

分布式计算#

Warp 可以与 JAX 的 shard_map 结合使用,以执行分布式多 GPU 计算。

为此,必须初始化 JAX 分布式环境(有关更多详细信息,请参见 分布式数组和自动并行化)。

import jax
jax.distributed.initialize()

必须在程序开始时,在任何其他 JAX 操作之前调用此初始化。

这是一个如何将 shard_map 与 Warp 内核一起使用的示例

import warp as wp
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.experimental.multihost_utils import process_allgather as allgather
from jax.experimental.shard_map import shard_map
from warp.jax_experimental import jax_kernel
import numpy as np

# Initialize JAX distributed environment
jax.distributed.initialize()
num_gpus = jax.device_count()

def print_on_process_0(*args, **kwargs):
    if jax.process_index() == 0:
        print(*args, **kwargs)

print_on_process_0(f"Running on {num_gpus} GPU(s)")

@wp.kernel
def multiply_by_two_kernel(
    a_in: wp.array(dtype=wp.float32),
    a_out: wp.array(dtype=wp.float32),
):
    index = wp.tid()
    a_out[index] = a_in[index] * 2.0

jax_warp_multiply = jax_kernel(multiply_by_two_kernel)

def warp_multiply(x):
    result = jax_warp_multiply(x)
    return result

    # a_in here is the full sharded array with shape (M,)
    # The output will also be a sharded array with shape (M,)
def warp_distributed_operator(a_in):
    def _sharded_operator(a_in):
        # Inside the sharded operator, a_in is a local shard on each device
        # If we have N devices and input size M, each shard has shape (M/N,)

        # warp_multiply applies the Warp kernel to the local shard
        result = warp_multiply(a_in)[0]

        # result has the same shape as the input shard (M/N,)
        return result

    # shard_map distributes the computation across devices
    return shard_map(
        _sharded_operator,
        mesh=jax.sharding.Mesh(np.array(jax.devices()), "x"),
        in_specs=(P("x"),),  # Input is sharded along the 'x' axis
        out_specs=P("x"),    # Output is also sharded along the 'x' axis
        check_rep=False,
    )(a_in)

print_on_process_0("Test distributed multiplication using JAX + Warp")

devices = jax.devices()
mesh = jax.sharding.Mesh(np.array(devices), "x")
sharding_spec = jax.sharding.NamedSharding(mesh, P("x"))

input_size = num_gpus * 5  # 5 elements per device
single_device_arrays = jnp.arange(input_size, dtype=jnp.float32)

# Define the shape of the input array based on the total input size
shape = (input_size,)

# Create a list of arrays by distributing the single_device_arrays across the available devices
# Each device will receive a portion of the input data
arrays = [
    jax.device_put(single_device_arrays[index], d)  # Place each element on the corresponding device
    for d, index in sharding_spec.addressable_devices_indices_map(shape).items()
]

# Combine the individual device arrays into a single sharded array
sharded_array = jax.make_array_from_single_device_arrays(shape, sharding_spec, arrays)

# sharded_array has shape (input_size,) but is distributed across devices
print_on_process_0(f"Input array: {allgather(sharded_array)}")

# warp_result has the same shape and sharding as sharded_array
warp_result = warp_distributed_operator(sharded_array)

# allgather collects results from all devices, resulting in a full array of shape (input_size,)
print_on_process_0("Warp Output:", allgather(warp_result))

在这个例子中,shard_map 用于在可用设备上分配计算。输入数组 a_in 沿着 ‘x’ 轴进行分片,每个设备处理其本地分片。Warp 内核 multiply_by_two_kernel 应用于每个分片,并将结果组合起来形成最终输出。

这种方法可以有效地并行处理大型数组,因为每个设备同时处理数据的一部分。

要在多个 GPU 上运行此程序,您必须安装 Open MPI。您可以查阅 OpenMPI 安装指南,了解如何安装它。安装 Open MPI 后,您可以使用 mpirun 以及以下命令

mpirun -np <NUM_OF_GPUS> python <filename>.py

DLPack#

Warp 支持 Python Array API 标准 v2022.12 中包含的 DLPack 协议。有关参考,请参阅 DLPack 的 Python 规范

将外部数组导入 Warp 的规范方法是使用 warp.from_dlpack() 函数

warp_array = wp.from_dlpack(external_array)

外部数组可以是 PyTorch 张量、Jax 数组或任何其他与此版本的 DLPack 协议兼容的数组类型。对于 CUDA 数组,此方法要求生产者执行流同步,以确保数组上的操作顺序正确。warp.from_dlpack() 函数要求生产者在数组所在的设备上同步当前的 Warp 流。因此,在该设备上使用 Warp 内核中的数组应该是安全的,无需任何额外的同步。

将 Warp 数组导出到外部框架的规范方法是使用该框架中的 from_dlpack() 函数

jax_array = jax.dlpack.from_dlpack(warp_array)
torch_tensor = torch.utils.dlpack.from_dlpack(warp_array)
paddle_tensor = paddle.utils.dlpack.from_dlpack(warp_array)

对于 CUDA 数组,这会将消费者框架的当前流与数组设备上的当前 Warp 流同步。因此,即使该数组之前已在设备上的 Warp 内核中使用过,也可以安全地在消费者框架中使用包装的数组。

或者,可以使用生产者框架提供的 to_dlpack() 函数显式创建 PyCapsule 来共享数组。此方法可用于不支持 v2022.12 标准的旧版本框架

warp_array1 = wp.from_dlpack(jax.dlpack.to_dlpack(jax_array))
warp_array2 = wp.from_dlpack(torch.utils.dlpack.to_dlpack(torch_tensor))
warp_array3 = wp.from_dlpack(paddle.utils.dlpack.to_dlpack(paddle_tensor))

jax_array = jax.dlpack.from_dlpack(wp.to_dlpack(warp_array))
torch_tensor = torch.utils.dlpack.from_dlpack(wp.to_dlpack(warp_array))
paddle_tensor = paddle.utils.dlpack.from_dlpack(wp.to_dlpack(warp_array))

这种方法通常更快,因为它跳过了任何流同步,但必须使用另一种解决方案来确保操作的正确排序。在不需要同步的情况下,使用这种方法可以产生更好的性能。在以下情况下,这可能是一个不错的选择

  • 外部框架正在使用同步 CUDA 默认流。

  • Warp 和外部框架正在使用相同的 CUDA 流。

  • 已经存在另一种同步机制。

warp.from_dlpack(source, dtype=None)[source]#

将源数组或 DLPack capsule 转换为 Warp 数组,而无需复制。

参数:
  • source – 兼容 DLPack 的数组或 PyCapsule

  • dtype – 用于解释源数据的可选 Warp 数据类型。

Returns:

一个新的 Warp 数组,它使用与输入 pycapsule 相同的底层内存。

返回类型:

array

warp.to_dlpack(wp_array)[source]#

将 Warp 数组转换为另一种类型的兼容 DLPack 的数组。

参数:

wp_array (array) – 将要转换的源 Warp 数组。

Returns:

包含 DLManagedTensor 的 capsule,可以将其转换为另一种数组类型,而无需复制底层内存。

Paddle#

Warp 提供了辅助函数来与 Paddle 相互转换数组

w = wp.array([1.0, 2.0, 3.0], dtype=float, device="cpu")

# convert to Paddle tensor
t = wp.to_paddle(w)

# convert from Paddle tensor
w = wp.from_paddle(t)

这些辅助函数允许在 Warp 数组和 Paddle 张量之间转换,而无需复制底层数据。同时,如果可用,梯度数组和张量会与 Paddle autograd 张量相互转换,从而允许在 Paddle autograd 计算中使用 Warp 数组。

warp.from_paddle(
t,
dtype=None,
requires_grad=None,
grad=None,
return_ctype=False,
)[source]#

将 Paddle 张量转换为 Warp 数组,而无需复制数据。

参数:
  • t (paddle.Tensor) – 要包装的 paddle 张量。

  • dtype (warp.dtype, optional) – 生成的 Warp 数组的目标数据类型。 默认为映射到 Warp 数组值类型的张量值类型。

  • requires_grad (bool, optional) – 结果数组是否应包装张量的梯度(如果存在)(否则将分配 grad 张量)。 默认为张量的 requires_grad 值。

  • grad (paddle.Tensor, optional) – 附加到给定张量的 grad。默认为 None。

  • return_ctype (bool, optional) – 是否返回低级数组描述符而不是 wp.array 对象(更快)。 描述符可以传递给 Warp 内核。

Returns:

包装的数组或数组描述符。

返回类型:

warp.array

warp.to_paddle(a, requires_grad=None)[source]#

将 Warp 数组转换为 Paddle 张量,而无需复制数据。

参数:
  • a (warp.array) – 要转换的 Warp 数组。

  • requires_grad (bool, optional) – 结果张量是否应将数组的梯度(如果存在)转换为 grad 张量。 默认为数组的 requires_grad 值。

Returns:

转换后的张量。

返回类型:

paddle.Tensor

warp.device_from_paddle(paddle_device)[source]#

返回与 Paddle 设备对应的 Warp 设备。

参数:

paddle_device (Place, CPUPlace, CUDAPinnedPlace, CUDAPlace, or str) – Paddle 设备标识符

Raises:

RuntimeError – Paddle 设备没有相应的 Warp 设备

返回类型:

warp.context.Device

warp.device_to_paddle(warp_device)[source]#

返回与 Warp 设备对应的 Paddle 设备字符串。

参数:

warp_device (Device | str | None) – 可以解析为 warp.context.Device 的标识符。

Raises:

RuntimeError – Warp 设备与 PyPaddle 不兼容。

返回类型:

str

warp.dtype_from_paddle(paddle_dtype)[source]#

返回与 Paddle dtype 对应的 Warp dtype。

参数:

paddle_dtype – 具有相应 Warp 数据类型的 paddle.dtype。目前不支持 paddle.bfloat16paddle.complex64paddle.complex128

Raises:

TypeError – 无法找到相应的 Warp 数据类型。

warp.dtype_to_paddle(warp_dtype)[source]#

返回与 Warp dtype 对应的 Paddle dtype。

参数:

warp_dtype – 具有相应 paddle.dtype 的 Warp 数据类型。warp.uint16warp.uint32warp.uint64 映射到相同宽度的带符号整数 paddle.dtype

Raises:

TypeError – 无法找到相应的 PyPaddle 数据类型。

要将 Paddle CUDA 流转换为 Warp CUDA 流,反之亦然,Warp 提供了以下函数

warp.stream_from_paddle(stream_or_device=None)[source]#

从 Paddle CUDA 流转换为 Warp CUDA 流。

示例:使用 warp.from_paddle() 进行优化#

以下是一个示例,展示了如何使用 warp.from_paddle() 通过 Paddle 的 Adam 优化器最小化 Warp 中编写的 2D 点数组的损失函数

import warp as wp
import paddle

# init warp context at beginning
wp.context.init()

@wp.kernel()
def loss(xs: wp.array(dtype=float, ndim=2), l: wp.array(dtype=float)):
    tid = wp.tid()
    wp.atomic_add(l, 0, xs[tid, 0] ** 2.0 + xs[tid, 1] ** 2.0)

# indicate requires_grad so that Warp can accumulate gradients in the grad buffers
xs = paddle.randn([100, 2])
xs.stop_gradient = False
l = paddle.zeros([1])
l.stop_gradient = False
opt = paddle.optimizer.Adam(learning_rate=0.1, parameters=[xs])

wp_xs = wp.from_paddle(xs)
wp_l = wp.from_paddle(l)

tape = wp.Tape()
with tape:
    # record the loss function kernel launch on the tape
    wp.launch(loss, dim=len(xs), inputs=[wp_xs], outputs=[wp_l], device=wp_xs.device)

for i in range(500):
    tape.zero()
    tape.backward(loss=wp_l)  # compute gradients
    # now xs.grad will be populated with the gradients computed by Warp
    opt.step()  # update xs (and thereby wp_xs)

    # these lines are only needed for evaluating the loss
    # (the optimization just needs the gradient, not the loss value)
    wp_l.zero_()
    wp.launch(loss, dim=len(xs), inputs=[wp_xs], outputs=[wp_l], device=wp_xs.device)
    print(f"{i}\tloss: {l.item()}")

示例:使用 warp.to_paddle 进行优化#

当我们在 Warp 中直接声明优化变量并使用 warp.to_paddle() 将它们转换为 Paddle 张量时,需要的代码更少。在这里,我们重新审视了上面的相同示例,现在只需要一次转换为 Paddle 张量即可为 Adam 提供优化变量

import warp as wp
import numpy as np
import paddle

# init warp context at beginning
wp.context.init()

@wp.kernel()
def loss(xs: wp.array(dtype=float, ndim=2), l: wp.array(dtype=float)):
    tid = wp.tid()
    wp.atomic_add(l, 0, xs[tid, 0] ** 2.0 + xs[tid, 1] ** 2.0)

# initialize the optimization variables in Warp
xs = wp.array(np.random.randn(100, 2), dtype=wp.float32, requires_grad=True)
l = wp.zeros(1, dtype=wp.float32, requires_grad=True)
# just a single wp.to_paddle call is needed, Adam optimizes using the Warp array gradients
opt = paddle.optimizer.Adam(learning_rate=0.1, parameters=[wp.to_paddle(xs)])

tape = wp.Tape()
with tape:
    wp.launch(loss, dim=len(xs), inputs=[xs], outputs=[l], device=xs.device)

for i in range(500):
    tape.zero()
    tape.backward(loss=l)
    opt.step()

    l.zero_()
    wp.launch(loss, dim=len(xs), inputs=[xs], outputs=[l], device=xs.device)
    print(f"{i}\tloss: {l.numpy()[0]}")

性能说明#

wp.from_paddle() 函数创建一个 Warp 数组对象,该对象与 Paddle 张量共享数据。虽然此函数不复制数据,但在转换过程中总会产生一些 CPU 开销。如果这些转换频繁发生,则整体程序性能可能会受到影响。作为一般规则,最好避免重复转换相同的张量。而不是

x_t = paddle.arange(n, dtype=paddle.float32).to(device=wp.device_to_paddle(device))
y_t = paddle.ones([n], dtype=paddle.float32).to(device=wp.device_to_paddle(device))

for i in range(10):
    x_w = wp.from_paddle(x_t)
    y_w = wp.from_paddle(y_t)
    wp.launch(saxpy, dim=n, inputs=[x_w, y_w, 1.0], device=device)

尝试仅转换一次数组并重复使用它们

x_t = paddle.arange(n, dtype=paddle.float32).to(device=wp.device_to_paddle(device))
y_t = paddle.ones([n], dtype=paddle.float32).to(device=wp.device_to_paddle(device))

x_w = wp.from_paddle(x_t)
y_w = wp.from_paddle(y_t)

for i in range(10):
    wp.launch(saxpy, dim=n, inputs=[x_w, y_w, 1.0], device=device)

如果无法重用数组(例如,每次迭代都会构造一个新的 Paddle 张量),则将 return_ctype=True 传递给 wp.from_paddle() 应该会产生更快的性能。 将此参数设置为 True 可以避免构造 wp.array 对象,而是返回一个底层数组描述符。 该描述符是一个简单的 C 结构,可以传递给 Warp 内核,而不是 wp.array,但不能在其他需要 wp.array 的地方使用。

for n in range(1, 10):
    # get Paddle tensors for this iteration
    x_t = paddle.arange(n, dtype=paddle.float32).to(device=wp.device_to_paddle(device))
    y_t = paddle.ones([n], dtype=paddle.float32).to(device=wp.device_to_paddle(device))

    # get Warp array descriptors
    x_ctype = wp.from_paddle(x_t, return_ctype=True)
    y_ctype = wp.from_paddle(y_t, return_ctype=True)

    wp.launch(saxpy, dim=n, inputs=[x_ctype, y_ctype, 1.0], device=device)

另一种方法是将 Paddle 张量直接传递给 Warp 内核。 这避免了构造临时 Warp 数组,利用了 Paddle 和 Warp 都支持的标准数组接口(如 __cuda_array_interface__)。 这种方法的主要优点是方便,因为无需调用任何转换函数。 主要限制是它不处理梯度,因为梯度信息未包含在标准数组接口中。 因此,这种技术最适合于不涉及微分的算法。

x = paddle.arange(n, dtype=paddle.float32).to(device=wp.device_to_paddle(device))
y = paddle.ones([n], dtype=paddle.float32).to(device=wp.device_to_paddle(device))

for i in range(10):
    wp.launch(saxpy, dim=n, inputs=[x, y, 1.0], device=device)
python -m warp.examples.benchmarks.benchmark_interop_paddle

示例输出

13990 ms  from_paddle(...)
 5990 ms  from_paddle(..., return_ctype=True)
35167 ms  direct from paddle

默认的 wp.from_paddle() 转换速度最慢。 传递 return_ctype=True 速度最快,因为它跳过了创建临时 Warp 数组对象。 将 Paddle 张量直接传递给 Warp 内核的速度介于两者之间。 它跳过了创建临时 Warp 数组,但访问 Paddle 张量的 __cuda_array_interface__ 属性会增加开销,因为它们是按需初始化的。