计算库层

在 XLA(Accelerated Linear Algebra)中使用自定义调用(Custom Call)机制,结合 XLA FFI(外部函数接口,Foreign Function Interface)来实现用户定义的操作。使用自定义调用在 CPU 上计算 A[i] = B[i % 128]+ C[i]

  • xla::XlaBuilder:XLA 提供的用于构建计算图的类,这里实例化了一个名为 "do_it" 的构建器 b
  • xla::Parameter:定义两个输入参数 param0param1。其中 param0 是一个长度为 128 的 1D 浮点型(F32)数组,param1 是长度为 2048 的 1D 浮点型数组。
  • xla::CustomCall:这是 XLA 中执行自定义操作的关键调用。通过传递 "do_custom_call" 字符串来指定自定义调用的名称,表示需要调用一个外部定义的函数。该自定义操作接收两个输入(param0param1),输出结果的形状是一个长度为 2048 的 F32 数组。
  • BufferF32:这是 XLA FFI 中的类型别名,表示一个 1D 的浮点型(F32)缓冲区。
  • in0in1是输入缓冲区(分别为 param0 和 param1 的数据),它们的数据类型为BufferF32out 是输出缓冲区,存储结果。该函数的逻辑为:将 in0in1 中的数据进行逐元素相加,并将结果写入输出缓冲区。注意这里通过 i % d0 来处理 in0,使得其在计算时按顺序重复。assert 检查输出缓冲区的维度,确保与 in1 的维度相同。
  • 定义了一个处理器 handler,并将它绑定到 do_custom_call 函数上。通过这种绑定,FFI 可以知道自定义调用应该如何匹配到 C++ 函数。绑定过程中明确指定了函数的参数类型和返回值类型为 Buffer(即 1D 缓冲区)。
  • 将处理器 handler 注册到 XLA FFI,表示它将在 "Host" 平台上可用。
  • "do_custom_call" 是自定义调用的名称,与 xla::CustomCall 中使用的名称一致。
  • xla::ffi::GetXlaFfiApi() 获取当前的 XLA FFI API 实例,确保处理器能够正确注册到 XLA。
#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

在原有的 XLA 的自定义调用实现上进行了扩展,增加了 GPU 加速部分,主要通过 CUDA 来并行处理自定义操作的逻辑,计算 A[i] = B[i % 128] + C[i]

  • 构建了 XLA 的计算图,通过 xla::CustomCall 调用了名为 "do_custom_call" 的自定义操作。它定义了两个输入参数 param0param1,并设置输出为长度为 2048 的浮点数数组。
  • const float* in0, const float* in1, float* out:输入 in0in1 是常量浮点型数组指针,out 是输出数组指针。size_t idx = blockIdx.x * blockDim.x + threadIdx.x:计算当前线程的全局索引 idxblockIdx.x 是当前线程块的索引,blockDim.x 是每个线程块的大小,threadIdx.x 是当前线程在块内的索引。out[idx] = in0[idx % 128] + in1[idx]:对于每个线程,执行 in0[idx % 128] + in1[idx],并将结果写入 out[idx]in0 的大小为 128,因此使用 % 128 使得 in0 的数据循环重复使用,而 in1out 都是长度为 2048。
  • block_dimgrid_dim:用于定义 CUDA kernel 的执行配置。block_dim 设置为 64,表示每个线程块中有 64 个线程,grid_dim 设置为 2048 / 64,即 32 个线程块。每个线程块并行处理 64 个数据元素,共 2048 个数据元素。
  • custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(in0.data, in1.data, out->data):通过 CUDA 启动 custom_call_kernel 内核,传入输入和输出数据指针,以及 CUDA 流 stream,让 GPU 并行执行数据计算。
  • XLA_FFI_DEFINE_HANDLER:定义一个新的 XLA FFI 处理器 handler,并将其绑定到 do_custom_call 函数。
  • .Ctx<xla::ffi::PlatformStream<CUstream>>():这行代码表明 do_custom_call 函数需要接受一个 CUDA 流 CUstream 作为上下文,以便在 GPU 上执行自定义调用。
  • .Arg<BufferF32>():定义两个参数,类型为 BufferF32(浮点数组)。.Ret<BufferF32>():定义返回值为 BufferF32
  • XLA_FFI_REGISTER_HANDLER:将定义好的 handler 注册到 XLA FFI 中,使得 XLA 可以识别并调用这个自定义操作。
void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

为 TensorFlow 启用 XLA,使用@tf.function(jit_compile=True)进行显式编译,显式编译 API 提供精细的控制,用于选择应编译哪些函数。例如,以下执行 MNIST 训练的 TensorFlow 函数使用 XLA 进行编译:

@tf.function(jit_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

tfcompile 是 XLA编译器工具,可以将 TensorFlow 图进行提前(AOT)编译,生成可执行代码。它有助于减少二进制文件的整体大小,并能避免部分运行时开销。tfcompile 接收一个子图(通过 TensorFlow 的 Feed 和 Fetch 概念定义),并生成实现该子图的函数。Feed 对应函数的输入参数,Fetch 对应函数的输出参数。所有输入必须通过 Feed 完全指定,生成的子图不能包含占位符或变量节点。通常的做法是将所有占位符和变量标记为 Feed,以确保最终生成的子图中没有这些节点。

使用 tfcompile 编译 TensorFlow 子图,首先,需要定义一个简单的 TensorFlow 模型或子图。以下是一个定义子图的示例,输入为标量,输出为其平方。

import tensorflow as tf

# 创建计算图
def simple_graph(x):
    return tf.math.square(x)

# 输入符号化
x = tf.placeholder(dtype=tf.float32, shape=(), name='input')

# 定义子图
y = simple_graph(x)

# 将计算图保存到文件
with tf.Session() as sess:
    tf.io.write_graph(sess.graph_def, './', 'simple_graph.pbtxt')

tfcompile 需要一个配置文件,指定输入、输出及其他信息。配置文件 config.pbtxt 示例:

# config.pbtxt
feed {
  id { node_name: "input" }
  shape { dim { size: 1 } }  # 指定输入张量的形状
}
fetch {
  id { node_name: "Square" }  # 这是子图输出节点的名称
}

使用 tfcompile 编译器编译生成可执行二进制文件。生成的 .o 文件还需要链接到可执行程序。下面是 C++ 示例,展示如何使用生成的二进制文件:

#include <iostream>
#include "compiled_graph.o"

int main() {
    // 创建输入张量
    MyCompiledGraph computation;
    float input_value = 3.0;
    float output_value;

    // 执行计算
    computation.compute(&input_value, &output_value);

    std::cout << "输入值: " << input_value << " 的平方是: " << output_value << std::endl;
    return 0;
}

编译运行后输出如下内容:

输入值: 3 的平方是: 9

为 pytorch启用 XLA,PyTorch/XLA 使用与常规 PyTorch 相同的接口,但有一些附加功能。导入会torch_xla初始化 PyTorch/XLA,并 xm.xla_device()返回当前 XLA 设备。

import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)

结果

xla:0
tensor([[ 0.1028, -1.4783],
        [-0.4271,  1.3415]], device='xla:0')

与其他设备类型一样,XLA 张量仅与同一设备上的其他 XLA 张量配合使用。

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor

张量从 CPU 移动到 XLA 设备:当张量从 CPU 移动到 XLA 设备(如 TPU、GPU)时,数据会被复制到目标设备的内存中。这意味着可以在加速硬件上执行计算。同样,XLA 设备上的张量可以移动回 CPU,在这个过程中,数据会从设备复制回 CPU 的内存。一旦张量数据被复制到另一台设备,两个设备上的张量副本之间不会有任何联系。每个副本在各自的设备内存中独立存在。

应在保存之前将 XLA 张量移至 CPU,如以下代码片段所示:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)

tensors = (t0.cpu(), t1.cpu())

torch.save(tensors, 'tensors.pt')

tensors = torch.load('tensors.pt')

t0 = tensors[0].to(device)
t1 = tensors[1].to(device)
print(t0)
print(t1)

结果

tensor([[ 0.1028, -1.4783],
        [-0.4271,  1.3415]], device='xla:0')
tensor([[ 0.8257,  0.3266],
        [ 0.9146, -0.2747]], device='xla:0')