Skip to content

Commit

Permalink
Fix PaliGemma intructions and big_vision deps.
Browse files Browse the repository at this point in the history
  • Loading branch information
andresusanopinto authored and andsteing committed Dec 6, 2024
1 parent 8e9b05b commit f347091
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 40 deletions.
5 changes: 2 additions & 3 deletions big_vision/configs/proj/paligemma/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,16 +255,15 @@ export KAGGLE_USERNAME=
export KAGGLE_KEY=
# See https://www.kaggle.com/models/google/paligemma-2 for a full list of models.
export MODEL_NAME=paligemma-2
export CKPT_FILE=paligemma2-3b-pt-224.npz.b16
export MODEL_NAME=paligemma2-3b-pt-224
mkdir ckpts/
cd ckpts/
# Store as a "vanity name" from models/proj/paligemma/paligemma.py
curl -L -u $KAGGLE_USERNAME:$KAGGLE_KEY\
-o pt_3b_224.bf16.npz \
https://www.kaggle.com/api/v1/models/google/paligemma-2/jax/$MODEL_NAME/1/download/$CKPT_FILE
https://www.kaggle.com/api/v1/models/google/paligemma-2/jax/$MODEL_NAME/1/download/$MODEL_NAME.b16.npz
```

As an example, we provide the `forkme.py` config that is based on the easily-adjustable jsonl data source:
Expand Down
38 changes: 1 addition & 37 deletions big_vision/datasets/sequence_packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from typing import Dict, Optional, List, Union

from flax import traverse_util
import grain.tensorflow as tf_grain
import tensorflow as tf

AUTOTUNE = tf.data.experimental.AUTOTUNE
Expand Down Expand Up @@ -75,39 +74,4 @@ def pack_dataset(
Returns:
A `tf.data.Dataset`.
"""
def _maybe_join(k):
if isinstance(k, int):
k = (k,)
return FLATTEN_SEPARATOR.join(k)

if isinstance(key2length, int):
key2length = {_maybe_join(k): key2length for k in keys}
else:
key2length = dict(key2length) # Make new dict, we'll edit in-place.

def _add_fake_index(x):
x = dict(x)
x[tf_grain.INDEX] = -1
return x

def _flatten_dict(x):
return traverse_util.flatten_dict(x, sep=FLATTEN_SEPARATOR)

def _unflatten_dict(x):
return traverse_util.unflatten_dict(x, sep=FLATTEN_SEPARATOR)

def _remove_index(x):
x = dict(x)
x.pop(tf_grain.INDEX)
return x

pack_op = tf_grain.TfBatchAndPack(
batch_size=batch_size or 1,
sequence_lengths=_flatten_dict(key2length))

dataset = dataset.map(_add_fake_index, num_parallel_calls=AUTOTUNE)
dataset = dataset.map(_flatten_dict, num_parallel_calls=AUTOTUNE)
dataset = pack_op.apply_to_dataset(dataset)
dataset = dataset.map(_unflatten_dict, num_parallel_calls=AUTOTUNE)
dataset = dataset.map(_remove_index, num_parallel_calls=AUTOTUNE)
return dataset.unbatch()
raise ValueError("Not implemented in OSS yet.")

0 comments on commit f347091

Please sign in to comment.