GKE AI 系列:在 TPU 上使用 JAX 运行第一个训练作业
训练大型机器学习模型需要专用硬件。Google 的 Tensor Processing Units (TPU) 是其中最强大的芯片之一。然而,入门 TPU 并不容易——你需要了解芯片布局、配置节点池,并管理资源。在翻遍大量文档、排错数小时后,很容易迷失方向。
本系列记录我的学习旅程,并用浅显语言说明 TPU。我假设你熟悉基本 GCP 使用方式(gcloud 命令)以及 Kubernetes 的 Pod、Deployment、Node 与资源管理等概念。文中将使用以下术语。
术语
- TPU:Google 专为机器学习打造的 AI 加速芯片。TPU 擅长矩阵运算,因此在神经网络训练与推理上比 CPU 或 GPU 更快。TPU v6e(Trillium)是目前 GKE 的最新世代。TPU 以通过 ICI(Inter-Chip Interconnect)互连的“切片”组织,支持多芯片训练。
- GKE:Google Kubernetes Engine——Google Cloud 的托管 Kubernetes 服务,用于部署与扩展容器化应用。针对 TPU 工作负载,GKE 提供 TPU 节点池、拓扑感知,以及与 Kueue 等调度工具的整合。它在容错与配置简化方面表现良好。
- JAX:Google 开发的 Python 库,适用于高性能数值计算与机器学习。它结合 NumPy 的 API、自动微分与 XLA 编译,非常适合 TPU 工作负载。JAX 会自动将计算并行化到多个设备,特别适合多芯片 TPU 配置。
理解 TPU 拓扑(简化版)
容易混淆的点: TPU 拓扑听起来很复杂,但其实就是芯片如何排列。
可以这样理解:
- 一颗 芯片 = 一个处理单元(像一个大脑)
- 拓扑 = 芯片如何以行 × 列排列(像剧院座位)
- 2x2 = 2 行 × 2 列 = 共 4 颗芯片
常见 TPU v6e 配置:
| 拓扑 | TPU 芯片 | 主机 | VM | 机型 (GKE API) | 范围 |
|---|---|---|---|---|---|
| 1x1 | 1 | 1/8 | 1 | ct6e-standard-1t | 子主机 |
| 2x2 | 4 | 1/2 | 1 | ct6e-standard-4t | 子主机 |
| 2x4 | 8 | 1 | 1 | ct6e-standard-8t | 单主机 |
| 2x4 | 8 | 1 | 2 | ct6e-standard-4t | 单主机 |
先从基础开始。TPU 芯片是为机器学习工作负载设计的专用处理器。当你听到“2x2 拓扑”,代表 4 颗芯片按 2 行 × 2 列排列。实际上这些芯片未必真的呈现 2x2 方格,它们可能分布在同一主机内不同位置。重点是逻辑拓扑:TPU 运行时如何组织并寻址这 4 颗芯片以进行分布式计算。
下图为 Google 数据中心部署的 TPU v6e(Trillium)托盘。每个托盘包含 4 颗芯片,作为单个机器学习作业的计算单元(来源 1):

实际上,数据中心可以将多个托盘互连成更大的切片,以支持分布式训练工作负载。下图展示芯片与托盘的关系,每个蓝色方块代表一颗芯片,四颗芯片组成一个托盘:

了解芯片如何分配,有助于理解拓扑形态:
- 在 1x1 配置中,你只使用单一主机上的 8 颗芯片中的 1 颗。这是最小的 TPU 配置,适合开发、测试或轻量推理工作负载。
- 在 2x2 配置中,你只使用 8 颗芯片中的 4 颗(以蓝色标示)。这种子主机(sub-host)配置适合较小实验或开发工作,无需租用整台主机。
- 在 2x4 配置中,你会使用单一主机上的 全部 8 颗芯片(物理上可能跨两个托盘)。
下图展示 8 颗 TPU 芯片(蓝色方块)在不同拓扑下的排列方式。

在 GKE 中使用 TPU
1. 解开 TPU 迷思(以 v6e 为例)
在开始写代码之前,我们先了解 TPU 的租用方式。由于本文聚焦在 GKE 上使用 TPU,我将从 GKE 的视角说明 TPU 拓扑以及资源如何分配到节点。
为了说明,我以 GKE Standard 及其相关机型为例。
机型:ct6e-standard-4t
在 GKE 中,TPU 会挂载到虚拟机(节点)上。Trillium(v6e)最容易获得的入门机型是 ct6e-standard-4t。
- 1 个节点(VM) = 1、4 或 8 颗 TPU 芯片。每个 TPU v6e VM 可以包含 1、4 或 8 颗芯片。默认情况下,一个 VM 管理 4 颗芯片,除非你指定其他配置。
- “全有或全无”规则: 与 CPU 可以申请 0.5 核不同,在 GKE 上使用 TPU 时通常必须一次请求节点上的全部芯片。你无法只请求 1 颗。如果你租用 4 芯片机型,你的 Pod 必须请求 4 颗芯片。
支持 TPU 资源的 GKE 机型遵循命名规则:<version>-hightpu-<chip-count>t。芯片数量表示该 VM 可使用的 TPU 芯片数。例如:
ct6e-standard-4t表示 4 颗 TPU 芯片ct6e-standard-8t表示 8 颗 TPU 芯片tpu7x-standard-4t支持 Ironwood(TPU7x),包含 4 颗 TPU 芯片
完整的支持机型清单请见 Plan TPUs in GKE 2。
同一机型,不同拓扑的差异
先回到这张表:
| 拓扑 | TPU 芯片 | 主机 | VM | 机型 (GKE API) | 范围 |
|---|---|---|---|---|---|
| 2x2 | 4 | 1/2 | 1 | ct6e-standard-4t | 子主机 |
| 2x4 | 8 | 1 | 2 | ct6e-standard-4t | 单主机 |
你可能注意到上表中 ct6e-standard-4t 出现两次,对应不同拓扑(2x2 与 2x4)。这对新手很容易造成困惑,因此我们来拆解一下。
关键观念: 机型(ct6e-standard-4t)告诉你 VM 可以使用多少芯片,而拓扑描述这些芯片在分布式计算中的逻辑排列。
每个 VM 都能看到它所分配的芯片(通常是 4 颗),让你能够编写程序在这些芯片间协同计算。
示例 1:ct6e-standard-4t + 2x2 拓扑
- 你获得的资源: 1 台 VM,可使用 4 颗 TPU 芯片
- 物理现实: 这 4 颗芯片来自同一台主机(总共 8 颗),但你只租到了其中一半
- 逻辑排列: TPU 运行时将其视为 2×2 网格(2 行、2 列)
- 适用场景: 较小的训练作业
示例 2:ct6e-standard-4t + 2x4 拓扑(多 VM)
- 你获得的资源: 2 台 VM,每台可使用 4 颗 TPU 芯片(共 8 颗)
- 物理现实: 这 8 颗芯片来自同一台物理主机,但为了编排被拆成 2 台 VM
- 逻辑排列: TPU 运行时将其视为 2×4 网格(2 行、4 列)
- 适用场景: 更大的训练作业

在 2x2 配置中,你的程序可以与全部 4 颗芯片通信;在 2x4 多 VM 配置中,每台 VM 各自可访问 4 颗芯片,TPU 运行时会在两台 VM 之间协作,对外呈现为统一的 8 芯片系统。
为什么这很重要?
在 GKE 中,每台 VM 都是一个 GKE 节点。对于 2x4 多主机配置,你需要跨 2 个节点请求资源,而不是只在 1 个节点上。
下表显示不同拓扑对应的机型,以及工作负载需要的 GKE 节点数量。
| 拓扑 | 机型 | TPU 芯片总数 | TPU 主机 | 范围 | TPU VM | 需要的 GKE 节点数? |
|---|---|---|---|---|---|---|
| 1x1 | ct6e-standard-1t | 1 | 1/4 | 子主机 | 1 | 1 |
| 2x2 | ct6e-standard-4t | 4 | 1/2 | 子主机 | 1 | 1 |
| 2x4 | ct6e-standard-4t | 8 | 1 | 单主机 | 2 | 2 |
| 2x4 | ct6e-standard-8t | 8 | 1 | 单主机 | 1 | 1 |
理解这个映射关系对于配置节点池非常重要,因为你必须确保集群有足够的节点来承载所选拓扑。
2. 配置 GKE 集群
现在我们已经覆盖了 TPU 拓扑与机型的基础概念,接下来一步步完成实际部署。
我们使用 2x4 拓扑(8 颗芯片)运行一个小型训练作业。首先创建标准 GKE 集群,再为 TPU 添加专用节点池。
CLUSTER_NAME=tpu-quickstart # GKE 集群名称
CONTROL_PLANE_LOCATION=us-central1 # 请确保所用区域支持 v6e,详见: https://docs.cloud.google.com/tpu/docs/regions-zones
gcloud container clusters create $CLUSTER_NAME --location=$CONTROL_PLANE_LOCATION
3. 创建带 Flex-Start 的节点池
接下来添加一个专门配置 2x4 TPU 工作负载的节点池。此节点池将使用 Flex-Start 预配模型,确保在需求高峰时也能获得 TPU 容量。
NODE_POOL_NAME=nodepool # GKE 节点池名称
NODE_ZONES=us-central1-b # 请确保所用 zone 支持 v6e,详见: https://docs.cloud.google.com/tpu/docs/regions-zones
gcloud container node-pools create $NODE_POOL_NAME \
--cluster=$CLUSTER_NAME \
--location=$CONTROL_PLANE_LOCATION \
--node-locations=$NODE_ZONES \
--machine-type=ct6e-standard-4t \
--tpu-topology=2x4 \
--reservation-affinity=none \
--enable-autoscaling \
--num-nodes=0 --min-nodes=0 --max-nodes=2 \
--flex-start
--flex-start 标志非常关键。它会让 GKE 使用 Flex-Start 预配模型,将节点池创建请求排入队列,并在资源可用时进行配置,而不是因为容量不足而立即失败。
关键参数说明:
--tpu-topology=2x4:指定 TPU 芯片的逻辑排列(2×4 网格,共 8 颗芯片)--machine-type=ct6e-standard-4t:每个节点可使用 4 颗 TPU 芯片--enable-autoscaling搭配--num-nodes=0:起始为 0 个节点,只有在工作负载被调度时才扩容--max-nodes=2:2×4 拓扑需要 2 台 VM(节点),因此上限设为 2
注意:对 2x4 拓扑而言,--max-nodes=2 是必要条件。如果使用 ct6e-standard-4t 且 --max-nodes=1,将无法满足拓扑要求,只会获得 4 颗芯片,也就退化为 2x2 拓扑。
预配模型:Flex-Start
TPU 属于稀缺资源,当 Google 机器不足时,部署失败是常见痛点。Flex-Start 是专门用来缓解这个问题的模式。
| Feature | Standard On-Demand | Flex-Start |
|---|---|---|
| Cost | 标准价格 | 折扣价(通常约 6 折) |
| Availability | 立即(若有容量) | 成功率更高(进入队列) |
4. 运行批处理工作负载
现在集群与节点池都已配置完成,我们来部署一个简单的训练作业。
我们将创建一个 Job,把它调度到带 Flex-Start VM 的 TPU 节点。Kubernetes 的 Job 控制器会创建一个或多个 Pod,并确保其完成指定任务。
以下是示例 YAML,展示如何在 TPU 节点池上部署训练作业:
apiVersion: v1
kind: Service
metadata:
name: headless-svc
spec:
clusterIP: None
selector:
job-name: tpu-training-job
---
apiVersion: batch/v1
kind: Job
metadata:
name: tpu-training-job
spec:
backoffLimit: 0
completions: 2
parallelism: 2
completionMode: Indexed
template:
spec:
subdomain: headless-svc
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x4
containers:
- name: tpu-job
image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
ports:
- containerPort: 8471 # Default port using which TPU VMs communicate
- containerPort: 8431 # Port to export TPU runtime metrics, if supported.
securityContext:
privileged: true
command:
- bash
- -c
- |
python -c 'import jax; print("TPU cores:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
当你使用 ct6e-standard-4t 创建 2x4 拓扑的节点池时,GKE 会预配 2 个独立节点(VM),每个节点包含 4 颗 TPU 芯片。
google.com/tpu: "4" 的资源请求告诉 GKE 将作业调度到拥有 4 颗 TPU 芯片的节点,以匹配 2x4 拓扑。配合 parallelism: 2,两个作业将并行运行——每个节点一个。当你应用该 manifest 时,如果节点尚未运行,GKE 的自动扩缩会从 Flex-Start 节点池自动预配所需节点。
Kubernetes 视角:节点调度
运行以下命令查看节点:
kubectl get nodes -o wide
你会看到类似输出:
NAME STATUS ROLES AGE VERSION
gke-cluster-tpu-pool-a1b2c3d4-node-1 Ready <none> 5m v1.28.3-gke.1234
gke-cluster-tpu-pool-a1b2c3d4-node-2 Ready <none> 5m v1.28.3-gke.1234
要查看这些节点上的 TPU 特定标签:
kubectl get nodes -l cloud.google.com/gke-tpu-topology=2x4 -o json | jq '.items[].metadata.labels'
输出会类似:
{
"cloud.google.com/gke-tpu-accelerator": "tpu-v6e-slice",
"cloud.google.com/gke-tpu-topology": "2x4",
"google.com/tpu": "4",
...
}
关键观察:
- 加速器标签:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice用于标识 TPU 世代。该标签与 GKE Autopilot 中的枚举配置一致,详见 Plan TPUs on GKE 3(Validate TPU availability in GKE > Autopilot)。 - 两个独立节点: 2x4 拓扑配合
ct6e-standard-4t会得到 2 个 Kubernetes 节点 - 每个节点 4 颗芯片:
google.com/tpu: "4"表示每个节点可用的 TPU 资源数 - 相同拓扑标签: 两个节点都带有
cloud.google.com/gke-tpu-topology: 2x4
将工作负载分布到两个节点
要完整使用 2x4 拓扑(8 颗芯片),Job 必须将 Pod 调度到两个节点,所以在 Job 中设置 parallelism: 2 与 completions: 2(每个节点一个 Pod)。
调度过程如下:
- Job 创建 2 个 Pod: 使用
completionMode: Indexed,为每个 Pod 分配唯一索引(0 和 1) - 每个 Pod 请求 4 颗芯片: 在资源请求中设置
google.com/tpu: "4" - NodeSelector 匹配拓扑:
cloud.google.com/gke-tpu-topology: 2x4确保 Pod 落到正确的节点池 - 调度器分配 Pod: Kubernetes 调度器将每个 Pod 放到不同节点,满足资源需求

你可以通过以下命令验证 Pod 的落点:
kubectl get pods -o wide
你应该会看到:
NAME READY STATUS NODE
tpu-training-job-0-abc 1/1 Running gke-cluster-tpu-pool-a1b2c3d4-node-1
tpu-training-job-1-def 1/1 Running gke-cluster-tpu-pool-a1b2c3d4-node-2
-0 与 -1 的后缀来自 Indexed 完成模式,每个 Pod 运行在不同节点上,让你可以使用 2x4 拓扑的全部 8 颗芯片。
5. 查看训练作业
部署作业后,你可以监控进度并确认 TPU 资源是否被正确使用。下面检查作业状态并查看日志,确认一切正常。
检查 Job 状态
首先确认 Job 已创建并运行:
kubectl get jobs
你会看到类似输出:
NAME COMPLETIONS DURATION AGE
tpu-training-job 0/2 5s 5s
COMPLETIONS 显示 0/2,表示 2 个 Pod 中有 0 个已完成。待两个 Pod 成功完成后会变为 2/2。
查看 Pod 日志
查看每个 Pod 的输出:
kubectl logs tpu-training-job-0-<pod-suffix>
kubectl logs tpu-training-job-1-<pod-suffix>
你应该会看到 JAX 检测到 TPU 核心的输出:
TPU cores: 4
每个 Pod 看到 4 个 TPU 核心,因为它们分别运行在拥有 4 颗芯片的节点上。合起来就是 2x4 拓扑的 8 颗芯片。
常见问题排查
如果 Pod 一直停在 Pending,请查看事件:
kubectl describe pod tpu-training-job-0-<pod-suffix>
常见问题包括:
- TPU 资源不足: Flex-Start 节点池仍在预配中,可能需要几分钟
- 节点选择器不匹配: 确认 nodeSelector 标签与节点实际标签一致
- 资源配额超限: 检查项目的 TPU 配额限制
进阶:在单一主机启动 8 颗芯片(单 VM)
如果你希望更简化的部署,或想让所有 TPU 芯片都在单一 VM 上,可以使用提供 8 颗芯片的更大机型来配置 2x4 拓扑。
不要使用会创建 2 台 VM(每台 4 颗芯片)的 ct6e-standard-4t,改用 ct6e-standard-8t 获取单一 VM 的 8 颗芯片:
# 在 us-central2-b 的区域集群 "my-cluster" 创建 TPU 节点池
gcloud container node-pools create tpu-pool \
--cluster=my-cluster \
--location=us-central2-b \
--node-locations=us-central2-b \
--machine-type=ct6e-standard-8t \
--tpu-topology=2x4 \
--reservation-affinity=none \
--enable-autoscaling \
--num-nodes=0 \
--min-nodes=0 \
--max-nodes=1 \
--flex-start
主要差异:
- 机型: 使用
ct6e-standard-8t,而非ct6e-standard-4t - 最大节点数:
--max-nodes=1,因为只需要一台 VM - 单一 Kubernetes 节点:
kubectl get nodes只会看到一个节点 - 可用 8 颗芯片: 单个 Pod 即可访问全部 8 颗芯片
将 Job manifest 改为只运行一个 Pod:
apiVersion: batch/v1
kind: Job
metadata:
name: tpu-training-job-single-host
spec:
backoffLimit: 0
completions: 1
parallelism: 1
template:
spec:
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x4
containers:
- name: tpu-job
image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
securityContext:
privileged: true
command:
- bash
- -c
- |
python -c 'import jax; print("TPU cores:", jax.device_count())'
resources:
requests:
google.com/tpu: 8
limits:
google.com/tpu: 8
此配置更简单,原因包括:
- 只需管理一个 Pod,无需协调多个 Pod
- 不需要 headless Service 进行 Pod 间通信
- 更易排查问题,因为所有日志都集中在单个 Pod
不过,多 VM 方案(ct6e-standard-4t)在需要明确控制工作负载分布于不同 VM 时,能提供更灵活的分布式训练方式。
结论
本文完整走过在 GKE 上使用 TPU v6e 运行第一个 JAX 训练作业的流程:创建 GKE 集群、配置带 Flex-Start 的 TPU 节点池、部署训练作业,并理解 Kubernetes 如何在 TPU 节点之间调度工作负载。无论你选择多 VM 的分布式训练灵活性,还是单 VM 的简化部署,GKE 都能提供可扩展的基础设施,把 ML 工作负载从实验原型推向生产。
参考资料