From edfa1cacd6880ce58902ba81ee8b868b43ed5bbe Mon Sep 17 00:00:00 2001 From: Laz4rz <62252332+Laz4rz@users.noreply.github.com> Date: Fri, 7 Jun 2024 20:24:59 +0200 Subject: [PATCH 1/2] fix missing apostrophe in split string --- big_vision/configs/proj/uvim/vqvae_nyu_depth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/big_vision/configs/proj/uvim/vqvae_nyu_depth.py b/big_vision/configs/proj/uvim/vqvae_nyu_depth.py index d5ae293..75f030e 100644 --- a/big_vision/configs/proj/uvim/vqvae_nyu_depth.py +++ b/big_vision/configs/proj/uvim/vqvae_nyu_depth.py @@ -36,7 +36,7 @@ def get_config(arg='res=512,patch_size=16'): config.task = 'proj.uvim.depth_task' config.input = {} - config.input.data = dict(name='nyu_depth_v2', split='train) + config.input.data = dict(name='nyu_depth_v2', split='train') config.input.batch_size = 1024 config.input.shuffle_buffer_size = 25_000 @@ -141,4 +141,4 @@ def get_config(arg='res=512,patch_size=16'): config.evals.val.data.split = 'validation[:16]' config.evals.val.log_steps = 20 - return config \ No newline at end of file + return config From 0d92b798f880c51ada20a85deab067ada9916045 Mon Sep 17 00:00:00 2001 From: Laz4rz <62252332+Laz4rz@users.noreply.github.com> Date: Fri, 7 Jun 2024 20:27:20 +0200 Subject: [PATCH 2/2] fix intendation for DiscriminativeClassifierTest class --- .../discriminative_classifier_test.py | 324 +++++++++--------- 1 file changed, 162 insertions(+), 162 deletions(-) diff --git a/big_vision/evaluators/proj/image_text/discriminative_classifier_test.py b/big_vision/evaluators/proj/image_text/discriminative_classifier_test.py index 06f4fe2..d58ae5c 100644 --- a/big_vision/evaluators/proj/image_text/discriminative_classifier_test.py +++ b/big_vision/evaluators/proj/image_text/discriminative_classifier_test.py @@ -68,169 +68,169 @@ def z(x): class DiscriminativeClassifierTest(tf.test.TestCase): - - def test_prepare_datasets(self): - - def generator(): - yield { - "image": tf.ones([5, 5, 3], tf.float32), - "label": 1, - } - yield { - "image": tf.ones([4, 4, 3], tf.float32), - "label": 2, - } - - ds = tf.data.Dataset.from_generator( - generator, - output_signature={ - "image": tf.TensorSpec(shape=[None, None, 3], dtype=tf.float32), - "label": tf.TensorSpec(shape=[], dtype=tf.int64), - }) - class_names = [ - "class1,class1a", - "class2", - ] - prompt_templates = [ - "test {}", - "test {} test", - ] - ds_img, ds_txt = discriminative_classifier.prepare_datasets( - ds, - class_names, - prompt_templates=prompt_templates, - pp_img="resize(2)", - pp_txt="copy_from(labels='texts')", - ) - - it_img = iter(ds_img) - batch = next(it_img) - self.assertAllEqual(1, batch["label"]) - self.assertAllEqual(tf.ones([2, 2, 3]), batch["image"]) - batch = next(it_img) - self.assertAllEqual(2, batch["label"]) - self.assertAllEqual(tf.ones([2, 2, 3]), batch["image"]) - - it_txt = iter(ds_txt) - batch = next(it_txt) - self.assertAllEqual(0, batch["label"]) - self.assertAllEqual("test class1", batch["labels"]) - batch = next(it_txt) - self.assertAllEqual(0, batch["label"]) - self.assertAllEqual("test class1 test", batch["labels"]) - batch = next(it_txt) - self.assertAllEqual(0, batch["label"]) - self.assertAllEqual("test class1a", batch["labels"]) - batch = next(it_txt) - self.assertAllEqual(0, batch["label"]) - self.assertAllEqual("test class1a test", batch["labels"]) - batch = next(it_txt) - self.assertAllEqual(1, batch["label"]) - self.assertAllEqual("test class2", batch["labels"]) - batch = next(it_txt) - self.assertAllEqual(1, batch["label"]) - self.assertAllEqual("test class2 test", batch["labels"]) - - def test_average_embeddings(self): - self.assertAllEqual(jnp.array([ - [2.], [4.], [8.], - ]), discriminative_classifier._average_embeddings( - embeddings=jnp.array([ - 1., 3., 3., 1., # label1 - 8., 0., # label2 - 32., 0., 0., 0., # label3 - ])[..., None], - labels=jnp.array([ - 0, 0, # label1 - 0, 0, # label1 (alias) - 1, 1, # label2 - 2, 2, # label3 - 2, 2, # label3 (alias) - ], jnp.int32), - num_classes=3, normalize=False)) - self.assertAllEqual( - jnp.array([ - [2**-.5, 2**-.5], - ]), - discriminative_classifier._average_embeddings( - embeddings=jnp.array([[2., 2.]]), - labels=jnp.array([0], jnp.int32), - num_classes=1, - normalize=True)) - - @mock.patch("big_vision.evaluators.proj." - "image_text.prompt_engineering.get_class_names") - @mock.patch("big_vision.evaluators.proj." - "image_text.prompt_engineering.get_prompt_templates") - @mock.patch("big_vision.evaluators.proj." - "image_text.discriminative_classifier._get_dataset_info") - def test_evaluate(self, get_dataset_info_mock, get_prompt_templates_mock, - get_class_names_mock): - per_device_batch_size = 10 # Make sure we have some unfiltered examples. - global_batch_size = per_device_batch_size * jax.device_count() - per_host_num_examples = int( - np.ceil(global_batch_size / jax.process_count())) - splits = { - "test": - tfds.core.SplitInfo( - name="test", shard_lengths=[per_host_num_examples], num_bytes=0) - } - - model = _Model() - params = model.init(jax.random.PRNGKey(0), None, None)["params"] - - prompt_templates = [ - "test prompt 1 {}", - "test prompt 2 {}", - ] - class_names = [ - f"test_class_{i}" for i in range(10) - ] - - get_prompt_templates_mock.return_value = prompt_templates - get_class_names_mock.return_value = class_names - get_dataset_info_mock.return_value.splits = splits - - def pre_filter_fn(features): - return features["label"] < 5 # matches `texts %= 5` above - - dataset_name = "cifar10_test" - with tfds.testing.mock_data(num_examples=per_host_num_examples): - evaluator = discriminative_classifier.Evaluator( - lambda p, b: model.apply({"params": p}, - b.get("image", None), - b.get("labels", None)), - dataset_names=[dataset_name], - prompt_templates="test_prompts", - batch_size=global_batch_size, - devices=jax.devices(), - pp_img="copy_from(image='label')", - pp_txt="copy_from(labels='label')", - dataset_overrides={ - dataset_name: { - "dataset_name": "cifar10", - "class_names": "test_classes", - "pre_filter_fn": pre_filter_fn, - } - }, - first_class_name_only=True, + + def test_prepare_datasets(self): + + def generator(): + yield { + "image": tf.ones([5, 5, 3], tf.float32), + "label": 1, + } + yield { + "image": tf.ones([4, 4, 3], tf.float32), + "label": 2, + } + + ds = tf.data.Dataset.from_generator( + generator, + output_signature={ + "image": tf.TensorSpec(shape=[None, None, 3], dtype=tf.float32), + "label": tf.TensorSpec(shape=[], dtype=tf.int64), + }) + class_names = [ + "class1,class1a", + "class2", + ] + prompt_templates = [ + "test {}", + "test {} test", + ] + ds_img, ds_txt = discriminative_classifier.prepare_datasets( + ds, + class_names, + prompt_templates=prompt_templates, + pp_img="resize(2)", + pp_txt="copy_from(labels='texts')", ) - results = evaluator.evaluate( - params, - dataset_name, - return_embeddings=True) - metrics = dict(evaluator.run(params)) - - # Assert all examples were processed. - self.assertLen(results["texts"]["embedding"], - len(class_names) * len(prompt_templates)) - self.assertLen(results["texts"]["average_embedding"], len(class_names)) - self.assertAllEqual( - sorted(results["texts"]["label"]), - [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9]) - # Note that above model makes perfect predictions by design. - self.assertEqual(1.0, results["accuracy"]) - self.assertEqual(1.0, metrics[f"{dataset_name}_accuracy"]) + + it_img = iter(ds_img) + batch = next(it_img) + self.assertAllEqual(1, batch["label"]) + self.assertAllEqual(tf.ones([2, 2, 3]), batch["image"]) + batch = next(it_img) + self.assertAllEqual(2, batch["label"]) + self.assertAllEqual(tf.ones([2, 2, 3]), batch["image"]) + + it_txt = iter(ds_txt) + batch = next(it_txt) + self.assertAllEqual(0, batch["label"]) + self.assertAllEqual("test class1", batch["labels"]) + batch = next(it_txt) + self.assertAllEqual(0, batch["label"]) + self.assertAllEqual("test class1 test", batch["labels"]) + batch = next(it_txt) + self.assertAllEqual(0, batch["label"]) + self.assertAllEqual("test class1a", batch["labels"]) + batch = next(it_txt) + self.assertAllEqual(0, batch["label"]) + self.assertAllEqual("test class1a test", batch["labels"]) + batch = next(it_txt) + self.assertAllEqual(1, batch["label"]) + self.assertAllEqual("test class2", batch["labels"]) + batch = next(it_txt) + self.assertAllEqual(1, batch["label"]) + self.assertAllEqual("test class2 test", batch["labels"]) + + def test_average_embeddings(self): + self.assertAllEqual(jnp.array([ + [2.], [4.], [8.], + ]), discriminative_classifier._average_embeddings( + embeddings=jnp.array([ + 1., 3., 3., 1., # label1 + 8., 0., # label2 + 32., 0., 0., 0., # label3 + ])[..., None], + labels=jnp.array([ + 0, 0, # label1 + 0, 0, # label1 (alias) + 1, 1, # label2 + 2, 2, # label3 + 2, 2, # label3 (alias) + ], jnp.int32), + num_classes=3, normalize=False)) + self.assertAllEqual( + jnp.array([ + [2**-.5, 2**-.5], + ]), + discriminative_classifier._average_embeddings( + embeddings=jnp.array([[2., 2.]]), + labels=jnp.array([0], jnp.int32), + num_classes=1, + normalize=True)) + + @mock.patch("big_vision.evaluators.proj." + "image_text.prompt_engineering.get_class_names") + @mock.patch("big_vision.evaluators.proj." + "image_text.prompt_engineering.get_prompt_templates") + @mock.patch("big_vision.evaluators.proj." + "image_text.discriminative_classifier._get_dataset_info") + def test_evaluate(self, get_dataset_info_mock, get_prompt_templates_mock, + get_class_names_mock): + per_device_batch_size = 10 # Make sure we have some unfiltered examples. + global_batch_size = per_device_batch_size * jax.device_count() + per_host_num_examples = int( + np.ceil(global_batch_size / jax.process_count())) + splits = { + "test": + tfds.core.SplitInfo( + name="test", shard_lengths=[per_host_num_examples], num_bytes=0) + } + + model = _Model() + params = model.init(jax.random.PRNGKey(0), None, None)["params"] + + prompt_templates = [ + "test prompt 1 {}", + "test prompt 2 {}", + ] + class_names = [ + f"test_class_{i}" for i in range(10) + ] + + get_prompt_templates_mock.return_value = prompt_templates + get_class_names_mock.return_value = class_names + get_dataset_info_mock.return_value.splits = splits + + def pre_filter_fn(features): + return features["label"] < 5 # matches `texts %= 5` above + + dataset_name = "cifar10_test" + with tfds.testing.mock_data(num_examples=per_host_num_examples): + evaluator = discriminative_classifier.Evaluator( + lambda p, b: model.apply({"params": p}, + b.get("image", None), + b.get("labels", None)), + dataset_names=[dataset_name], + prompt_templates="test_prompts", + batch_size=global_batch_size, + devices=jax.devices(), + pp_img="copy_from(image='label')", + pp_txt="copy_from(labels='label')", + dataset_overrides={ + dataset_name: { + "dataset_name": "cifar10", + "class_names": "test_classes", + "pre_filter_fn": pre_filter_fn, + } + }, + first_class_name_only=True, + ) + results = evaluator.evaluate( + params, + dataset_name, + return_embeddings=True) + metrics = dict(evaluator.run(params)) + + # Assert all examples were processed. + self.assertLen(results["texts"]["embedding"], + len(class_names) * len(prompt_templates)) + self.assertLen(results["texts"]["average_embedding"], len(class_names)) + self.assertAllEqual( + sorted(results["texts"]["label"]), + [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9]) + # Note that above model makes perfect predictions by design. + self.assertEqual(1.0, results["accuracy"]) + self.assertEqual(1.0, metrics[f"{dataset_name}_accuracy"]) if __name__ == "__main__":