Skip to content

Commit

Permalink
feat: add Nx Servings: custom implemented + via bumblebee
Browse files Browse the repository at this point in the history
  • Loading branch information
rajrajhans committed Oct 8, 2023
1 parent d7bb830 commit 7d68136
Show file tree
Hide file tree
Showing 10 changed files with 443 additions and 148 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Bumblebee Media Search

- A demo application that uses the [CLIP model](https://openai.com/research/clip) for natural language media search (searching images with text, and searching related images with an image).
- Presented at [ElixirConf Africa 2023 for the talk "Natural Language Media Search with Elixir , Bumblebee and Nx."](https://elixirconf.africa/sessions/natural-language-media-search-with-elixir-bumblebee-and-nx). Slides can be found [here](https://assets.rajrajhans.com/bumblebee-media-search/slides_raj_rajhans_elixir_conf_africa_2023.pdf)
- Built using [Phoenix Framework](https://github.com/phoenixframework/phoenix), [Bumblebee](https://github.com/elixir-nx/bumblebee), [Axon](https://github.com/elixir-nx/axon), [Nx](https://github.com/elixir-nx/nx) and [HNSWLib](https://github.com/elixir-nx/hnswlib).

## Sneak Peek: Searching for Images with Text
Expand All @@ -16,6 +15,13 @@
| --------------------------------------------------------------------- | ---------------------------------------------------------------------- |
| ![ Searching Images with an Image 3 ](./docs/search-with-image-3.png) | ![ Searching Images with an Image 4 ](./docs/search-with-image-4.jpeg) |

## Nx Servings

- This uses Nx Servings for serving the CLIP model. There are two sets of Nx Servings in the codebase:
1. [Nx Servings provided by Bumblebee for text & image embeddings](./lib/media_search_demo/clip/servings/bumblebee/): Using ready made Nx Servings provided by Bumblebee library.
2. [Hand rolled Nx Servings for text & image embeddings](./lib/media_search_demo/clip/servings/custom/): Custom implemented Nx Servings intended to learn how to implement Nx Servings from scratch.
- Both provide the same output and can be used interchangeably. However, if you're interested in learning how Nx Serving works and how to implement them, the hand rolled Nx Serving files will be helpful.

## Installation

- Uses Nix for dependency management. [Install Nix](https://nixos.org/download.html) if you don't have it already.
Expand Down
2 changes: 1 addition & 1 deletion bin/run
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ run_compile() {

run_server() {
with_dev_env
iex --sname copilot --cookie copilot -S mix phx.server
iex --sname clip-media-search --cookie clip-media-search -S mix phx.server
}

run_iex() {
Expand Down
38 changes: 35 additions & 3 deletions lib/media_search_demo/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ defmodule MediaSearchDemo.Application do
@moduledoc false

use Application
alias MediaSearchDemo.Clip.Servings.Constants

@impl true
def start(_type, _args) do
# make sure the model is loaded before starting the app
Bumblebee.load_model({:hf, "openai/clip-vit-base-patch32"})
{:ok, _} = Bumblebee.load_model({:hf, Constants.clip_hf_model()})

children = [
# Start the Telemetry supervisor
Expand All @@ -21,8 +22,39 @@ defmodule MediaSearchDemo.Application do
{Finch, name: MediaSearchDemo.Finch},
# Start the Endpoint (http/https)
MediaSearchDemoWeb.Endpoint,
MediaSearchDemo.Clip.ModelAgent,
MediaSearchDemo.Clip.ClipIndexAgent
MediaSearchDemo.Clip.ClipIndexAgent,
## Hand Rolled Nx Serving for CLIP Text Embedding ->
{
Nx.Serving,
serving: MediaSearchDemo.Clip.Servings.Text.get_serving(),
name: MediaSearchDemo.Clip.TextServing,
batch_size: Constants.clip_text_batch_size(),
batch_timeout: Constants.clip_text_batch_timeout()
},
## Hand Rolled Nx Servings for CLIP Image Embedding ->
{
Nx.Serving,
serving: MediaSearchDemo.Clip.Servings.Vision.get_serving(),
name: MediaSearchDemo.Clip.VisionServing,
batch_size: Constants.clip_vision_batch_size(),
batch_timeout: Constants.clip_vision_batch_timeout()
},
## Bumblebee Nx Servings for CLIP Text Embedding ->
{
Nx.Serving,
serving: MediaSearchDemo.Clip.Servings.Bumblebee.Text.get_serving(),
name: MediaSearchDemo.Clip.Bumblebee.TextServing,
batch_size: Constants.clip_text_batch_size(),
batch_timeout: Constants.clip_text_batch_timeout()
},
## Bumblebee Nx Servings for CLIP Image Embedding ->
{
Nx.Serving,
serving: MediaSearchDemo.Clip.Servings.Bumblebee.Vision.get_serving(),
name: MediaSearchDemo.Clip.Bumblebee.VisionServing,
batch_size: Constants.clip_vision_batch_size(),
batch_timeout: Constants.clip_vision_batch_timeout()
}
]

# See https://hexdocs.pm/elixir/Supervisor.html
Expand Down
116 changes: 0 additions & 116 deletions lib/media_search_demo/clip/model_agent.ex

This file was deleted.

64 changes: 64 additions & 0 deletions lib/media_search_demo/clip/servings/bumblebee/text.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
defmodule MediaSearchDemo.Clip.Servings.Bumblebee.Text do
alias MediaSearchDemo.Clip.Servings.Constants

@spec get_serving :: Nx.Serving.t()
def get_serving() do
model_info = clip_text_embeddings_model()
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, Constants.clip_hf_model()})

Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer,
output_attribute: :embedding,
output_pool: nil,
sequence_length: Constants.sequence_length(),
embedding_processor: nil
)
end

def run_embeddings(text) do
serving = MediaSearchDemo.Clip.Bumblebee.TextServing

Nx.Serving.batched_run(
serving,
[text]
)
end

defp clip_text_embeddings_model() do
# CLIP model for text embeddings
# same as Text.ClipText from bumblebee
# but `with` the final projection layer

{:ok, %{model: _multimodal_model, params: multimodal_params, spec: _multimodal_spec}} =
Bumblebee.load_model({:hf, Constants.clip_hf_model()},
architecture: :base
)

{:ok, %{model: text_model, params: text_params, spec: text_spec}} =
Bumblebee.load_model({:hf, Constants.clip_hf_model()},
module: Bumblebee.Text.ClipText,
architecture: :base
)

dimension = Application.get_env(:media_search_demo, :clip_embedding_dimension)

text_model_with_projection_head =
text_model
|> Axon.nx(& &1.pooled_state)
|> Axon.dense(dimension, use_bias: false, name: "text_projection")
# temporary workaround until Bumblebee bug is fixed
|> Axon.nx(fn x -> %{embedding: x} end)

# extract the text projection layer's params from the multimodal model's params
text_projection_params = multimodal_params["text_projection"]

# text projection layer params that will be needed for the "text_projection" layer we added
text_params_with_text_projection =
put_in(text_params["text_projection"], text_projection_params)

%{
model: text_model_with_projection_head,
params: text_params_with_text_projection,
spec: text_spec
}
end
end
61 changes: 61 additions & 0 deletions lib/media_search_demo/clip/servings/bumblebee/vision.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
defmodule MediaSearchDemo.Clip.Servings.Bumblebee.Vision do
alias MediaSearchDemo.Clip.Servings.Constants

@spec get_serving :: Nx.Serving.t()
def get_serving() do
model_info = clip_vision_embeddings_model()
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, Constants.clip_hf_model()})

Bumblebee.Vision.ImageEmbedding.image_embedding(model_info, featurizer,
output_attribute: nil,
output_pool: nil,
embedding_processor: nil
)
end

def run_embeddings(image) do
serving = MediaSearchDemo.Clip.VisionServing

Nx.Serving.batched_run(
serving,
[image]
)
end

defp clip_vision_embeddings_model() do
# CLIP model for vision embeddings
# same as Vision.ClipVision from bumblebee
# but `with` the final projection layer

{:ok, %{model: _multimodal_model, params: multimodal_params, spec: _multimodal_spec}} =
Bumblebee.load_model({:hf, Constants.clip_hf_model()},
architecture: :base
)

{:ok, %{model: vision_model, params: vision_params, spec: vision_spec}} =
Bumblebee.load_model({:hf, Constants.clip_hf_model()},
module: Bumblebee.Vision.ClipVision,
architecture: :base
)

dimension = Application.get_env(:media_search_demo, :clip_embedding_dimension)

vision_model_with_projection_head =
vision_model
|> Axon.nx(& &1.pooled_state)
|> Axon.dense(dimension, use_bias: false, name: "visual_projection")

# extract the visual projection layer's params from the multimodal model's params
visual_projection_params = multimodal_params["visual_projection"]

# visual projection layer params that will be needed for the "visual_projection" layer we added
params_with_visual_projection =
put_in(vision_params["visual_projection"], visual_projection_params)

%{
model: vision_model_with_projection_head,
params: params_with_visual_projection,
vision_spec: vision_spec
}
end
end
8 changes: 8 additions & 0 deletions lib/media_search_demo/clip/servings/constants.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
defmodule MediaSearchDemo.Clip.Servings.Constants do
def clip_hf_model(), do: "openai/clip-vit-base-patch32"
def sequence_length(), do: 42
def clip_text_batch_size(), do: 10
def clip_text_batch_timeout(), do: 100
def clip_vision_batch_size(), do: 10
def clip_vision_batch_timeout(), do: 100
end
Loading

0 comments on commit 7d68136

Please sign in to comment.