Published on

Practical TPU VM on GCP

Authors
  • Acknowledgement : Google Cloud credits are provided for this project :: #AISprint

Setting up a Practical TPU VM on GCP

Kaggle competitions very generously offer the use of TPUv5-8 instances. However, the pool of available TPUs is apparently rather limited ('a few dozen' according to Twitter), which leads to unpredictable waiting times in a queue for allocation. Moreover, while the maximum run-time per session is 9 hours (seems fine), the maximum run-time per week is 20 hours - which is a little limiting.

So : In order to make the development cycle more predictable, I wanted to set up a workflow on GCP that could be directly transferrable to Kaggle, but benefit from:

  • Faster instance creation
  • Build the instance cleanly from code
  • Manage code via GitHub repo
  • Run on both TPUv5-1 (for code sanity testing) and TPUv5-8 (like Kaggle itself)
  • Be 'understandable' instead of being a black box

This is a quick guide on the process I came up with, the actual code/scripts are also available in my 'Aha' repo

Finally, TPUv5-1 and TPUv5-8 can each be driven by a single CPU Host machine. Going bigger than -8 requires a more complex set-up, and I'm not covering that here.

Lesson 1 : Start with a Clean Project

Something was deeply borked with my existing 'kitchen sink' GCP project - with all kinds of indecipherable errors.

SOLUTION: Just start a new project on GCP (connected to the same billing account). The quotas will be immediately fine for experimentation, and something is silently fixed.

Lesson 2 : TPUs come with temporary disks

Typically, my go-to solution for a utility machine would involve mounting an extra drive with all the messy details on it. But the unavailability of TPUs in different regions meant that this would have been a forever moving target. Moreover, the TPU documentation seems to suggest that 'building from scratch, and disposing of' was also a sensible plan.

So my set-up involves building the VM from scratch each time - i.e. it is fully self-contained. There's a lot of space on the TPU's host machine (>88GB of HD, >380GB of RAM), so it has plenty of capacity...

Lesson 3 : GCP documentation is very scattered

To start working on your local machine, install gcloud so that you can run :

gcloud auth login
gcloud config set project <PROJECT_ID>
gcloud services enable tpu.googleapis.com

Set up a Blank Project (I only discovered that this was one of the unspoken rules of Google Cloud after wasting a lot of time), and give it permissions to access the "Cloud TPU API" (GCP documentation).

To check whether you have quotas for TPUs, look in your GCP console for 'Quotas' (under IAM & Admin / Quotas), and search for "tpu" "v5". With a blank project, I found I had quota for 32 TPUv5e everywhere (YMMV).

Then (following the GCP documentation), you need to :

  • Add essential IAM Roles to \<YOUR_GCP_USER\> :
    • Your account needs specific permissions to create and manage TPUs. The most critical role to ADD is TPU Admin.
      • Service Account Admin: Needed to create a service account
      • Project IAM Admin: Needed to grant a role in a project
      • TPU Admin: Needed to create a TPU
    • Create a Cloud TPU service agent:
      • gcloud beta services identity create --service tpu.googleapis.com

The other stuff (extra services accounts, etc) were not required to get this working initially. Perhaps their importance will be revealed later, but doing those steps was not required to actually get my hands on a TPU instance.

Set up local variables:

export TPU_NAME="kaggle-tpu-paid"
export TPU_TYPE="v5litepod-8"
export TPU_TYPE="v5litepod-1"  #  Lower cost while testing out startup scripts
export TPU_SOFTWARE="v2-alpha-tpuv5-lite"

# There are lots of zones - please find your own, and don't starve this one...
export TPU_ZONE="us-west1-c"        # $1.20 per hour (Oregon) : WORKED!

Check what TPUs are available in that zone:

gcloud compute tpus accelerator-types list --zone=${TPU_ZONE}
gcloud compute tpus tpu-vm accelerator-types describe ${TPU_TYPE} --zone=${TPU_ZONE}

Lesson 4 : TPUs come with outdated software

The startup-script is key

While TPUs are bootable, they require updates in order to get recent versions of JAX, and other frameworks. In particular, their Python 3.10 needs updating. That means that the startup-script is going to do a lot of heavy lifting - it's listed below.

Several things to note about this startup-script, it:

  • does not upgrade system Python, since this is a big headache
    • it's easier to do it per virtual environment
  • installs uv via snap (rather than a cursed curl | sh)
  • runs mostly as a newly created tpu_user
  • provides/runs Jupyter Lab which is accessed via ssh tunnel (see Lesson 8)
  • installs in stages (which uv makes really fast)
    • When you actually log in, you can see logs of installation progress in *pip*.log files
  • has careful escaping of different variables in the tpu_user EOF here-document
#!/bin/bash

# NB: Not using actual GCP username here : 
TPU_USER=tpu_user

JUPYTER_USER=${TPU_USER}
JUPYTER_PORT=8585


# Ensure the user exists (replace 'your_username' with the actual username)
if ! id ${TPU_USER} &>/dev/null; then
  useradd -m ${TPU_USER}
fi

# Nicer way of getting 'uv' than curl|sh ...
snap install astral-uv --classic

# Switch to 'tpu_user' and execute commands
sudo -u ${TPU_USER} bash << EOF
  cd /home/${TPU_USER}

  # https://docs.astral.sh/uv/guides/install-python/
  uv venv --python 3.12 ./env-tpu  # Auto-installs correct python version

  source ./env-tpu/bin/activate
  uv pip freeze | sort > 0-pip-freeze.log  # NOTHING!

  # Install Jupyter and necessary packages
  uv pip install jupyter jupyterlab jupytext "tqdm[notebook]" ipywidgets
  uv pip freeze | sort > 1-pip-freeze_with_jupyter.log

  uv pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  # 0.6.2 (with Python 3.10 installed == BAD CHOICE)
  # 0.8.1 (with Python 3.12 installed)
  uv pip freeze | sort > 2-pip-freeze_with_jax.log

  uv pip install git+https://github.com/google/flax.git
  uv pip freeze | sort > 3-pip-freeze_with_flax.log

  uv pip install git+https://github.com/google/tunix git+https://github.com/google/qwix
  uv pip freeze | sort > 4-pip-freeze_with_tunix-qwix.log
  # This one also gives us kagglehub and dotenv!

  uv pip install seaborn pandas 
  uv pip freeze | sort > 5-pip-freeze_with_misc.log


  # https://docs.cloud.google.com/compute/docs/instances/startup-scripts/linux#accessing-metadata
  JUPYTER_TOKEN=\$(curl http://metadata.google.internal/computeMetadata/v1/instance/attributes/JUPYTER_TOKEN -H "Metadata-Flavor: Google")

  # curl http://metadata.google.internal/computeMetadata/v1/instance/ -H "Metadata-Flavor: Google"
  #   gives us back a bunch of information in (apparently) a nice structure
  # eg: .../instance/machine-type -> "projects/714NNNNNNN0/machineTypes/n2d-48-24-v5lite-tpu"

  # Start JupyterLab server in the background as the user
  nohup jupyter lab --no-browser --ip=0.0.0.0 --port=${JUPYTER_PORT} --ServerApp.token=\${JUPYTER_TOKEN} --allow-root &

EOF

# Does nothing ( just here as a placeholder / comment )
cat << EOF
  
EOF

Lesson 5 : Launch with an alert when it's done...

To launch the TPU instance, using the startup-script below, run :

TPU_SECRETS="foo=barbar,JUPYTER_TOKEN=''" # security-via-GCP-auth
gcloud compute tpus tpu-vm create ${TPU_NAME} \
  --zone=${TPU_ZONE} \
  --accelerator-type=${TPU_TYPE} \
  --version=${TPU_SOFTWARE} \
  --metadata-from-file=startup-script=startup_script.sh \
  --metadata=${TPU_SECRETS} \
  && ./bell_tpu service-login

Note that bell_tpu is a simple local script that plays a system sound to let you know that the TPUs are now online (otherwise, Google will happily start charging you unless you watch the terminal like a hawk: startup can take several minutes).

Lesson 6 : DESTROYING THE INSTANCE

Google examples don't highlight this...

  • MAKE SURE TO DO THIS ONCE YOU HAVE FINISHED!
    • keeping this close to the tpu-vm create command just to be sure it isn't missed

This must be done on the local machine when you're finished, to avoid GCP silently charging you for the instance...

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --zone=${TPU_ZONE} \
   --quiet \
   && ./bell_tpu service-logout \
   && gcloud compute tpus tpu-vm list

Lesson 7 : Running git requires ssh keys

(and other auth stuff)

The following should be run locally and deals with the creation of keys, so that the key.pub can be given to GitHub (for instance) just once, and your own local key information isn't being sent up to GCP.

The following creates a tpu_user key pair, and pushes them up to the instance. It also securely copies up a local configuration file, which can be executed by dotenv in the Jupyter notebook to get that configuration done quickly.

mkdir -p ./tpu_ssh/
TPU_KEY_PATH="./tpu_ssh/id_ed25519"
if [ ! -f ${TPU_KEY_PATH} ]; then  # Do this only once for stability
  ssh-keygen -t ed25519 -f ${TPU_KEY_PATH} -C "TPU-machine" -N ""
  echo "Created new ssh keys : Upload ${TPU_KEY_PATH}.pub to github"
fi

gcloud compute tpus tpu-vm scp --zone=${TPU_ZONE} \
  ${TPU_KEY_PATH}* tpu_user@${TPU_NAME}:~/.ssh/
# Actually, this works!  Because the first connection to the 
#   tpu-vm propagates SSH public keys first

TPU_DOTENV="./tpu_dotenv/dotenv"
gcloud compute tpus tpu-vm scp --zone=${TPU_ZONE} \
  ${TPU_DOTENV} tpu_user@${TPU_NAME}:~/.env

./bell_tpu

The following is a template for the ./tpu_dotenv/dotenv file:

# Need a Kaggle username...
KAGGLE_USERNAME="YOUR_KAGGLE_USERNAME"

# This is from "Legacy API Credentials"
#   NB: this isn't the same as the KAGGLE_API_KEY thing
KAGGLE_KEY="YOUR_KAGGLE_API_KEY"

Lesson 8 : Finally log into the TPU host!

The following will also set up the port forwarding for Jupyter Lab:

gcloud compute tpus tpu-vm ssh tpu_user@${TPU_NAME} --zone=${TPU_ZONE} \
  -- -L 8585:localhost:8585

Lesson 9 : Securely access GitHub

On the TPU host machine...

The following pre-populates known_hosts with the hash taken from GitHub itself to avoid being asked to check.

TPU_REPO_USER="mdda"                 # Your GitHub username here...
TPU_REPO="getting-to-aha-with-tpus"  # And your Repo

touch ~/.ssh/known_hosts
echo "github.com ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIOMqqnkVzrm0SdG6UOoqKLsabgH5C9okWi0dh2l9GKJl" >> ~/.ssh/known_hosts
git clone git@github.com:${TPU_REPO_USER}/${TPU_REPO}.git
cd ${TPU_REPO}

git config --global user.email "TPU-Machine@example.com"
git config --global user.name "TPU-Machine"

Lesson 10 : Have fun with TPUs!

The set-up has done the following:

  • Installed current versions of the JAX/nnx software
  • Given you secure JupyterLab access to your own GitHub repo
  • Enabled you to write-back repo updates and download data/outputs

Note that once the machine is terminated (and importantly killed, as in Lesson 6 above), all data on the instance will be gone : Make sure to git push your changes, and download your model results, etc

AND REMEMBER TO KILL THE TPU INSTANCE (see Lesson 6 above)