Releases: pytorch/xla
PyTorch/XLA 2.5.1: Readme update Release
PyTorch/XLA 2.5 Release
Cloud TPUs now support the Pytorch 2.5 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in the PyTorch 2.5 release, this release introduces several features, and PyTorch/XLA specific bug fixes.
Highlights
We are excited to announce the release of PyTorch XLA 2.5! PyTorch 2.5 supports torch_xla.compile
function which improves the debugging experience for developers during the development process, and aligns distributed APIs with upstream PyTorch with the traceable collective support for both Dynamo and non-Dynamo cases. Start from PyTorch/XLA 2.5, proposed a clarified vision for deprecation of the older torch_xla API in favor of moving towards the existing PyTorch API, providing for a simplified developer experience.
If youโve used vLLM for serving models on GPUs, youโll now be able to seamlessly switch to its TPU backend. vLLM is a widely adopted inference framework that also serves as an excellent way to drive accelerator interoperability. With vLLM on TPU, users will retain the same vLLM interface weโve grown to love, with direct integration with Hugging Face Models to make model experimentation easy.
STABLE FEATURES
Eager
- Increase max in flight operation to accommodate eager mode [#7263]
- Unify the logics to check eager mode [#7709]
- Update
eager.md
[#7710] - Optimize execution for ops that have multiple output in eager mode [#7680]
Quantization / Low Precision
- Asymmetric quantized
matmul
support [#7626] - Add blockwise quantized dot support [#7605]
- Support
int4
weight in quantized matmul / linear [#7235] - Support
fp8e5m2 dtype
[#7740] - Add
fp8e4m3fn
support [#7842] - Support dynamic activation quant for per-channel quantized matmul [#7867]
- Enable cross entropy loss for xla autocast with FP32 precision [#8094]
Pallas Kernels
- Support ab for
flash_attention
[#7840], actual kernel is implemented in JAX - Support
logits_soft_cap
parameter inpaged_attention
[#7704], actual kernel is implemented in JAX - Support
gmm
andtgmm trace_pallas
caching [#7921] - Cache flash attention tracing [#8026]
- Improve the user guide [#7625]
- Update pallas doc with
paged_attention
[#7591]
StableHLO
- Add user guide for stablehlo composite op [#7826]
gSPMD
- Handle the parameter wrapping for SPMD [#7604]
- Add helper function to get 1d mesh [#7577]
- Support manual
all-reduce
[#7576] - Expose
apply_backward_optimization_barrier
[#7477] - Support reduce-scatter in manual sharding [#7231]
- Allow
MpDeviceLoader
to shard dictionaries of tensor [#8202]
Dynamo
- Optimize dynamo dynamic shape caching [#7726]
- Add support for dynamic shape in dynamo [#7676]
- In dynamo optim_mode avoid unnecessary set_attr [#7915]
- Fix the crash with copy op in dynamo [#7902]
- Optimize
_split_xla_args_tensor_sym_constant
[#7900] - DYNAMO RNG seed update optimization [#7884]
- Support
mark_dynamic
[#7812] - Support gmm as a custom op for dynamo [#7672]
- Fix dynamo inplace copy [#7933]
- CPU time optimization for
GraphInputMatcher
[#7895]
PJRT
- Improve device auto-detection [#7787]
- Move _xla_register_custom_call_target implementation into PjRtComputationClient [#7801]
- Handle SPMD case inside of ComputationClient::WaitDeviceOps [#7796]
GKE
Functionalization
- Add 1-layer gradient accumulation test to check aliasing [#7692]
AMP
- Fix norm data-type when using AMP [#7878]
BETA FEATURES
Op Lowering
- Lower
aten::_linalg_eigh
[#7674] - Fallback
_embedding_bag_backward
and forcesparse=false
[#7584] - Support trilinear by using upstream decomp [#7586]
Higher order ops
- [Fori_loop] Update randint max range to Support bool dtype [#7632]
TorchBench Integration
- [benchmarks] API alignment with PyTorch profiler events [#7930]
- [benchmarks] Add IR dump option when run torchbench [#7927]
- [benchmarks] Use same
matmul
precision between PyTorch and PyTorch/XLA[#7748] - [benchmarks] Introduce verifier to verify the model output correctness against native pytorch [#7724, #7777]
- [benchmarks] Fix moco model issue on XLA [#7257, #7598]
- Type annotation for
benchmarks/
[#7289] - Default with
CUDAGraphs
on for inductor [#7749]
GPU
- Deprecate
XRT
forXLA:CUDA
[#8006]
EXPERIMENTAL FEATURES
Backward Compatibility & APIs that will be removed in 2.7 release:
- Deprecate APIs (deprecated โ new):
Deprecated New PRs xla_model.xrt_world_size()
runtime.world_size()
[#7679][#7743] xla_model.get_ordinal()
runtime.global_ordinal()
[#7679] xla_model.get_local_ordinal()
runtime.global_ordinal()
[#7679] - Internalize APIs
xla_model.parse_xla_device()
[#7675]
- Improvement
- Automatic PJRT device detection when importing
torch_xla
[#7787]
- Automatic PJRT device detection when importing
- Add deprecated decorator [#7703]
Distributed
Distributed API
We have aligned our distributed APIs with upstream PyTorch. Previously, we implemented custom distributed APIs, such as torch_xla.xla_model.all_reduce. With the traceable collective support, we now enable torch.distributed.all_reduce
and similar functions for both Dynamo and non-Dynamo cases in torch_xla
.
- Support of upstream distributed APIs (torch.distributed.*) like
all_reduce
,all_gather
,reduce_scatter_tensor
,all_to_all
. Previously we used xla specific distributed APIs in xla_model [#7860, #7950, #8064]. - Introduce
torch_xla.launch()
to launch the multiprocess in order to unify torchrun andtorch_xla.distributed.xla_multiprocessing.spawn()
[#7764, #7648, #7695]. torch.distributed.reduce_scatter_tensor()
: [#7950]- Register sdp lower precision autocast [#7299]
- Add Python binding for xla::DotGeneral [#7863]
- Fix input output alias for custom inplace ops [#7822]
torch_xla.compile
- Support
full_graph
which will error out if there will be more than one ...
PyTorch/XLA 2.4 Release
Cloud TPUs now support the Pytorch 2.4 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in the PyTorch 2.4 release, this release introduces several features, and PyTorch/XLA specific bug fixes.
๐ PyTorch/XLA 2.4 release delivers a 4% speedup boost (Geometric Mean) on torchbench evaluation benchmarks using openxla_eval
dynamo backend on TPUs, compared to the 2.3 release.
Highlights
We are excited to announce the release of PyTorch XLA 2.4! PyTorch 2.4 offers improved support for custom kernels using Pallas, including kernels like FlashAttention and Group Matrix Multiplication that can be used like any other torch operators and inference support for the PagedAttention kernel. We also add experimental support for eager mode that compiles and executes each operator for a better debugging and development experience.
Stable Features
PJRT
- Enable dynamic plugins by default #7270
GSPMD
- Support manual sharding and introduce high level manual sharding APIs #6915, #6931
- Support SPMDFullToShardShape, SPMDShardToFullShape #6922, #6925
Torch Compile
- Add a DynamoSyncInputExecuteTime counter #6813
- Fix runtime error when run dynamo with a profiler scope #6913
Export
- Add fx passes to support unbounded dynamism #6653
- Add dynamism support to conv1d, view, softmax #6653
- Add dynamism support to aten.embedding and aten.split_with_sizes #6781
- Inline all scalars by default in export path #6803
- Run shape propagation for inserted fx nodes #6805
- Add an option to not generate weights #6909
- Support export custom op to stablehlo custom call #7017
- Support array attribute in stablehlo composite #6840
- Add option to export FX Node metadata to StableHLO #7046
Beta Features
Pallas
- Support FlashAttention backward kernels #6870
- Make FlashAttention as torch.autograd.Function #6886
- Remove torch.empty in tracing to avoid allocating extra memory #6897
- Integrate FlashAttention with SPMD #6935
- Support scaling factor for attention weights in FlashAttention #7035
- Support segment ids in FlashAttention #6943
- Enable PagedAttention through Pallas #6912
- Properly support PagedAttention dynamo code path #7022
- Support megacore_mode in PagedAttention #7060
- Add Megablocksโ Group Matrix Multiplication kernel #6940, #7117, #7120, #7119, #7133, #7151
- Support histogram #7115, #7202
- Support tgmm #7137
- Make repeat_with_fixed_output_size not OOM on VMEM #7145
- Introduce GMM torch.autograd.function #7152
CoreAtenOpSet
- Lower embedding_bag_forward_only #6951
- Implement Repeat with fixed output shape #7114
- Add int8 per channel weight-only quantized matmul #7201
FSDP via SPMD
- Support multislice #7044
- Allow sharding on the maximal dimension of the weights #7134
- Apply optimization-barrier to all params and buffers during grad checkpointing #7206
Distributed Checkpoint
- Add optimizer priming for distributed checkpointing #6572
Usability
- Add xla.sync as a better name for mark_step. See #6399. #6914
- Add xla.step context manager to handle exceptions better. See #6751. #7068
- Implement ComputationClient::GetMemoryInfo for getting TPU memory allocation #7086
- Dump HLO HBM usage info #7085
- Add function for retrieving fallback operations #7116
- Deprecate XLA_USE_BF16 and XLA_USE_FP16 #7150
- Add PT_XLA_DEBUG_LEVEL to make it easier to distinguish between execution cause and compilation cause #7149
- Warn when using persistent cache with debug env vars #7175
- Add experimental MLIR debuginfo writer API #6799
GPU CUDA Fallback
- Add dlpack support #7025
- Make from_dlpack handle cuda synchronization implicitly for input tensors that have
__dlpack__
and__dlpack_device__
attributes. #7125
Distributed
- Switch all_reduce to use the new functional collective op #6887
- Allow user to configure distributed runtime service. #7204
- Use dest_offsets directly in LoadPlanner #7243
Experimental Features
Eager Mode
- Enable Eager mode for PyTorch/XLA #7611
- Support eager mode with torch.compile #7649
- Eagerly execute inplace ops in eager mode #7666
- Support eager mode for multi-process training #7668
- Handle random seed for eager mode #7669
- Enable SPMD with eager mode #7673
Triton
While Loop
- Prepare for torch while_loop signature change. #6872
- Implement fori_loop as a wrapper around while_loop #6850
- Complete fori_loop/while_loop and additional test case #7306
Bug Fixes and Improvements
- Fix type promotion for pow. (#6745)
- Fix vector norm lowering #6883
- Manually init absl log to avoid log spam #6890
- Fix pixel_shuffle return empty #6907
- Make nms fallback to CPU implementation by default #6933
- Fix torch.full scalar type #7010
- Handle multiple inplace update input output aliasing #7023
- Fix overflow for div arguments. #7081
- Add data_type promotion to gelu_backward, stack #7090, #7091
- Fix index of 0-element tensor by 0-element tensor #7113
- Fix output data-type for upsample_bilinear #7168
- Fix a data-type related problem for mul operation by converting inputs to result type #7130
- Make clip_grad_norm_ follow inputโs dtype #7205
PyTorch/XLA 2.3 Release Notes
Highlights
We are excited to announce the release of PyTorch XLA 2.3! PyTorch 2.3 offers experimental support for SPMD Auto Sharding on single TPU host, this allows user to shard their models on TPU with a single config change. We also add the experimental support for Pallas custom kernel for inference, which enables users to make use of the popular custom kernel like flash attention and paged attention on TPU.
Stable Features
PJRT
- Experimental GPU PJRT Plugin (#6240)
- Define PJRT plugin interface in C++ (#6360)
- Add limit to max inflight TPU computations (#6533)
- Remove TPU_C_API device type (#6435)
GSPMD
Torch Compile
- Support activation sharding within torch.compile (#6524)
- Do not cache FX input args in dynamo bridge to avoid memory leak (#6553)
- Ignore non-XLA nodes and their direct dependents. (#6170)
Export
- Support of implicit broadcasting with unbounded dynamism (#6219)
- Support multiple StableHLO Composite outputs (#6295)
- Add support of dynamism for add (#6443)
- Enable unbounded dynamism on conv, softmax, addmm, slice (#6494)
- Handle constant variable (#6510)
Beta Features
CoreAtenOpSet
Support all Core Aten Ops used by torch.export
- Lower reflection_pad1d, reflection_pad1d_backward, reflection_pad3d and reflection_pad3d_backward (#6588)
- lower replication_pad3d and replication_pad3d_backward (#6566)
- Lower the embedding op (#6495)
- Lowering for _pdist_forward (#6507)
- Support mixed precision for torch.where (#6303)
Benchmark
- Unify PyTorch/XLA and Pytorch torchbench model configuration using the same torchbench.yaml (#6881)
- Align model data precision settings with pytorch HUD (#6447, #6518, #6555)
- Fix some torchbench models configuration to make it runnable using XLA (#6509, #6542, #6558, #6612).
FSDP via SPMD
Distributed Checkpoint
Usability
GPU
- Fix global_device_count(), local_device_count() for single process on CUDA(#6022)
- Automatically use XLA:GPU if on a GPU machine (#6605)
- Add SPMD on GPU instructions (#6684)
- Build XLA:GPU as a separate Plugin (#6825)
Distributed
Experimental Features
Pallas
- Introduce Flash Attention kernel using Pallas (#6827)
- Support Flash Attention kernel with casual mask (#6837)
- Support Flash Attention kernel with
torch.compile
(#6875) - Support Pallas kernel (#6340)
- Support programmatically extracting the payload from Pallas kernel (#6696)
- Support Pallas kernel with
torch.compile
(#6477) - Introduce helper to convert Pallas kernel to PyTorch/XLA callable (#6713)
GSPMD Auto-Sharding
Input Output Aliasing
- Support torch.compile for
dynamo_set_buffer_donor
- Use XLAโs new API to alias graph input and output (#6855)
While Loop
Bug Fixes and Improvements
- Propagates requires_grad over to AllReduce output (#6326)
- Avoid fallback for avg_pool (#6409)
- Fix output tensor shape for argmin and argmax where keepdim=True and dim=None (#6536)
- Fix preserve_rng_state for activation checkpointing (#4690)
- Allow int data-type for Embedding indices (#6718)
- Don't terminate the whole process when Compile fails (#6707)
- Fix a incorrect assert on frame count for PT_XLA_DEBUG=1 (#6466)
- Refactor nms into TorchVision variant.(#6814)
PyTorch/XLA 2.2 Release Notes
Cloud TPUs now support the PyTorch 2.2 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in the PyTorch 2.2 release, this release introduces several features, and PyTorch/XLA specific bug fixes.
Installing PyTorch and PyTorch/XLA 2.2.0 wheel:
pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html
Please note that you might have to re-install the libtpu on your TPUVM depending on your previous installation:
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
- Note: If you meet the error
RuntimeError: operator torchvision::nms does not exist
when using torchvision in the 2.2.0 docker image, please try the following command to fix the issue:
pip uninstall torch -y; pip install torch==2.2.0
Stable Features
PJRT
PJRT_DEVICE=GPU
has been renamed toPJRT_DEVICE=CUDA
(#5754).PJRT_DEVICE=GPU
will be removed in the 2.3 release.
- Optimize Host to Device transfer (#5772) and device to host transfer (#5825).
- Miscellaneous low-level refactoring and performance improvements (#5799, #5737, #5794, #5793, #5546).
Beta Features
GSPMD
- Support DTensor API integration and move GSPMD out of experimental (#5776).
- Enable debug visualization func
visualize_tensor_sharding
(#5742), added doc. - Support
mark_shard
scalar tensors (#6158). - Add
apply_backward_optimization_barrier
(#6157).
Export
- Handled lifted constants in torch export (#6111).
- Run decomp before processing (#5713).
- Support export to
tf.saved_model
for models with unused params (#5694). - Add an option to not save the weights (#5964).
- Experimental support for dynamic dimension sizes in torch export to StableHLO (#5790, openxla/xla#6897).
CoreAtenOpSet
- PyTorch/XLA aims to support all PyTorch core ATen ops in the 2.3 release. Weโre actively working on this, remaining issues to be closed can be found at issue list.
Benchmark
- Support of benchmark running automation and metric report analysis on both TPU and GPU (doc).
Experimental Features
FSDP via SPMD
- Introduce FSDP via SPMD, or FSDPv2 (#6187). The RFC can be found (#6379).
- Add FSDPv2 user guide (#6386).
Distributed Op
Persistent Compilation
- Enable persistent compilation caching (#6065).
- Document and introduce
xr.initialize_cache
python API (#6046).
Checkpointing
- Support auto checkpointing for TPU preemption (#5753).
- Support Async checkpointing through CheckpointManager (#5697).
Usability
Quantization
- Lower quant/dequant torch op to StableHLO (#5763).
GPU
Bug Fixes and Improvements
- Pow precision issue (#6103).
- Handle negative dim for Diagonal Scatter (#6123).
- Fix
as_strided
for inputs smaller than the arguments specification (#5914). - Fix squeeze op lowering issue when dim is not in sorted order (#5751).
- Optimize RNG seed dtype for better memory utilization (#5710).
Lowering
_prelu_kernel_backward
(#5724).
PyTorch/XLA 2.1 Release
Cloud TPUs now support the PyTorch 2.1 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in the PyTorch 2.1 release, this release introduces several features, and PyTorch/XLA specific bug fixes.
PJRT is now PyTorch/XLA's officially supported runtime! PJRT brings improved performance, superior usability, and broader device support. PyTorch/XLA r2.1 will be the last release with XRT available as a legacy runtime. Our main release build will not include XRT, but it will be available in a separate package. In most cases, we expect the migration to PJRT to require minimal changes. For more information, see our PJRT documentation.
GSPMD support has been added as an experimental feature to the PyTorch/XLA 2.1 release. GSPMD will transform the single device program into a partitioned one with proper collectives, based on the user provided sharding hints. This feature allows developers to write PyTorch programs as if they are on a single large device without any custom sharded computation ops and/or collective communications to scale. We published a blog post explaining the technical details and expected usage, you can also find more detail in this user guide.
PyTorch/XLA has transitioned from depending on TensorFlow to depending on the new OpenXLA repo. This allows us to reduce our binary size and simplify our build system. Starting from 2.1, PyTorch/XLA will release our TPU whl on the pypi.
To install PyTorch/XLA 2.1.0 wheels, please find the installation instructions below.
Installing PyTorch and PyTorch/XLA 2.1.0 wheel:
pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
Please note that you might have to re-install the libtpu on your TPUVM depending on your previous installation:
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
Stable Features
OpenXLA
- Migrate to pull XLA from TensorFlow to OpenXLA, TF pin dependency sunset (#5202)
- Instructions to build PyTorch/XLA with OpenXLA can be found in this doc.
PjRt Runtime
- Move PJRT APIs from experimental to
torch_xla.runtime
(#5011) - Enable PJRT C API Client and other changes for Neuron (#5428)
- Enable PJRT C API Client for Intel XPU (#4891)
- Change pjrt:// init method to xla:// (#5560)
- Make TPU detection more robust (#5271)
- Add runtime.host_index (#5283)
Functionalization
Improvements and additions
- Op Lowering
- Build System
- Migrate the build system to Bazel (#4528)
Beta Features
AMP (Automatic MIxed Precision)
TorchDynamo
- Support CPU egaer fallback in Dynamo bridge (#5000)
- Support
torch.compile
with SPMD for inference (#5002) - Update the dynamo backend name to
openxla
andopenxla_eval
(#5402) - Inference optimization for SPMD inference +
torch.compile
(#5447, #5446)
Traceable Collectives
Experimental Features
GSPMD
- Add SPMD user guide
- Enable Input-output aliasing (#5320)
- Introduce
global_runtime_device_count
to query the runtime device count (#5129) - Support partial replication (#5411 )
- Support tuple partition spec (#5488)
- Support mark_sharding on IRs (#5301)
- Make IR sharding custom sharding op (#5433)
- Introduce Hybrid Device mesh creation (#5147)
- Introduce SPMD-friendly patched nn.Linear (#5491)
- Allow dumping post optimizations HLO (#5302)
- Allow sharding n-d tensor on (n+1)-d Mesh (#5268)
- Support synchronous distributed checkpointing (#5130, #5170)
Serving Support
- SavedModel
- Added a script stablehlo-to-saved-model (#5493)
- docs:https://github.com/pytorch/xla/blob/r2.1/docs/stablehlo.md#convert-saved-stablehlo-for-serving
StableHLO
- Add StableHLO user guide (#5523)
- Add save_as_stablehlo and save_torch_model_as_stablehlo APIs (#5493)
- Make StableHLO executable (#5476)
Ongoing Development
TorchDynamo
- Enable single step graph for training
- Avoid inter-graph reshapes from aot_autograd
- Support GSPMD for activation checkpointing
GSPMD
- Support auto-sharding
- Benchmark and improving GSPMD for XLA:GPU
- Integrating to PyTorchโs Distributed Tensor API
GPU
- Support Multi-host GPU for PJRT runtime
- Improve performance on torchbench models
Quantization
- Support PyTorch PT2E quantization workflow
Bug Fixes and Improvements
PyTorch/XLA 2.0 release
Cloud TPUs now support the PyTorch 2.0 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in PyTorch's 2.0 release, this release introduces several features, and PyTorch/XLA specific bug fixes.
Beta Features
PJRT runtime
- Checkout our newest document; PjRt is the default runtime in 2.0.
- New Implementation of xm.rendezvous with XLA collective communication which scales better (#4181)
- New PJRT TPU backend through the C-API (#4077)
- Use PJRT to default if no runtime is configured (#4599)
- Experimental support for torch.distributed and DDP on TPU v2 and v3 (#4520)
FSDP
- Add auto_wrap_policy into XLA FSDP for automatic wrapping (#4318)
Stable Features
Lazy Tensor Core Migration
- Migration is completed, checkout this dev discussion for more detail.
- Naively inherits LazyTensor (#4271)
- Adopt even more LazyTensor interfaces (#4317)
- Introduce XLAGraphExecutor (#4270)
- Inherits LazyGraphExecutor (#4296)
- Adopt more LazyGraphExecutor virtual interfaces (#4314)
- Rollback to use xla::Shape instead of torch::lazy::Shape (#4111)
- Use TORCH_LAZY_COUNTER/METRIC (#4208)
Improvements & Additions
- Add an option to increase the worker thread efficiency for data loading (#4727)
- Improve numerical stability of torch.sigmoid (#4311)
- Add an api to clear counter and metrics (#4109)
- Add met.short_metrics_report to display more concise metrics report (#4148)
- Document environment variables (#4273)
- Op Lowering
Experimental Features
TorchDynamo (torch.compile) support
- Checkout our newest doc.
- Dynamo bridge python binding (#4119)
- Dynamo bridge backend implementation (#4523)
- Training optimization: make execution async (#4425)
- Training optimization: reduce graph execution per step (#4523)
PyTorch/XLA GSPMD on single host
- Preserve parameter sharding with sharded data placeholder (#4721)
- Transfer shards from server to host (#4508)
- Store the sharding annotation within XLATensor(#4390)
- Use d2d replication for more efficient input sharding (#4336)
- Mesh to support custom device order. (#4162)
- Introduce virtual SPMD device to avoid unpartitioned data transfer (#4091)
Ongoing development
Ongoing Dynamic Shape implementation
- Implement missing
XLASymNodeImpl::Sub
(#4551) - Make empty_symint support dynamism. (#4550)
- Add dynamic shape support to SigmoidBackward (#4322)
- Add a forward pass NN model with dynamism test (#4256)
Ongoing SPMD multi host execution (#4573)
Bug fixes & improvements
PyTorch/XLA 1.13 release
Cloud TPUs now support the PyTorch 1.13 release, via PyTorch/XLA integration. The release has daily automated testing for the supported models: Torchvision ResNet, FairSeq Transformer and RoBERTa, HuggingFace GLUE and LM, and Facebook Research DLRM.
On top of the underlying improvements and bug fixes in PyTorch's 1.13 release, this release adds several features and PyTorch/XLA specified bug fixes.
New Features
- GPU enhancement
- FSDP enhancement
- Lower torch::einsum using xla::einsum which provide significant speedup (#3843)
- Support large models with >3200 graph input on TPU + PJRT (#3920)
Experimental Features
- PJRT experimental support on Cloud TPU v4
- Check the instruction and example code in here
- DDP experimental support on Cloud TPU and GPU
- Check the instruction, analysis and example code in here
Ongoing development
- Ongoing Dynamic Shape implementation (POC completed)
- Ongoing SPMD implementation (POC completed)
- Ongoing LTC migration
Bug fixes and improvements
- Make XLA_HLO_DEBUG populate the scope metadata (#3985)
PyTorch/XLA 1.12 release
Cloud TPUs now support the PyTorch 1.12 release, via PyTorch/XLA integration. The release has daily automated testing for the supported models: Torchvision ResNet, FairSeq Transformer and RoBERTa, HuggingFace GLUE and LM, and Facebook Research DLRM.
On top of the underlying improvements and bug fixes in PyTorch's 1.12 release, this release adds several features and PyTorch/XLA specified bug fixes.
New feature
- FSDP
- PyTorch/XLA gradident checkpoint api (#3524)
- Optimization_barrier which enables gradient checkpointing (#3482)
- Ongoing LTC migration
- Device lock position optimization to speed up tracing (#3457)
- Experimental support for PJRT TPU client (#3550)
- Send/Recv CC op support (#3494)
- Performance profiling tool enhancement (#3498)
- TPU-V4 pod official support (#3440)
- Roll lowering (#3505)
- Celu, celu_, selu, selu_ lowering (#3547)
Bug fixes and improvements
- Fixed a view bug which will create unnecessary IR graph (#3411)
PyTorch/XLA 1.11 release
Cloud TPUs now support the PyTorch 1.11 release, via PyTorch/XLA integration. The release has daily automated testing for the supported models: Torchvision ResNet, FairSeq Transformer and RoBERTa, HuggingFace GLUE and LM, and Facebook Research DLRM.
On top of the underlying improvements and bug fixes in PyTorch's 1.11 release, this release adds several features and PyTorch/XLA specified bug fixes.
New feature
- Enable asynchronous RNG seed sending by environment variable
XLA_TRANSFER_SEED_ASYNC
- Add a native torch.distributed backend
- Introduce a Eager debug mode by environment variable
XLA_USE_EAGER_DEBUG_MODE
- Add synchronous free Adam and AdamW optimizers for PyTorch/XLA:GPU AMP
- Add synchronous free SGD optimizers for PyTorch/XLA:GPU AMP
- linspace lowering
- mish lowering
- prelu lowering
- slogdet lowering
- stable sort lowering
- index_add with alpha scaling lowering
Bug fixes && improvements
- Improve
torch.var
performance and numerical stability on TPU - Improve
torch.pow
performance - Fix the incorrect output dtype when divide a f32 by a f64
- Fix the incorrect result of
nll_loss
when reduction = "mean" and whole target is equal to ignore_index