GKE AI 系列:在 TPU 上使用 JAX 运行第一个训练作业

GKE AI 系列:在 TPU 上使用 JAX 运行第一个训练作业

训练大型机器学习模型需要专用硬件。Google 的 Tensor Processing Units (TPU) 是其中最强大的芯片之一。然而,入门 TPU 并不容易——你需要了解芯片布局、配置节点池,并管理资源。在翻遍大量文档、排错数小时后,很容易迷失方向。

本系列记录我的学习旅程,并用浅显语言说明 TPU。我假设你熟悉基本 GCP 使用方式(gcloud 命令)以及 Kubernetes 的 Pod、Deployment、Node 与资源管理等概念。文中将使用以下术语。

术语

  • TPU:Google 专为机器学习打造的 AI 加速芯片。TPU 擅长矩阵运算,因此在神经网络训练与推理上比 CPU 或 GPU 更快。TPU v6e(Trillium)是目前 GKE 的最新世代。TPU 以通过 ICI(Inter-Chip Interconnect)互连的“切片”组织,支持多芯片训练。
  • GKE:Google Kubernetes Engine——Google Cloud 的托管 Kubernetes 服务,用于部署与扩展容器化应用。针对 TPU 工作负载,GKE 提供 TPU 节点池、拓扑感知,以及与 Kueue 等调度工具的整合。它在容错与配置简化方面表现良好。
  • JAX:Google 开发的 Python 库,适用于高性能数值计算与机器学习。它结合 NumPy 的 API、自动微分与 XLA 编译,非常适合 TPU 工作负载。JAX 会自动将计算并行化到多个设备,特别适合多芯片 TPU 配置。

理解 TPU 拓扑(简化版)

容易混淆的点: TPU 拓扑听起来很复杂,但其实就是芯片如何排列。

可以这样理解:

  • 一颗 芯片 = 一个处理单元(像一个大脑)
  • 拓扑 = 芯片如何以行 × 列排列(像剧院座位)
  • 2x2 = 2 行 × 2 列 = 共 4 颗芯片

常见 TPU v6e 配置:

拓扑 TPU 芯片 主机 VM 机型 (GKE API) 范围
1x1 1 1/8 1 ct6e-standard-1t 子主机
2x2 4 1/2 1 ct6e-standard-4t 子主机
2x4 8 1 1 ct6e-standard-8t 单主机
2x4 8 1 2 ct6e-standard-4t 单主机

先从基础开始。TPU 芯片是为机器学习工作负载设计的专用处理器。当你听到“2x2 拓扑”,代表 4 颗芯片按 2 行 × 2 列排列。实际上这些芯片未必真的呈现 2x2 方格,它们可能分布在同一主机内不同位置。重点是逻辑拓扑:TPU 运行时如何组织并寻址这 4 颗芯片以进行分布式计算。

下图为 Google 数据中心部署的 TPU v6e(Trillium)托盘。每个托盘包含 4 颗芯片,作为单个机器学习作业的计算单元(来源 1):

实际上,数据中心可以将多个托盘互连成更大的切片,以支持分布式训练工作负载。下图展示芯片与托盘的关系,每个蓝色方块代表一颗芯片,四颗芯片组成一个托盘:

了解芯片如何分配,有助于理解拓扑形态:

  • 1x1 配置中,你只使用单一主机上的 8 颗芯片中的 1 颗。这是最小的 TPU 配置,适合开发、测试或轻量推理工作负载。
  • 2x2 配置中,你只使用 8 颗芯片中的 4 颗(以蓝色标示)。这种子主机(sub-host)配置适合较小实验或开发工作,无需租用整台主机。
  • 2x4 配置中,你会使用单一主机上的 全部 8 颗芯片(物理上可能跨两个托盘)。

下图展示 8 颗 TPU 芯片(蓝色方块)在不同拓扑下的排列方式。

在 GKE 中使用 TPU

1. 解开 TPU 迷思(以 v6e 为例)

在开始写代码之前,我们先了解 TPU 的租用方式。由于本文聚焦在 GKE 上使用 TPU,我将从 GKE 的视角说明 TPU 拓扑以及资源如何分配到节点。

为了说明,我以 GKE Standard 及其相关机型为例。

机型:ct6e-standard-4t

在 GKE 中,TPU 会挂载到虚拟机(节点)上。Trillium(v6e)最容易获得的入门机型是 ct6e-standard-4t

  • 1 个节点(VM) = 1、4 或 8 颗 TPU 芯片。每个 TPU v6e VM 可以包含 1、4 或 8 颗芯片。默认情况下,一个 VM 管理 4 颗芯片,除非你指定其他配置。
  • “全有或全无”规则: 与 CPU 可以申请 0.5 核不同,在 GKE 上使用 TPU 时通常必须一次请求节点上的全部芯片。你无法只请求 1 颗。如果你租用 4 芯片机型,你的 Pod 必须请求 4 颗芯片。

支持 TPU 资源的 GKE 机型遵循命名规则:<version>-hightpu-<chip-count>t。芯片数量表示该 VM 可使用的 TPU 芯片数。例如:

  • ct6e-standard-4t 表示 4 颗 TPU 芯片
  • ct6e-standard-8t 表示 8 颗 TPU 芯片
  • tpu7x-standard-4t 支持 Ironwood(TPU7x),包含 4 颗 TPU 芯片

完整的支持机型清单请见 Plan TPUs in GKE 2

同一机型,不同拓扑的差异

先回到这张表:

拓扑 TPU 芯片 主机 VM 机型 (GKE API) 范围
2x2 4 1/2 1 ct6e-standard-4t 子主机
2x4 8 1 2 ct6e-standard-4t 单主机

你可能注意到上表中 ct6e-standard-4t 出现两次,对应不同拓扑(2x2 与 2x4)。这对新手很容易造成困惑,因此我们来拆解一下。

关键观念: 机型(ct6e-standard-4t)告诉你 VM 可以使用多少芯片,而拓扑描述这些芯片在分布式计算中的逻辑排列

每个 VM 都能看到它所分配的芯片(通常是 4 颗),让你能够编写程序在这些芯片间协同计算。

示例 1:ct6e-standard-4t + 2x2 拓扑

  • 你获得的资源: 1 台 VM,可使用 4 颗 TPU 芯片
  • 物理现实: 这 4 颗芯片来自同一台主机(总共 8 颗),但你只租到了其中一半
  • 逻辑排列: TPU 运行时将其视为 2×2 网格(2 行、2 列)
  • 适用场景: 较小的训练作业

示例 2:ct6e-standard-4t + 2x4 拓扑(多 VM)

  • 你获得的资源: 2 台 VM,每台可使用 4 颗 TPU 芯片(共 8 颗)
  • 物理现实: 这 8 颗芯片来自同一台物理主机,但为了编排被拆成 2 台 VM
  • 逻辑排列: TPU 运行时将其视为 2×4 网格(2 行、4 列)
  • 适用场景: 更大的训练作业

在 2x2 配置中,你的程序可以与全部 4 颗芯片通信;在 2x4 多 VM 配置中,每台 VM 各自可访问 4 颗芯片,TPU 运行时会在两台 VM 之间协作,对外呈现为统一的 8 芯片系统。

为什么这很重要?

在 GKE 中,每台 VM 都是一个 GKE 节点。对于 2x4 多主机配置,你需要跨 2 个节点请求资源,而不是只在 1 个节点上。

下表显示不同拓扑对应的机型,以及工作负载需要的 GKE 节点数量。

拓扑 机型 TPU 芯片总数 TPU 主机 范围 TPU VM 需要的 GKE 节点数?
1x1 ct6e-standard-1t 1 1/4 子主机 1 1
2x2 ct6e-standard-4t 4 1/2 子主机 1 1
2x4 ct6e-standard-4t 8 1 单主机 2 2
2x4 ct6e-standard-8t 8 1 单主机 1 1

理解这个映射关系对于配置节点池非常重要,因为你必须确保集群有足够的节点来承载所选拓扑。

2. 配置 GKE 集群

现在我们已经覆盖了 TPU 拓扑与机型的基础概念,接下来一步步完成实际部署。

我们使用 2x4 拓扑(8 颗芯片)运行一个小型训练作业。首先创建标准 GKE 集群,再为 TPU 添加专用节点池。

CLUSTER_NAME=tpu-quickstart         # GKE 集群名称
CONTROL_PLANE_LOCATION=us-central1  # 请确保所用区域支持 v6e,详见: https://docs.cloud.google.com/tpu/docs/regions-zones

gcloud container clusters create $CLUSTER_NAME --location=$CONTROL_PLANE_LOCATION 

3. 创建带 Flex-Start 的节点池

接下来添加一个专门配置 2x4 TPU 工作负载的节点池。此节点池将使用 Flex-Start 预配模型,确保在需求高峰时也能获得 TPU 容量。

NODE_POOL_NAME=nodepool   # GKE 节点池名称
NODE_ZONES=us-central1-b  # 请确保所用 zone 支持 v6e,详见: https://docs.cloud.google.com/tpu/docs/regions-zones

gcloud container node-pools create $NODE_POOL_NAME \
    --cluster=$CLUSTER_NAME \
    --location=$CONTROL_PLANE_LOCATION \
    --node-locations=$NODE_ZONES \
    --machine-type=ct6e-standard-4t \
    --tpu-topology=2x4 \
    --reservation-affinity=none \
    --enable-autoscaling \
    --num-nodes=0 --min-nodes=0 --max-nodes=2 \
    --flex-start

--flex-start 标志非常关键。它会让 GKE 使用 Flex-Start 预配模型,将节点池创建请求排入队列,并在资源可用时进行配置,而不是因为容量不足而立即失败。

关键参数说明:

  • --tpu-topology=2x4:指定 TPU 芯片的逻辑排列(2×4 网格,共 8 颗芯片)
  • --machine-type=ct6e-standard-4t:每个节点可使用 4 颗 TPU 芯片
  • --enable-autoscaling 搭配 --num-nodes=0:起始为 0 个节点,只有在工作负载被调度时才扩容
  • --max-nodes=2:2×4 拓扑需要 2 台 VM(节点),因此上限设为 2

注意:对 2x4 拓扑而言,--max-nodes=2 是必要条件。如果使用 ct6e-standard-4t--max-nodes=1,将无法满足拓扑要求,只会获得 4 颗芯片,也就退化为 2x2 拓扑。

预配模型:Flex-Start

TPU 属于稀缺资源,当 Google 机器不足时,部署失败是常见痛点。Flex-Start 是专门用来缓解这个问题的模式。

Feature Standard On-Demand Flex-Start
Cost 标准价格 折扣价(通常约 6 折)
Availability 立即(若有容量) 成功率更高(进入队列)

4. 运行批处理工作负载

现在集群与节点池都已配置完成,我们来部署一个简单的训练作业。

我们将创建一个 Job,把它调度到带 Flex-Start VM 的 TPU 节点。Kubernetes 的 Job 控制器会创建一个或多个 Pod,并确保其完成指定任务。

以下是示例 YAML,展示如何在 TPU 节点池上部署训练作业:

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-training-job
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-training-job
spec:
  backoffLimit: 0
  completions: 2
  parallelism: 2
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
          cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
          cloud.google.com/gke-tpu-topology: 2x4
      containers:
      - name: tpu-job
        image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          python -c 'import jax; print("TPU cores:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

当你使用 ct6e-standard-4t 创建 2x4 拓扑的节点池时,GKE 会预配 2 个独立节点(VM),每个节点包含 4 颗 TPU 芯片。

google.com/tpu: "4" 的资源请求告诉 GKE 将作业调度到拥有 4 颗 TPU 芯片的节点,以匹配 2x4 拓扑。配合 parallelism: 2,两个作业将并行运行——每个节点一个。当你应用该 manifest 时,如果节点尚未运行,GKE 的自动扩缩会从 Flex-Start 节点池自动预配所需节点。

Kubernetes 视角:节点调度

运行以下命令查看节点:

kubectl get nodes -o wide

你会看到类似输出:

NAME                                      STATUS   ROLES    AGE   VERSION
gke-cluster-tpu-pool-a1b2c3d4-node-1      Ready    <none>   5m    v1.28.3-gke.1234
gke-cluster-tpu-pool-a1b2c3d4-node-2      Ready    <none>   5m    v1.28.3-gke.1234

要查看这些节点上的 TPU 特定标签:

kubectl get nodes -l cloud.google.com/gke-tpu-topology=2x4 -o json | jq '.items[].metadata.labels'

输出会类似:

{
  "cloud.google.com/gke-tpu-accelerator": "tpu-v6e-slice",
  "cloud.google.com/gke-tpu-topology": "2x4",
  "google.com/tpu": "4",
  ...
}

关键观察:

  • 加速器标签cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice 用于标识 TPU 世代。该标签与 GKE Autopilot 中的枚举配置一致,详见 Plan TPUs on GKE 3(Validate TPU availability in GKE > Autopilot)。
  • 两个独立节点: 2x4 拓扑配合 ct6e-standard-4t 会得到 2 个 Kubernetes 节点
  • 每个节点 4 颗芯片: google.com/tpu: "4" 表示每个节点可用的 TPU 资源数
  • 相同拓扑标签: 两个节点都带有 cloud.google.com/gke-tpu-topology: 2x4

将工作负载分布到两个节点

要完整使用 2x4 拓扑(8 颗芯片),Job 必须将 Pod 调度到两个节点,所以在 Job 中设置 parallelism: 2completions: 2(每个节点一个 Pod)。

调度过程如下:

  1. Job 创建 2 个 Pod: 使用 completionMode: Indexed,为每个 Pod 分配唯一索引(0 和 1)
  2. 每个 Pod 请求 4 颗芯片: 在资源请求中设置 google.com/tpu: "4"
  3. NodeSelector 匹配拓扑: cloud.google.com/gke-tpu-topology: 2x4 确保 Pod 落到正确的节点池
  4. 调度器分配 Pod: Kubernetes 调度器将每个 Pod 放到不同节点,满足资源需求

你可以通过以下命令验证 Pod 的落点:

kubectl get pods -o wide

你应该会看到:

NAME                      READY   STATUS    NODE
tpu-training-job-0-abc    1/1     Running   gke-cluster-tpu-pool-a1b2c3d4-node-1
tpu-training-job-1-def    1/1     Running   gke-cluster-tpu-pool-a1b2c3d4-node-2

-0-1 的后缀来自 Indexed 完成模式,每个 Pod 运行在不同节点上,让你可以使用 2x4 拓扑的全部 8 颗芯片。

5. 查看训练作业

部署作业后,你可以监控进度并确认 TPU 资源是否被正确使用。下面检查作业状态并查看日志,确认一切正常。

检查 Job 状态

首先确认 Job 已创建并运行:

kubectl get jobs

你会看到类似输出:

NAME                COMPLETIONS   DURATION   AGE
tpu-training-job    0/2           5s         5s

COMPLETIONS 显示 0/2,表示 2 个 Pod 中有 0 个已完成。待两个 Pod 成功完成后会变为 2/2

查看 Pod 日志

查看每个 Pod 的输出:

kubectl logs tpu-training-job-0-<pod-suffix>
kubectl logs tpu-training-job-1-<pod-suffix>

你应该会看到 JAX 检测到 TPU 核心的输出:

TPU cores: 4

每个 Pod 看到 4 个 TPU 核心,因为它们分别运行在拥有 4 颗芯片的节点上。合起来就是 2x4 拓扑的 8 颗芯片。

常见问题排查

如果 Pod 一直停在 Pending,请查看事件:

kubectl describe pod tpu-training-job-0-<pod-suffix>

常见问题包括:

  • TPU 资源不足: Flex-Start 节点池仍在预配中,可能需要几分钟
  • 节点选择器不匹配: 确认 nodeSelector 标签与节点实际标签一致
  • 资源配额超限: 检查项目的 TPU 配额限制

进阶:在单一主机启动 8 颗芯片(单 VM)

如果你希望更简化的部署,或想让所有 TPU 芯片都在单一 VM 上,可以使用提供 8 颗芯片的更大机型来配置 2x4 拓扑。

不要使用会创建 2 台 VM(每台 4 颗芯片)的 ct6e-standard-4t,改用 ct6e-standard-8t 获取单一 VM 的 8 颗芯片:

# 在 us-central2-b 的区域集群 "my-cluster" 创建 TPU 节点池
gcloud container node-pools create tpu-pool \
    --cluster=my-cluster \
    --location=us-central2-b \
    --node-locations=us-central2-b \
    --machine-type=ct6e-standard-8t \
    --tpu-topology=2x4 \
    --reservation-affinity=none \
    --enable-autoscaling \
    --num-nodes=0 \
    --min-nodes=0 \
    --max-nodes=1 \
    --flex-start

主要差异:

  • 机型: 使用 ct6e-standard-8t,而非 ct6e-standard-4t
  • 最大节点数: --max-nodes=1,因为只需要一台 VM
  • 单一 Kubernetes 节点: kubectl get nodes 只会看到一个节点
  • 可用 8 颗芯片: 单个 Pod 即可访问全部 8 颗芯片

将 Job manifest 改为只运行一个 Pod:

apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-training-job-single-host
spec:
  backoffLimit: 0
  completions: 1
  parallelism: 1
  template:
    spec:
      restartPolicy: Never
      nodeSelector:
          cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
          cloud.google.com/gke-tpu-topology: 2x4
      containers:
      - name: tpu-job
        image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          python -c 'import jax; print("TPU cores:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 8
          limits:
            google.com/tpu: 8

此配置更简单,原因包括:

  • 只需管理一个 Pod,无需协调多个 Pod
  • 不需要 headless Service 进行 Pod 间通信
  • 更易排查问题,因为所有日志都集中在单个 Pod

不过,多 VM 方案(ct6e-standard-4t)在需要明确控制工作负载分布于不同 VM 时,能提供更灵活的分布式训练方式。

结论

本文完整走过在 GKE 上使用 TPU v6e 运行第一个 JAX 训练作业的流程:创建 GKE 集群、配置带 Flex-Start 的 TPU 节点池、部署训练作业,并理解 Kubernetes 如何在 TPU 节点之间调度工作负载。无论你选择多 VM 的分布式训练灵活性,还是单 VM 的简化部署,GKE 都能提供可扩展的基础设施,把 ML 工作负载从实验原型推向生产。

参考资料

Eason Cao
Eason Cao Eason is an engineer working at FANNG and living in Europe. He was accredited as AWS Professional Solution Architect, AWS Professional DevOps Engineer and CNCF Certified Kubernetes Administrator. He started his Kubernetes journey in 2017 and enjoys solving real-world business problems.
comments powered by Disqus