From 91472de3c800e17eb9773b0c57e7d0c5a13f31db Mon Sep 17 00:00:00 2001 From: Mandlin Sarah Date: Mon, 2 Sep 2024 19:19:56 -0700 Subject: [PATCH] Refactor batch processing into a separate function --- app/calculate_coco_features.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/app/calculate_coco_features.py b/app/calculate_coco_features.py index 168e8503e..4114268d4 100644 --- a/app/calculate_coco_features.py +++ b/app/calculate_coco_features.py @@ -1,9 +1,7 @@ -""" - # Copyright (c) 2022, salesforce.com, inc. - # All rights reserved. - # SPDX-License-Identifier: BSD-3-Clause - # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" +# Copyright (c) 2022, salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause from PIL import Image import requests @@ -18,7 +16,6 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - def load_demo_image(): img_url = ( "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg" @@ -34,6 +31,17 @@ def read_img(filepath): return raw_image +def process_batch(images_in_batch, filepaths_in_batch, feature_extractor, caption, path2feat): + images_in_batch = torch.cat(images_in_batch, dim=0).to(device) + with torch.no_grad(): + image_features = feature_extractor( + images_in_batch, caption, mode="image", normalized=True + )[:, 0] + + for filepath, image_feat in zip(filepaths_in_batch, image_features): + path2feat[os.path.basename(filepath)] = image_feat.detach().cpu() + return path2feat + # model model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth" feature_extractor = BlipFeatureExtractor(pretrained=model_url) @@ -62,14 +70,7 @@ def read_img(filepath): for i, filename in enumerate(filepaths): if i % bsz == 0 and i > 0: - images_in_batch = torch.cat(images_in_batch, dim=0).to(device) - with torch.no_grad(): - image_features = feature_extractor( - images_in_batch, caption, mode="image", normalized=True - )[:, 0] - - for filepath, image_feat in zip(filepaths_in_batch, image_features): - path2feat[os.path.basename(filepath)] = image_feat.detach().cpu() + path2feat = process_batch(images_in_batch, filepaths_in_batch, feature_extractor, caption, path2feat) images_in_batch = [] filepaths_in_batch = [] @@ -84,4 +85,7 @@ def read_img(filepath): images_in_batch.append(image) filepaths_in_batch.append(filepath) +path2feat = process_batch(images_in_batch, filepaths_in_batch, feature_extractor, caption, path2feat) # process remaining images + torch.save(path2feat, "path2feat_coco_train2014.pth") +