generated from rajrajhans/phoenix-elixir-nix-starter
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add Nx Servings: custom implemented + via bumblebee
- Loading branch information
1 parent
d7bb830
commit 7d68136
Showing
10 changed files
with
443 additions
and
148 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
defmodule MediaSearchDemo.Clip.Servings.Bumblebee.Text do | ||
alias MediaSearchDemo.Clip.Servings.Constants | ||
|
||
@spec get_serving :: Nx.Serving.t() | ||
def get_serving() do | ||
model_info = clip_text_embeddings_model() | ||
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, Constants.clip_hf_model()}) | ||
|
||
Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer, | ||
output_attribute: :embedding, | ||
output_pool: nil, | ||
sequence_length: Constants.sequence_length(), | ||
embedding_processor: nil | ||
) | ||
end | ||
|
||
def run_embeddings(text) do | ||
serving = MediaSearchDemo.Clip.Bumblebee.TextServing | ||
|
||
Nx.Serving.batched_run( | ||
serving, | ||
[text] | ||
) | ||
end | ||
|
||
defp clip_text_embeddings_model() do | ||
# CLIP model for text embeddings | ||
# same as Text.ClipText from bumblebee | ||
# but `with` the final projection layer | ||
|
||
{:ok, %{model: _multimodal_model, params: multimodal_params, spec: _multimodal_spec}} = | ||
Bumblebee.load_model({:hf, Constants.clip_hf_model()}, | ||
architecture: :base | ||
) | ||
|
||
{:ok, %{model: text_model, params: text_params, spec: text_spec}} = | ||
Bumblebee.load_model({:hf, Constants.clip_hf_model()}, | ||
module: Bumblebee.Text.ClipText, | ||
architecture: :base | ||
) | ||
|
||
dimension = Application.get_env(:media_search_demo, :clip_embedding_dimension) | ||
|
||
text_model_with_projection_head = | ||
text_model | ||
|> Axon.nx(& &1.pooled_state) | ||
|> Axon.dense(dimension, use_bias: false, name: "text_projection") | ||
# temporary workaround until Bumblebee bug is fixed | ||
|> Axon.nx(fn x -> %{embedding: x} end) | ||
|
||
# extract the text projection layer's params from the multimodal model's params | ||
text_projection_params = multimodal_params["text_projection"] | ||
|
||
# text projection layer params that will be needed for the "text_projection" layer we added | ||
text_params_with_text_projection = | ||
put_in(text_params["text_projection"], text_projection_params) | ||
|
||
%{ | ||
model: text_model_with_projection_head, | ||
params: text_params_with_text_projection, | ||
spec: text_spec | ||
} | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
defmodule MediaSearchDemo.Clip.Servings.Bumblebee.Vision do | ||
alias MediaSearchDemo.Clip.Servings.Constants | ||
|
||
@spec get_serving :: Nx.Serving.t() | ||
def get_serving() do | ||
model_info = clip_vision_embeddings_model() | ||
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, Constants.clip_hf_model()}) | ||
|
||
Bumblebee.Vision.ImageEmbedding.image_embedding(model_info, featurizer, | ||
output_attribute: nil, | ||
output_pool: nil, | ||
embedding_processor: nil | ||
) | ||
end | ||
|
||
def run_embeddings(image) do | ||
serving = MediaSearchDemo.Clip.VisionServing | ||
|
||
Nx.Serving.batched_run( | ||
serving, | ||
[image] | ||
) | ||
end | ||
|
||
defp clip_vision_embeddings_model() do | ||
# CLIP model for vision embeddings | ||
# same as Vision.ClipVision from bumblebee | ||
# but `with` the final projection layer | ||
|
||
{:ok, %{model: _multimodal_model, params: multimodal_params, spec: _multimodal_spec}} = | ||
Bumblebee.load_model({:hf, Constants.clip_hf_model()}, | ||
architecture: :base | ||
) | ||
|
||
{:ok, %{model: vision_model, params: vision_params, spec: vision_spec}} = | ||
Bumblebee.load_model({:hf, Constants.clip_hf_model()}, | ||
module: Bumblebee.Vision.ClipVision, | ||
architecture: :base | ||
) | ||
|
||
dimension = Application.get_env(:media_search_demo, :clip_embedding_dimension) | ||
|
||
vision_model_with_projection_head = | ||
vision_model | ||
|> Axon.nx(& &1.pooled_state) | ||
|> Axon.dense(dimension, use_bias: false, name: "visual_projection") | ||
|
||
# extract the visual projection layer's params from the multimodal model's params | ||
visual_projection_params = multimodal_params["visual_projection"] | ||
|
||
# visual projection layer params that will be needed for the "visual_projection" layer we added | ||
params_with_visual_projection = | ||
put_in(vision_params["visual_projection"], visual_projection_params) | ||
|
||
%{ | ||
model: vision_model_with_projection_head, | ||
params: params_with_visual_projection, | ||
vision_spec: vision_spec | ||
} | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
defmodule MediaSearchDemo.Clip.Servings.Constants do | ||
def clip_hf_model(), do: "openai/clip-vit-base-patch32" | ||
def sequence_length(), do: 42 | ||
def clip_text_batch_size(), do: 10 | ||
def clip_text_batch_timeout(), do: 100 | ||
def clip_vision_batch_size(), do: 10 | ||
def clip_vision_batch_timeout(), do: 100 | ||
end |
Oops, something went wrong.