Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Syntax and Indentation errors #112

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions big_vision/configs/proj/uvim/vqvae_nyu_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_config(arg='res=512,patch_size=16'):
config.task = 'proj.uvim.depth_task'

config.input = {}
config.input.data = dict(name='nyu_depth_v2', split='train)
config.input.data = dict(name='nyu_depth_v2', split='train')

config.input.batch_size = 1024
config.input.shuffle_buffer_size = 25_000
Expand Down Expand Up @@ -141,4 +141,4 @@ def get_config(arg='res=512,patch_size=16'):
config.evals.val.data.split = 'validation[:16]'
config.evals.val.log_steps = 20

return config
return config
324 changes: 162 additions & 162 deletions big_vision/evaluators/proj/image_text/discriminative_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,169 +68,169 @@ def z(x):


class DiscriminativeClassifierTest(tf.test.TestCase):

def test_prepare_datasets(self):

def generator():
yield {
"image": tf.ones([5, 5, 3], tf.float32),
"label": 1,
}
yield {
"image": tf.ones([4, 4, 3], tf.float32),
"label": 2,
}

ds = tf.data.Dataset.from_generator(
generator,
output_signature={
"image": tf.TensorSpec(shape=[None, None, 3], dtype=tf.float32),
"label": tf.TensorSpec(shape=[], dtype=tf.int64),
})
class_names = [
"class1,class1a",
"class2",
]
prompt_templates = [
"test {}",
"test {} test",
]
ds_img, ds_txt = discriminative_classifier.prepare_datasets(
ds,
class_names,
prompt_templates=prompt_templates,
pp_img="resize(2)",
pp_txt="copy_from(labels='texts')",
)

it_img = iter(ds_img)
batch = next(it_img)
self.assertAllEqual(1, batch["label"])
self.assertAllEqual(tf.ones([2, 2, 3]), batch["image"])
batch = next(it_img)
self.assertAllEqual(2, batch["label"])
self.assertAllEqual(tf.ones([2, 2, 3]), batch["image"])

it_txt = iter(ds_txt)
batch = next(it_txt)
self.assertAllEqual(0, batch["label"])
self.assertAllEqual("test class1", batch["labels"])
batch = next(it_txt)
self.assertAllEqual(0, batch["label"])
self.assertAllEqual("test class1 test", batch["labels"])
batch = next(it_txt)
self.assertAllEqual(0, batch["label"])
self.assertAllEqual("test class1a", batch["labels"])
batch = next(it_txt)
self.assertAllEqual(0, batch["label"])
self.assertAllEqual("test class1a test", batch["labels"])
batch = next(it_txt)
self.assertAllEqual(1, batch["label"])
self.assertAllEqual("test class2", batch["labels"])
batch = next(it_txt)
self.assertAllEqual(1, batch["label"])
self.assertAllEqual("test class2 test", batch["labels"])

def test_average_embeddings(self):
self.assertAllEqual(jnp.array([
[2.], [4.], [8.],
]), discriminative_classifier._average_embeddings(
embeddings=jnp.array([
1., 3., 3., 1., # label1
8., 0., # label2
32., 0., 0., 0., # label3
])[..., None],
labels=jnp.array([
0, 0, # label1
0, 0, # label1 (alias)
1, 1, # label2
2, 2, # label3
2, 2, # label3 (alias)
], jnp.int32),
num_classes=3, normalize=False))
self.assertAllEqual(
jnp.array([
[2**-.5, 2**-.5],
]),
discriminative_classifier._average_embeddings(
embeddings=jnp.array([[2., 2.]]),
labels=jnp.array([0], jnp.int32),
num_classes=1,
normalize=True))

@mock.patch("big_vision.evaluators.proj."
"image_text.prompt_engineering.get_class_names")
@mock.patch("big_vision.evaluators.proj."
"image_text.prompt_engineering.get_prompt_templates")
@mock.patch("big_vision.evaluators.proj."
"image_text.discriminative_classifier._get_dataset_info")
def test_evaluate(self, get_dataset_info_mock, get_prompt_templates_mock,
get_class_names_mock):
per_device_batch_size = 10 # Make sure we have some unfiltered examples.
global_batch_size = per_device_batch_size * jax.device_count()
per_host_num_examples = int(
np.ceil(global_batch_size / jax.process_count()))
splits = {
"test":
tfds.core.SplitInfo(
name="test", shard_lengths=[per_host_num_examples], num_bytes=0)
}

model = _Model()
params = model.init(jax.random.PRNGKey(0), None, None)["params"]

prompt_templates = [
"test prompt 1 {}",
"test prompt 2 {}",
]
class_names = [
f"test_class_{i}" for i in range(10)
]

get_prompt_templates_mock.return_value = prompt_templates
get_class_names_mock.return_value = class_names
get_dataset_info_mock.return_value.splits = splits

def pre_filter_fn(features):
return features["label"] < 5 # matches `texts %= 5` above

dataset_name = "cifar10_test"
with tfds.testing.mock_data(num_examples=per_host_num_examples):
evaluator = discriminative_classifier.Evaluator(
lambda p, b: model.apply({"params": p},
b.get("image", None),
b.get("labels", None)),
dataset_names=[dataset_name],
prompt_templates="test_prompts",
batch_size=global_batch_size,
devices=jax.devices(),
pp_img="copy_from(image='label')",
pp_txt="copy_from(labels='label')",
dataset_overrides={
dataset_name: {
"dataset_name": "cifar10",
"class_names": "test_classes",
"pre_filter_fn": pre_filter_fn,
}
},
first_class_name_only=True,

def test_prepare_datasets(self):

def generator():
yield {
"image": tf.ones([5, 5, 3], tf.float32),
"label": 1,
}
yield {
"image": tf.ones([4, 4, 3], tf.float32),
"label": 2,
}

ds = tf.data.Dataset.from_generator(
generator,
output_signature={
"image": tf.TensorSpec(shape=[None, None, 3], dtype=tf.float32),
"label": tf.TensorSpec(shape=[], dtype=tf.int64),
})
class_names = [
"class1,class1a",
"class2",
]
prompt_templates = [
"test {}",
"test {} test",
]
ds_img, ds_txt = discriminative_classifier.prepare_datasets(
ds,
class_names,
prompt_templates=prompt_templates,
pp_img="resize(2)",
pp_txt="copy_from(labels='texts')",
)
results = evaluator.evaluate(
params,
dataset_name,
return_embeddings=True)
metrics = dict(evaluator.run(params))

# Assert all examples were processed.
self.assertLen(results["texts"]["embedding"],
len(class_names) * len(prompt_templates))
self.assertLen(results["texts"]["average_embedding"], len(class_names))
self.assertAllEqual(
sorted(results["texts"]["label"]),
[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9])
# Note that above model makes perfect predictions by design.
self.assertEqual(1.0, results["accuracy"])
self.assertEqual(1.0, metrics[f"{dataset_name}_accuracy"])

it_img = iter(ds_img)
batch = next(it_img)
self.assertAllEqual(1, batch["label"])
self.assertAllEqual(tf.ones([2, 2, 3]), batch["image"])
batch = next(it_img)
self.assertAllEqual(2, batch["label"])
self.assertAllEqual(tf.ones([2, 2, 3]), batch["image"])

it_txt = iter(ds_txt)
batch = next(it_txt)
self.assertAllEqual(0, batch["label"])
self.assertAllEqual("test class1", batch["labels"])
batch = next(it_txt)
self.assertAllEqual(0, batch["label"])
self.assertAllEqual("test class1 test", batch["labels"])
batch = next(it_txt)
self.assertAllEqual(0, batch["label"])
self.assertAllEqual("test class1a", batch["labels"])
batch = next(it_txt)
self.assertAllEqual(0, batch["label"])
self.assertAllEqual("test class1a test", batch["labels"])
batch = next(it_txt)
self.assertAllEqual(1, batch["label"])
self.assertAllEqual("test class2", batch["labels"])
batch = next(it_txt)
self.assertAllEqual(1, batch["label"])
self.assertAllEqual("test class2 test", batch["labels"])

def test_average_embeddings(self):
self.assertAllEqual(jnp.array([
[2.], [4.], [8.],
]), discriminative_classifier._average_embeddings(
embeddings=jnp.array([
1., 3., 3., 1., # label1
8., 0., # label2
32., 0., 0., 0., # label3
])[..., None],
labels=jnp.array([
0, 0, # label1
0, 0, # label1 (alias)
1, 1, # label2
2, 2, # label3
2, 2, # label3 (alias)
], jnp.int32),
num_classes=3, normalize=False))
self.assertAllEqual(
jnp.array([
[2**-.5, 2**-.5],
]),
discriminative_classifier._average_embeddings(
embeddings=jnp.array([[2., 2.]]),
labels=jnp.array([0], jnp.int32),
num_classes=1,
normalize=True))

@mock.patch("big_vision.evaluators.proj."
"image_text.prompt_engineering.get_class_names")
@mock.patch("big_vision.evaluators.proj."
"image_text.prompt_engineering.get_prompt_templates")
@mock.patch("big_vision.evaluators.proj."
"image_text.discriminative_classifier._get_dataset_info")
def test_evaluate(self, get_dataset_info_mock, get_prompt_templates_mock,
get_class_names_mock):
per_device_batch_size = 10 # Make sure we have some unfiltered examples.
global_batch_size = per_device_batch_size * jax.device_count()
per_host_num_examples = int(
np.ceil(global_batch_size / jax.process_count()))
splits = {
"test":
tfds.core.SplitInfo(
name="test", shard_lengths=[per_host_num_examples], num_bytes=0)
}

model = _Model()
params = model.init(jax.random.PRNGKey(0), None, None)["params"]

prompt_templates = [
"test prompt 1 {}",
"test prompt 2 {}",
]
class_names = [
f"test_class_{i}" for i in range(10)
]

get_prompt_templates_mock.return_value = prompt_templates
get_class_names_mock.return_value = class_names
get_dataset_info_mock.return_value.splits = splits

def pre_filter_fn(features):
return features["label"] < 5 # matches `texts %= 5` above

dataset_name = "cifar10_test"
with tfds.testing.mock_data(num_examples=per_host_num_examples):
evaluator = discriminative_classifier.Evaluator(
lambda p, b: model.apply({"params": p},
b.get("image", None),
b.get("labels", None)),
dataset_names=[dataset_name],
prompt_templates="test_prompts",
batch_size=global_batch_size,
devices=jax.devices(),
pp_img="copy_from(image='label')",
pp_txt="copy_from(labels='label')",
dataset_overrides={
dataset_name: {
"dataset_name": "cifar10",
"class_names": "test_classes",
"pre_filter_fn": pre_filter_fn,
}
},
first_class_name_only=True,
)
results = evaluator.evaluate(
params,
dataset_name,
return_embeddings=True)
metrics = dict(evaluator.run(params))

# Assert all examples were processed.
self.assertLen(results["texts"]["embedding"],
len(class_names) * len(prompt_templates))
self.assertLen(results["texts"]["average_embedding"], len(class_names))
self.assertAllEqual(
sorted(results["texts"]["label"]),
[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9])
# Note that above model makes perfect predictions by design.
self.assertEqual(1.0, results["accuracy"])
self.assertEqual(1.0, metrics[f"{dataset_name}_accuracy"])


if __name__ == "__main__":
Expand Down