diff --git a/big_vision/configs/proj/jet/imagenet64.py b/big_vision/configs/proj/jet/imagenet64.py index f4de148..6293193 100644 --- a/big_vision/configs/proj/jet/imagenet64.py +++ b/big_vision/configs/proj/jet/imagenet64.py @@ -13,7 +13,7 @@ # limitations under the License. # pytype: disable=attribute-error,line-too-long -r"""JetFormer for imagenet64. +r"""Jet config for imagenet64. Expected values in imagenet64 (200 epochs): - 32 couplings and block depth 2: 3.72 bpd diff --git a/big_vision/models/proj/jet/jet.py b/big_vision/models/proj/jet/jet.py index 6e6a6a0..8fd60fe 100644 --- a/big_vision/models/proj/jet/jet.py +++ b/big_vision/models/proj/jet/jet.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Jet: A Modern Transformer-Based Normalizing Flow.""" +"""Jet: A Modern Transformer-Based Normalizing Flow. + +https://arxiv.org/abs/2412.15129 +""" import itertools from typing import Any, Sequence @@ -29,7 +32,7 @@ class DNN(nn.Module): - """Main non-invertible compute block, used inside coupling layers.""" + """Main non-invertible compute block with a ViT used in coupling layers.""" depth: int = 1 emb_dim: int = 256 diff --git a/big_vision/trainers/proj/jet/train.py b/big_vision/trainers/proj/jet/train.py index 812f9ff..cc287d9 100644 --- a/big_vision/trainers/proj/jet/train.py +++ b/big_vision/trainers/proj/jet/train.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Training loop example. - -This is a basic variant of a training loop, good starting point for fancy ones. -""" +"""Training loop for Jet.""" # pylint: disable=consider-using-from-import # pylint: disable=logging-fstring-interpolation