GKE AI Series: First Training Job With JAX on TPUs

GKE AI Series: First Training Job With JAX on TPUs

Training large machine learning models requires specialized hardware. Google’s Tensor Processing Units (TPUs) are among the best chips for this work. However, getting started with TPUs can be challenging—you need to understand chip layouts, configure node pools, and manage resources. It’s easy to get lost after browsing countless documents and troubleshooting errors for hours.

This series documents my learning journey and explains TPUs in plain language. I assume you’re familiar with basic GCP usage (the gcloud command) and Kubernetes concepts like Pods, Deployments, Nodes, and resource management. We’ll use the following terms throughout.

Terminologies

  • TPU: Google’s specialized AI accelerator chip for machine learning. TPUs excel at matrix operations, making them faster than CPUs or GPUs for neural network training and inference. TPU v6e (Trillium) is the one of the available in GKE. TPUs are organized into “slices” connected via ICI (Inter-Chip Interconnect) for multi-chip training.
  • GKE: Google Kubernetes Engine—Google Cloud’s managed Kubernetes service for deploying and scaling containerized applications. For TPU workloads, GKE offers TPU node pools, topology awareness, and integration with scheduling tools like Kueue. It’s recommended for its fault tolerance and simplified configuration.
  • JAX: A Python library developed by Google for high-performance numerical computing and machine learning. It combines NumPy’s API with automatic differentiation and XLA compilation, making it ideal for TPU workloads. JAX automatically parallelizes computations across multiple devices, essential for multi-chip TPU configurations.

Understanding TPU Topology (Simplified)

The confusion: TPU topology sounds complex, but it’s simply about how chips are arranged. We’ll use TPU v6e (Trillium) as our example—it’s easier to understand with its 2D layout and offers excellent performance for training workloads.

Think of it this way:

  • A chip = one processing unit (like a single brain)
  • A topology = how chips are arranged in rows × columns (like seats in a theater)
  • 2x2 = 2 rows × 2 columns = 4 chips total

Common TPU v6e configurations:

Topology TPU Chips Hosts VMs Machine type (GKE API) Scope
1x1 1 1/8 1 ct6e-standard-1t Sub-host
2x2 4 1/2 1 ct6e-standard-4t Sub-host
2x4 8 1 1 ct6e-standard-8t Single-host
2x4 8 1 2 ct6e-standard-4t Single-host

Let’s start with the basics. A TPU chip is a specialized processor designed for machine learning workloads. When you hear “2x2 topology”, that simply means 4 chips arranged in a 2-row by 2-column grid. Physically, these chips may not form an actual 2x2 grid—they could be distributed across different locations within the host machine. What matters is the logical topology: how the TPU runtime organizes and addresses these 4 chips for distributed computation.

Below is a tray for TPU v6e (Trillium) deployed in Google datacenters. Each tray contains 4 chips and functions as an individual compute unit for machine learning jobs (source 1):

In reality, multiple trays can be connected in the datacenter to form larger slices for distributed training workloads. The diagram below shows chips and trays, where each blue block represents a chip. Four chips form one tray:

Knowing how these chips are separated can better help you understand topology form:

  • In a 1x1 configuration, you’re using only 1 out of 8 chips on a single host. This is the smallest TPU configuration, ideal for development, testing, or running lightweight inference workloads.
  • In a 2x2 configuration, you’re using only 4 out of 8 chips (highlighted in blue). This sub-host setup is perfect for smaller experiments or development work without renting the full machine.
  • In a 2x4 configuration, you’re using all 8 chips on a single host machine (In a 2x4 configuration, you’re using all 8 chips on a single host machine (this may physically span two trays).

The following diagram showing how 8 TPU chips (represented as blue blocks) are arranged in different topologies.

How to use TPU in GKE

1. De-mystifying TPU (Uses v6e as example)

Before we touch any code, let’s understand how TPU rental actually works. Since this post focuses on using TPUs with GKE, I’ll start from the GKE perspective to explain TPU topology and how resources are allocated across nodes.

To understand how this works, I’ll use GKE Standard and its related machine types as examples.

The Machine: ct6e-standard-4t

In GKE, TPUs are attached to Virtual Machines (Nodes). The most accessible entry point for Trillium (v6e) is the ct6e-standard-4t machine type.

  • 1 Node (VM) = 1, 4, or 8 TPU Chips. Each TPU v6e VM can contain 1, 4, or 8 chips. By default, a VM manages 4 chips unless you specify otherwise.
  • The “All-or-Nothing” Rule: Unlike CPUs where you can request 0.5 cores, with TPUs on GKE, you generally must request all the chips on the node. You cannot request just 1 chip. If you rent the 4-chip machine, your Pod must request 4 chips.

GKE machine types that support TPU resources follow a naming convention: <version>-hightpu-<chip-count>t . The chip count indicates how many TPU chips are available to that VM. For example:

  • ct6e-standard-4t means 4 TPU chips
  • ct6e-standard-8t means 8 TPU chips
  • tpu7x-standard-4t supports Ironwood (TPU7x) and contains four TPU chips

For a complete list of supported machine types in GKE, see Plan TPUs in GKE 2.

Understanding the Difference: Same Machine Type, Different Topologies

Let’s revisit the table:

Topology TPU chips Hosts VMs Machine type (GKE API) Scope
2x2 4 1/2 1 ct6e-standard-4t Sub-host
2x4 8 1 2 ct6e-standard-4t Single-host

You might notice in the table above that ct6e-standard-4t appears twice with different topologies (2x2 and 2x4). This can be confusing for beginners, so let’s break it down.

The Key Insight: The machine type (ct6e-standard-4t) tells you how many chips your VM has access to, but the topology tells you how those chips are logically arranged for distributed computation.

Each VM sees its allocated chips (typically 4), allowing you to write programs that access and coordinate across all chips within that allocation.

Example 1: ct6e-standard-4t with 2x2 Topology

  • What you get: 1 VM with access to 4 TPU chips
  • Physical reality: These 4 chips are part of a single host machine (which has 8 chips total), but you’re only renting half of it
  • Logical arrangement: The TPU runtime treats these as a 2×2 grid (2 rows, 2 columns)
  • Use case: Smaller training jobs

Example 2: ct6e-standard-4t with 2x4 Topology (Multi-VMs)

  • What you get: 2 VMs, each with access to 4 TPU chips (8 chips total)
  • Physical reality: These 8 chips come from a single physical host machine, but they’re exposed as 2 separate VMs for orchestration purposes
  • Logical arrangement: The TPU runtime treats these as a 2×4 grid (2 rows, 4 columns)
  • Use case: Larger training jobs

In a 2x2 configuration, your VM’s program can communicate with all 4 chips. In a 2x4 multi-VMs setup, each VM can access its own 4 chips, and the TPU runtime handles coordination between the two VMs to present them as a unified 8-chip system.

Why does this matter?

In GKE, each VM is a separate GKE node. For the 2x4 multi-host setup, you need to request resources across 2 nodes, not just 1.

This table shows how different topologies map to machine types and the number of GKE nodes your workload needs.

Topology Machine type Total TPU chips TPU hosts Scope TPU VMs How many GKE nodes I needed?
1x1 ct6e-standard-1t 1 1/4 Sub-host 1 1
2x2 ct6e-standard-4t 4 1/2 Sub-host 1 1
2x4 ct6e-standard-4t 8 1 Single-host 2 2
2x4 ct6e-standard-8t 8 1 Single-host 1 1

Understanding this mapping is essential when configuring your node pools, as you’ll need to ensure your cluster has enough nodes to accommodate your chosen topology.

2. Setting Up the GKE Cluster

Now that we’ve covered the fundamentals of TPU topology and machine types, let’s walk through the actual deployment process step by step.

We’ll use a 2x4 topology (8 chips) to run a small training job. First, we’ll create a standard GKE cluster and add a specialized node pool for our TPUs.

CLUSTER_NAME=tpu-quickstart         # Name of the GKE cluster
CONTROL_PLANE_LOCATION=us-central1  # Ensure you use region have v6e supported, see: https://docs.cloud.google.com/tpu/docs/regions-zones

gcloud container clusters create $CLUSTER_NAME --location=$CONTROL_PLANE_LOCATION 

3. Create a node pool with flex-start

Next, we’ll add a node pool specifically configured for a 2x4 TPU workloads. This node pool will use the Flex-Start provisioning model to ensure we can secure TPU capacity even when demand is high.

NODE_POOL_NAME=nodepool   # Name of the GKE node pool
NODE_ZONES=us-central1-b  # Ensure you use zone have v6e supported, see: https://docs.cloud.google.com/tpu/docs/regions-zones

gcloud container node-pools create $NODE_POOL_NAME \
    --cluster=$CLUSTER_NAME \
    --location=$CONTROL_PLANE_LOCATION \
    --node-locations=$NODE_ZONES \
    --machine-type=ct6e-standard-4t \
    --tpu-topology=2x4 \
    --reservation-affinity=none \
    --enable-autoscaling \
    --num-nodes=0 --min-nodes=0 --max-nodes=2 \
    --flex-start

The --flex-start flag is crucial here. It tells GKE to use the Flex-Start provisioning model, which queues your node pool creation request and provisions resources when available, rather than failing immediately if capacity is unavailable.

Let’s break down the key parameters:

  • --tpu-topology=2x4: Specifies the logical arrangement of TPU chips (8 chips total in a 2×4 grid)
  • --machine-type=ct6e-standard-4t: Each node will have access to 4 TPU chips
  • --enable-autoscaling with --num-nodes=0: Starts with zero nodes and scales up only when workloads are scheduled
  • --max-nodes=2: For a 2×4 topology, you need 2 VMs (nodes), so set max-nodes accordingly

Note: --max-nodes=2 is required for a 2x4 topology to acquire 8 chips. If you use --max-nodes=1 with ct6e-standard-4t, it won’t meet the topology requirements and will only acquire 4 chips - this becomes a 2x2 topology instead.

The Provisioning Model: Flex-Start

As TPU is limited resource, deployment failure due to resource unavailable (when Google runs out of machines) are a common pain point. Flex-Start is a specialized mode that helps avoiding this.

Feature Standard On-Demand Flex-Start
Cost Standard Price Discounted (typically ~40% off)
Availability Immediate (if available) Higher probability (you join a queue)

4. Run a batch workload

Now that we have our GKE cluster and node pool configured, let’s deploy a simple training workload.

We’ll create a Job that schedules a TPU node with Flex-Start VMs. A Job controller in Kubernetes creates one or more Pods and ensures they successfully complete a specific task.

Here’s a simple YAML manifest that demonstrates how to deploy a training job on our TPU node pool:

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-training-job
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-training-job
spec:
  backoffLimit: 0
  completions: 2
  parallelism: 2
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
          cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
          cloud.google.com/gke-tpu-topology: 2x4
      containers:
      - name: tpu-job
        image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          python -c 'import jax; print("TPU cores:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

When you create a node pool with a 2x4 topology using ct6e-standard-4t, GKE provisions 2 separate nodes (VMs), each with 4 TPU chips.

The google.com/tpu: "4" resource request tells GKE to schedule the job on nodes with 4 TPU chips, matching our 2x4 topology. With parallelism: 2, two jobs run in parallel—one on each node. When you apply this manifest, GKE’s autoscaler automatically provisions the required nodes from your flex-start node pool if they aren’t already running.

Understanding the Kubernetes View: Node Scheduling

Let’s see how this looks from Kubernetes’ perspective. Run the following command to view your nodes:

kubectl get nodes -o wide

You’ll see output similar to this:

NAME                                      STATUS   ROLES    AGE   VERSION
gke-cluster-tpu-pool-a1b2c3d4-node-1      Ready    <none>   5m    v1.28.3-gke.1234
gke-cluster-tpu-pool-a1b2c3d4-node-2      Ready    <none>   5m    v1.28.3-gke.1234

To see the TPU-specific labels on these nodes, use:

kubectl get nodes -l cloud.google.com/gke-tpu-topology=2x4 -o json | jq '.items[].metadata.labels'

This will show labels like:

{
  "cloud.google.com/gke-tpu-accelerator": "tpu-v6e-slice",
  "cloud.google.com/gke-tpu-topology": "2x4",
  "google.com/tpu": "4",
  ...
}

Key observations:

  • Accelerator label: cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice identifies the TPU hardware generation. This label is an enum configuration same as GKE Autopilot, documented in Plan TPUs on GKE 3 (Validate TPU availability in GKE > Autopilot).
  • Two separate nodes: For a 2x4 topology with ct6e-standard-4t, you get 2 Kubernetes nodes
  • Each node has 4 chips: The google.com/tpu: "4" label indicates the available TPU resources per node
  • Same topology label: Both nodes share the cloud.google.com/gke-tpu-topology: 2x4 label

Scheduling Your Workload Across Two Nodes

To utilize the full 2x4 topology (8 chips), your Job must schedule Pods on both nodes. This is why we set parallelism: 2 and completions: 2 in the Job manifest: one Pod per node.

Here’s how the scheduling works:

  1. Job creates 2 Pods: The Job controller creates 2 Pods with completionMode: Indexed, assigning each a unique index (0 and 1)
  2. Each Pod requests 4 chips: google.com/tpu: "4" in the resource request
  3. NodeSelector matches topology: The cloud.google.com/gke-tpu-topology: 2x4 nodeSelector ensures Pods land on the correct node pool
  4. Scheduler distributes Pods: Kubernetes scheduler places one Pod on each node, satisfying the resource requirements

You can verify the Pod placement with:

kubectl get pods -o wide

You should see:

NAME                      READY   STATUS    NODE
tpu-training-job-0-abc    1/1     Running   gke-cluster-tpu-pool-a1b2c3d4-node-1
tpu-training-job-1-def    1/1     Running   gke-cluster-tpu-pool-a1b2c3d4-node-2

The -0 and -1 suffixes come from the Indexed completion mode, and each Pod runs on a different node, giving you access to all 8 chips in your 2x4 topology.

5. Review Training Job

After deploying the job, you can monitor its progress and verify that the TPU resources are being utilized correctly. Let’s check the job status and examine the logs to confirm everything is working as expected.

Check Job Status

First, verify that your job has been created and is running:

kubectl get jobs

You should see output like:

NAME                COMPLETIONS   DURATION   AGE
tpu-training-job    0/2           5s         5s

The COMPLETIONS field shows 0/2, meaning 0 out of 2 pods have completed. Once both pods finish successfully, it will show 2/2.

View Pod Logs

To see the output from your training job, check the logs of each pod:

kubectl logs tpu-training-job-0-<pod-suffix>
kubectl logs tpu-training-job-1-<pod-suffix>

You should see output confirming that JAX has detected the TPU cores:

TPU cores: 4

Each pod sees 4 TPU cores because each runs on a node with 4 chips. Together, they form the complete 2x4 topology with 8 total chips.

Troubleshooting Common Issues

If your pods remain in Pending state, check the events:

kubectl describe pod tpu-training-job-0-<pod-suffix>

Common issues include:

  • Insufficient TPU resources: The flex-start node pool is still provisioning nodes. This can take several minutes.
  • Node selector mismatch: Verify that your nodeSelector labels match the actual node labels in your node pool.
  • Resource quota exceeded: Check your project’s TPU quota limits.

Bonus: Launch 8 Chips in a Single Host (Single VM)

For simpler deployments or when you want all TPU chips on a single VM, you can configure a 2x4 topology using a larger machine type that provides 8 chips in one node.

Instead of using ct6e-standard-4t which creates 2 separate VMs with 4 chips each, use ct6e-standard-8t to get all 8 chips on a single VM:

# Create TPU node pool for zonal cluster "my-cluster" in us-central2-b
gcloud container node-pools create tpu-pool \
    --cluster=my-cluster \
    --location=us-central2-b \
    --node-locations=us-central2-b \
    --machine-type=ct6e-standard-8t \
    --tpu-topology=2x4 \
    --reservation-affinity=none \
    --enable-autoscaling \
    --num-nodes=0 \
    --min-nodes=0 \
    --max-nodes=1 \
    --flex-start

Key differences:

  • Machine type: ct6e-standard-8t instead of ct6e-standard-4t
  • Max nodes: --max-nodes=1 since you only need one VM
  • Single Kubernetes node: You’ll see only one node in kubectl get nodes
  • All 8 chips accessible: Your workload can access all 8 chips from a single Pod

Update your Job manifest to run a single Pod instead of two:

apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-training-job-single-host
spec:
  backoffLimit: 0
  completions: 1
  parallelism: 1
  template:
    spec:
      restartPolicy: Never
      nodeSelector:
          cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
          cloud.google.com/gke-tpu-topology: 2x4
      containers:
      - name: tpu-job
        image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          python -c 'import jax; print("TPU cores:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 8
          limits:
            google.com/tpu: 8

This configuration is simpler because:

  • You only need to manage one Pod instead of coordinating across multiple Pods
  • No need for a headless Service for inter-Pod communication
  • Easier debugging since all logs are in a single Pod

However, the multi-VM approach (ct6e-standard-4t) offers better flexibility for distributed training patterns when you need explicit control over workload distribution across VMs.

Conclusion

In this post, we walked through the complete process of running your first JAX training job on GKE with TPU v6e. We covered creating a GKE cluster, configuring a flex-start node pool with TPU topology, deploying a training job, and understanding how Kubernetes schedules workloads across TPU nodes. Whether you choose a multi-VM setup for distributed training flexibility or a single-VM configuration for simplicity, GKE provides the infrastructure to scale your ML workloads from experimental prototypes to production deployments.

References

Eason Cao
Eason Cao Eason is an engineer working at FANNG and living in Europe. He was accredited as AWS Professional Solution Architect, AWS Professional DevOps Engineer and CNCF Certified Kubernetes Administrator. He started his Kubernetes journey in 2017 and enjoys solving real-world business problems.
comments powered by Disqus