PyTorch runs on XLA devices, like TPUs, with the torch_xla package. This document describes how to run your models on these devices.
PyTorch/XLA adds a new xla
device type to PyTorch. This device type works just
like other PyTorch device types. For example, here's how to create and
print an XLA tensor:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)
This code should look familiar. PyTorch/XLA uses the same interface as regular
PyTorch with a few additions. Importing torch_xla
initializes PyTorch/XLA, and
xm.xla_device()
returns the current XLA device. This may be a CPU or TPU
depending on your environment.
PyTorch operations can be performed on XLA tensors just like CPU or CUDA tensors.
For example, XLA tensors can be added together:
t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)
Or matrix multiplied:
print(t0.mm(t1))
Or used with neural network modules:
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)
Like other device types, XLA tensors only work with other XLA tensors on the same device. So code like
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor
will throw an error since the torch.nn.Linear
module is on the CPU.
Building a new PyTorch network or converting an existing one to run on XLA devices requires only a few lines of XLA-specific code. The following snippets highlight these lines when running on a single device and multiple devices with XLA multi-processing.
The following snippet shows a network training on a single XLA device:
import torch_xla.core.xla_model as xm
device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in train_loader:
optimizer.zero_grad()
data = data.to(device)
target = target.to(device)
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
This snippet highlights how easy it is to switch your model to run on XLA. The
model definition, dataloader, optimizer and training loop can work on any device.
The only XLA-specific code is a couple lines that acquire the XLA device and
mark the step. Calling
xm.mark_step()
at the end of each training
iteration causes XLA to execute its current graph and update the model's
parameters. See XLA Tensor Deep Dive for more on
how XLA creates graphs and runs operations.
PyTorch/XLA makes it easy to accelerate training by running on multiple XLA devices. The following snippet shows how:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
def _mp_fn(index):
device = xm.xla_device()
mp_device_loader = pl.MpDeviceLoader(train_loader, device)
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
for data, target in mp_device_loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
There are three differences between this multi-device snippet and the previous single device snippet:
xmp.spawn()
creates the processes that each run an XLA device.MpDeviceLoader
loads the training data onto each device.xm.optimizer_step(optimizer)
consolidates the gradients between cores and issues the XLA device step computation.
The model definition, optimizer definition and training loop remain the same.
NOTE: It is important to note that, when using multi-processing, the user can start retrieving and accessing XLA devices only from within the target function of
xmp.spawn()
(or any function which hasxmp.spawn()
as parent in the call stack).
See the full multiprocessing example for more on training a network on multiple XLA devices with multi-processing.
Using XLA tensors and devices requires changing only a few lines of code. But even though XLA tensors act a lot like CPU and CUDA tensors, their internals are different. This section describes what makes XLA tensors unique.
CPU and CUDA tensors launch operations immediately or eagerly. XLA tensors, on the other hand, are lazy. They record operations in a graph until the results are needed. Deferring execution like this lets XLA optimize it. A graph of multiple separate operations might be fused into a single optimized operation, for example.
Lazy execution is generally invisible to the caller. PyTorch/XLA automatically constructs the graphs, sends them to XLA devices, and synchronizes when copying data between an XLA device and the CPU. Inserting a barrier when taking an optimizer step explicitly synchronizes the CPU and the XLA device. For more information about our lazy tensor design, you can read this paper.
PyTorch/XLA can use the
bfloat16
datatype when running on TPUs. In fact, PyTorch/XLA handles float types
(torch.float
and torch.double
) differently on TPUs. This behavior is
controlled by the XLA_USE_BF16
and XLA_DOWNCAST_BF16
environment variable:
- By default both
torch.float
andtorch.double
aretorch.float
on TPUs. - If
XLA_USE_BF16
is set, thentorch.float
andtorch.double
are bothbfloat16
on TPUs. - If
XLA_DOWNCAST_BF16
is set, thentorch.float
isbfloat16
on TPUs andtorch.double
isfloat32
on TPUs. - If a PyTorch tensor has
torch.bfloat16
data type, this will be directly mapped to the TPUbfloat16
(XLABF16
primitive type).
Developers should note that XLA tensors on TPUs will always report their PyTorch datatype regardless of the actual datatype they're using. This conversion is automatic and opaque. If an XLA tensor on a TPU is moved back to the CPU it will be converted from its actual datatype to its PyTorch datatype. Depending on how your code operates, this conversion triggered by the type of processing unit can be important.
The internal data representation of XLA tensors is opaque to the user. They do not expose their storage and they always appear to be contiguous, unlike CPU and CUDA tensors. This allows XLA to adjust a tensor's memory layout for better performance.
XLA tensors can be moved from the CPU to an XLA device and from an XLA device to the CPU. If a view is moved then the data its viewing is also copied to the other device and the view relationship is not preserved. Put another way, once data is copied to another device it has no relationship with its previous device or any tensors on it. Again, depending on how your code operates, appreciating and accommodating this transition can be important.
XLA tensors should be moved to the CPU before saving, as in the following snippet:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)
tensors = (t0.cpu(), t1.cpu())
torch.save(tensors, 'tensors.pt')
tensors = torch.load('tensors.pt')
t0 = tensors[0].to(device)
t1 = tensors[1].to(device)
This lets you put the loaded tensors on any available device, not just the one on which they were initialized.
Per the above note on moving XLA tensors to the CPU, care must be taken when working with views. Instead of saving views it is recommended that you recreate them after the tensors have been loaded and moved to their destination device(s).
A utility API is provided to save data by taking care of previously moving it to CPU:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
xm.save(model.state_dict(), path)
In case of multiple devices, the above API will only save the data for the master device ordinal (0).
In case where memory is limited compared to the size of the model parameters, an API is provided that reduces the memory footprint on the host:
import torch_xla.utils.serialization as xser
xser.save(model.state_dict(), path)
This API streams XLA tensors to CPU one at a time, reducing the amount of host memory used, but it requires a matching load API to restore:
import torch_xla.utils.serialization as xser
state_dict = xser.load(path)
model.load_state_dict(state_dict)
Directly saving XLA tensors is possible but not recommended. XLA tensors are always loaded back to the device they were saved from, and if that device is unavailable the load will fail. PyTorch/XLA, like all of PyTorch, is under active development and this behavior may change in the future.
Additional documentation is available at the PyTorch/XLA repo. More examples of running networks on TPUs are available here.