计算库层
在 XLA(Accelerated Linear Algebra)中使用自定义调用(Custom Call)机制,结合 XLA FFI(外部函数接口,Foreign Function Interface)来实现用户定义的操作。使用自定义调用在 CPU 上计算 A[i] = B[i % 128]+ C[i]
。
xla::XlaBuilder
:XLA 提供的用于构建计算图的类,这里实例化了一个名为 "do_it" 的构建器b
。xla::Parameter
:定义两个输入参数param0
和param1
。其中param0
是一个长度为 128 的 1D 浮点型(F32)数组,param1
是长度为 2048 的 1D 浮点型数组。xla::CustomCall
:这是 XLA 中执行自定义操作的关键调用。通过传递"do_custom_call"
字符串来指定自定义调用的名称,表示需要调用一个外部定义的函数。该自定义操作接收两个输入(param0
和param1
),输出结果的形状是一个长度为 2048 的 F32 数组。BufferF32
:这是 XLA FFI 中的类型别名,表示一个 1D 的浮点型(F32)缓冲区。- in0
和
in1是输入缓冲区(分别为 param0 和 param1 的数据),它们的数据类型为
BufferF32out
是输出缓冲区,存储结果。该函数的逻辑为:将in0
和in1
中的数据进行逐元素相加,并将结果写入输出缓冲区。注意这里通过i % d0
来处理in0
,使得其在计算时按顺序重复。assert
检查输出缓冲区的维度,确保与in1
的维度相同。 - 定义了一个处理器
handler
,并将它绑定到do_custom_call
函数上。通过这种绑定,FFI 可以知道自定义调用应该如何匹配到 C++ 函数。绑定过程中明确指定了函数的参数类型和返回值类型为Buffer
(即 1D 缓冲区)。 - 将处理器
handler
注册到 XLA FFI,表示它将在 "Host" 平台上可用。 "do_custom_call"
是自定义调用的名称,与xla::CustomCall
中使用的名称一致。xla::ffi::GetXlaFfiApi()
获取当前的 XLA FFI API 实例,确保处理器能够正确注册到 XLA。
#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
/*opaque=*/"", /*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}
// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;
// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
// Check that dimensions are compatible.
assert(out->dimensions[0] == d1 && "unexpected dimensions");
for (size_t i = 0; i < d1; ++i) {
out->data[i] = in0.data[i % d0] + in1.data[i];
}
}
// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Arg<Buffer>()
.Arg<Buffer>()
.Ret<Buffer>());
// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"Host", handler);
在原有的 XLA 的自定义调用实现上进行了扩展,增加了 GPU 加速部分,主要通过并行处理自定义操作的逻辑,计算 A[i] = B[i % 128] + C[i]
。
- 构建了 XLA 的计算图,通过
xla::CustomCall
调用了名为"do_custom_call"
的自定义操作。它定义了两个输入参数param0
和param1
,并设置输出为长度为 2048 的浮点数数组。 const float* in0, const float* in1, float* out
:输入in0
和in1
是常量浮点型数组指针,out
是输出数组指针。size_t idx = blockIdx.x * blockDim.x + threadIdx.x
:计算当前线程的全局索引idx
。blockIdx.x
是当前线程块的索引,blockDim.x
是每个线程块的大小,threadIdx.x
是当前线程在块内的索引。out[idx] = in0[idx % 128] + in1[idx]
:对于每个线程,执行in0[idx % 128] + in1[idx]
,并将结果写入out[idx]
。in0
的大小为 128,因此使用% 128
使得in0
的数据循环重复使用,而in1
和out
都是长度为 2048。block_dim
和grid_dim
:用于定义kernel 的执行配置。block_dim
设置为 64,表示每个线程块中有 64 个线程,grid_dim
设置为2048 / 64
,即 32 个线程块。每个线程块并行处理 64 个数据元素,共 2048 个数据元素。custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(in0.data, in1.data, out->data)
:通过启动custom_call_kernel
内核,传入输入和输出数据指针,以及流stream
,让 GPU 并行执行数据计算。XLA_FFI_DEFINE_HANDLER
:定义一个新的 XLA FFI 处理器handler
,并将其绑定到do_custom_call
函数。.Ctx<xla::ffi::PlatformStream<CUstream>>()
:这行代码表明do_custom_call
函数需要接受一个流CUstream
作为上下文,以便在 GPU 上执行自定义调用。.Arg<BufferF32>()
:定义两个参数,类型为BufferF32
(浮点数组)。.Ret<BufferF32>()
:定义返回值为BufferF32
。XLA_FFI_REGISTER_HANDLER
:将定义好的handler
注册到 XLA FFI 中,使得 XLA 可以识别并调用这个自定义操作。
#include <hip/hip_runtime.h>
#include <cassert>
// 定义与原 CUDA 代码相同的实现
void do_it() { /* same implementation as above */ }
// 自定义核函数,使用 HIP 语法
__global__ void custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(hipStream_t stream, BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
size_t d2 = out->dimensions[0];
assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");
const int64_t block_dim = 64;
const int64_t grid_dim = 2048 / block_dim;
// 使用 HIP 语法调用核函数
hipLaunchKernelGGL(custom_call_kernel, dim3(grid_dim), dim3(block_dim), 0, stream, in0.data, in1.data, out->data);
}
// 使用 ROCm 注册自定义调用处理程序
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Ctx<xla::ffi::PlatformStream<hipStream_t>>() // 使用 hipStream_t
.Arg<BufferF32>()
.Arg<BufferF32>()
.Ret<BufferF32>());
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"ROCm", handler); // 将 "CUDA" 更改为 "ROCm"
为 TensorFlow 启用 XLA,使用@tf.function(jit_compile=True)进行显式编译,显式编译 API 提供精细的控制,用于选择应编译哪些函数。例如,以下执行 MNIST 训练的 TensorFlow 函数使用 XLA 进行编译:
@tf.function(jit_compile=True)
def train_mnist(images, labels):
images, labels = cast(images, labels)
with tf.GradientTape() as tape:
predicted_labels = layer(images)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=predicted_labels, labels=labels
))
layer_variables = layer.trainable_variables
grads = tape.gradient(loss, layer_variables)
optimizer.apply_gradients(zip(grads, layer_variables))
tfcompile
是 XLA编译器工具,可以将 TensorFlow 图进行提前(AOT)编译,生成可执行代码。它有助于减少二进制文件的整体大小,并能避免部分运行时开销。tfcompile
接收一个子图(通过 TensorFlow 的 Feed 和 Fetch 概念定义),并生成实现该子图的函数。Feed 对应函数的输入参数,Fetch 对应函数的输出参数。所有输入必须通过 Feed 完全指定,生成的子图不能包含占位符或变量节点。通常的做法是将所有占位符和变量标记为 Feed,以确保最终生成的子图中没有这些节点。
使用 tfcompile 编译 TensorFlow 子图,首先,需要定义一个简单的 TensorFlow 模型或子图。以下是一个定义子图的示例,输入为标量,输出为其平方。
import tensorflow as tf
# 创建计算图
def simple_graph(x):
return tf.math.square(x)
# 输入符号化
x = tf.placeholder(dtype=tf.float32, shape=(), name='input')
# 定义子图
y = simple_graph(x)
# 将计算图保存到文件
with tf.Session() as sess:
tf.io.write_graph(sess.graph_def, './', 'simple_graph.pbtxt')
tfcompile
需要一个配置文件,指定输入、输出及其他信息。配置文件 config.pbtxt
示例:
# config.pbtxt
feed {
id { node_name: "input" }
shape { dim { size: 1 } } # 指定输入张量的形状
}
fetch {
id { node_name: "Square" } # 这是子图输出节点的名称
}
使用 tfcompile
编译器编译生成可执行二进制文件。生成的 .o
文件还需要链接到可执行程序。下面是 C++ 示例,展示如何使用生成的二进制文件:
#include <iostream>
#include "compiled_graph.o"
int main() {
// 创建输入张量
MyCompiledGraph computation;
float input_value = 3.0;
float output_value;
// 执行计算
computation.compute(&input_value, &output_value);
std::cout << "输入值: " << input_value << " 的平方是: " << output_value << std::endl;
return 0;
}
编译运行后输出如下内容:
输入值: 3 的平方是: 9
为 pytorch启用 XLA,PyTorch/XLA 使用与常规 PyTorch 相同的接口,但有一些附加功能。导入会torch_xla
初始化 PyTorch/XLA,并 xm.xla_device()
返回当前 XLA 设备。
import torch
import torch_xla
import torch_xla.core.xla_model as xm
t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)
结果
xla:0
tensor([[ 0.1028, -1.4783],
[-0.4271, 1.3415]], device='xla:0')
与其他设备类型一样,XLA 张量仅与同一设备上的其他 XLA 张量配合使用。
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor
张量从 CPU 移动到 XLA 设备:当张量从 CPU 移动到 XLA 设备(如 TPU、GPU)时,数据会被复制到目标设备的内存中。这意味着可以在加速硬件上执行计算。同样,XLA 设备上的张量可以移动回 CPU,在这个过程中,数据会从设备复制回 CPU 的内存。一旦张量数据被复制到另一台设备,两个设备上的张量副本之间不会有任何联系。每个副本在各自的设备内存中独立存在。
应在保存之前将 XLA 张量移至 CPU,如以下代码片段所示:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)
tensors = (t0.cpu(), t1.cpu())
torch.save(tensors, 'tensors.pt')
tensors = torch.load('tensors.pt')
t0 = tensors[0].to(device)
t1 = tensors[1].to(device)
print(t0)
print(t1)
结果
tensor([[ 0.1028, -1.4783],
[-0.4271, 1.3415]], device='xla:0')
tensor([[ 0.8257, 0.3266],
[ 0.9146, -0.2747]], device='xla:0')