Skip to content

Commit

Permalink
[HuggingFace] Make hugging face embeddings work in docker run with mu…
Browse files Browse the repository at this point in the history
…ltiple pipelines (#721)
  • Loading branch information
eolivelli authored Nov 16, 2023
1 parent 54d8aef commit 3f3a84e
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 73 deletions.
15 changes: 1 addition & 14 deletions examples/applications/ollama-chatbot/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ In this example we are using [HerdDB](ps://github.com/diennea/herddb) as a vecto
but you can use any Vector databases.

As LLM we are using [Ollama](https://ollama.ai), that is a service that runs on your machine.
We are using OpenAI to compute the embeddings of the texts.
We are using Hugging Face to compute the embeddings of the texts.

## Install Ollama

Expand All @@ -27,19 +27,6 @@ If you want to use another model export this variable before starting the applic
export OLLAMA_MODEL=llama2:13b
```

## Configure you OpenAI API Key

At the moment it the embeddings computed by Ollama models are not performing well, so we are using OpenAI to compute them.

Export to the ENV the access key to OpenAI

```
export OPEN_AI_ACCESS_KEY=...
```

The default [secrets file](../../secrets/secrets.yaml) reads from the ENV. Check out the file to learn more about
the default settings, you can change them by exporting other ENV variables.

## Deploy the LangStream application in docker

The default docker runner starts Minio, Kafka and HerdDB, so you can run the application locally.
Expand Down
5 changes: 3 additions & 2 deletions examples/applications/ollama-chatbot/chatbot.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ pipeline:
- name: "compute-embeddings"
type: "compute-ai-embeddings"
configuration:
ai-service: "openai"
model: "${secrets.open-ai.embeddings-model}"
ai-service: "huggingface"
model: "${secrets.hugging-face.embeddings-model}" # This is the id of the model
model-url: "${secrets.hugging-face.embeddings-model-url}" # This is the URL of the repository containing the model
embeddings-field: "value.question_embeddings"
text: "{{ value.question }}"
flush-interval: 0
Expand Down
10 changes: 4 additions & 6 deletions examples/applications/ollama-chatbot/configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@ configuration:
id: "ollama"
configuration:
url: "${secrets.ollama.url}"
- type: "open-ai-configuration"
name: "OpenAI Azure configuration"
id: "openai"
- type: "hugging-face-configuration"
name: "Hugging Face AI configuration"
id: "huggingface"
configuration:
url: "${secrets.open-ai.url}"
access-key: "${secrets.open-ai.access-key}"
provider: "${secrets.open-ai.provider}"
provider: "local"
dependencies:
- name: "HerdDB.org JDBC Driver"
url: "https://repo1.maven.org/maven2/org/herddb/herddb-jdbc/0.28.0/herddb-jdbc-0.28.0-thin.jar"
Expand Down
9 changes: 5 additions & 4 deletions examples/applications/ollama-chatbot/crawler.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ pipeline:
type: "text-splitter"
configuration:
splitter_type: "RecursiveCharacterTextSplitter"
chunk_size: 400
chunk_size: 200
separators: ["\n\n", "\n", " ", ""]
keep_separator: false
chunk_overlap: 100
chunk_overlap: 20
length_function: "cl100k_base"
- name: "Convert to structured data"
type: "document-to-json"
Expand All @@ -76,8 +76,9 @@ pipeline:
id: "step1"
type: "compute-ai-embeddings"
configuration:
ai-service: "openai"
model: "${secrets.open-ai.embeddings-model}"
ai-service: "huggingface"
model: "${secrets.hugging-face.embeddings-model}" # This is the id of the model
model-url: "${secrets.hugging-face.embeddings-model-url}" # This is the URL of the repository containing the model
embeddings-field: "value.embeddings_vector"
text: "{{ value.text }}"
batch-size: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,38 @@
package com.datastax.oss.streaming.ai.embeddings;

import ai.djl.MalformedModelException;
import ai.djl.engine.Engine;
import ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory;
import ai.djl.inference.Predictor;
import ai.djl.pytorch.jni.LibUtils;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.lang.reflect.ParameterizedType;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.locks.ReentrantLock;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public abstract class AbstractHuggingFaceEmbeddingService<IN, OUT>
implements EmbeddingsService, AutoCloseable {

static {
log.info("Loading libtorch");
LibUtils.loadLibrary();
Engine.getEngine("PyTorch");
}

/**
* comma-separated list of allowed url prefixes, like
* file://,s3://,djl://,https://models.datastax.com/
Expand Down Expand Up @@ -90,6 +101,8 @@ public static class HuggingFaceConfig {
// http://djl.ai/docs/development/inference_performance_optimization.html#multithreading-support
ZooModel<IN, OUT> model;

private static final ReentrantLock localModelLock = new ReentrantLock();

private static final ThreadLocal<Predictor<?, ?>> predictorThreadLocal = new ThreadLocal<>();
private static final ConcurrentLinkedQueue<Predictor<?, ?>> predictorList =
new ConcurrentLinkedQueue<>();
Expand All @@ -98,7 +111,8 @@ public AbstractHuggingFaceEmbeddingService(HuggingFaceConfig conf)
throws IOException,
ModelNotFoundException,
MalformedModelException,
IllegalAccessException {
IllegalAccessException,
InterruptedException {
Objects.requireNonNull(conf);
Objects.requireNonNull(conf.modelName);

Expand Down Expand Up @@ -142,7 +156,19 @@ public AbstractHuggingFaceEmbeddingService(HuggingFaceConfig conf)

Criteria<IN, OUT> criteria = builder.build();

model = criteria.loadModel();
localModelLock.lockInterruptibly();
try {
model = criteria.loadModel();
} catch (ai.djl.engine.EngineException error) {
log.info("Classloader information: {}", getClass().getClassLoader());
log.info(
"Classloader information for Criteria: {}",
criteria.getClass().getClassLoader());
log.info("Context classloader: {}", Thread.currentThread().getContextClassLoader());
throw error;
} finally {
localModelLock.unlock();
}
}

private void checkIfUrlIsAllowed(String modelUrl) throws IllegalAccessException {
Expand All @@ -162,7 +188,22 @@ public List<OUT> compute(List<IN> texts) throws TranslateException {
predictorList.add(predictor);
}

return predictor.batchPredict(texts);
List<OUT> result = new ArrayList<>(texts.size());
// we are doing one text at a time, but we could do more
// batchPredict wants texts with the same size
for (IN in : texts) {
try {
result.add(predictor.predict(in));
} catch (TranslateException error) {
Throwable cause = error.getCause();
if (cause instanceof IllegalArgumentException err) {
throw new TranslateException(
"Illegal input, maybe the number of tokens is too high", error);
}
throw error;
}
}
return result;
}

abstract List<IN> convertInput(List<String> texts);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ public HuggingFaceEmbeddingService(HuggingFaceConfig conf)
throws IOException,
ModelNotFoundException,
MalformedModelException,
IllegalAccessException {
IllegalAccessException,
InterruptedException {
super(conf);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,31 @@
*/
package com.datastax.oss.streaming.ai.embeddings;

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.util.List;
import org.junit.jupiter.api.Disabled;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;

// disabled, just for experiments/usage demo
@Slf4j
public class HuggingFaceEmbeddingServiceTest {

@Disabled
public void testMain() throws Exception {
@Test
public void testEmbeddings() throws Exception {
AbstractHuggingFaceEmbeddingService.HuggingFaceConfig conf =
AbstractHuggingFaceEmbeddingService.HuggingFaceConfig.builder()
.engine("PyTorch")
.modelUrl(
"file:///Users/andreyyegorov/src/djl/model/nlp/text_embedding/ai/djl/huggingface/pytorch/sentence-transformers/all-MiniLM-L6-v2/0.0.1/all-MiniLM-L6-v2.zip")
.modelName("multilingual-e5-small")
.modelUrl("djl://ai.djl.huggingface.pytorch/intfloat/multilingual-e5-small")
.build();

try (EmbeddingsService service = new HuggingFaceEmbeddingService(conf)) {
List<List<Double>> result =
service.computeEmbeddings(List.of("hello world", "stranger things")).get();
result.forEach(System.out::println);

List<List<Double>> lists =
service.computeEmbeddings(List.of("Hello", "my friend")).get();
assertEquals(2, lists.size());
assertEquals(List.of(384), List.of(lists.get(0).size()));
assertEquals(List.of(384), List.of(lists.get(1).size()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,20 @@ public synchronized void unpack() throws Exception {
file.extractTo(dest);
directory = dest;
}

interface ClassloaderBuilder {
URLClassLoader apply(PackageMetadata metadata) throws Exception;
}

public synchronized URLClassLoader buildClassloader(ClassloaderBuilder builder)
throws Exception {
if (this.classLoader != null) {
return classLoader;
}
unpack();
this.classLoader = builder.apply(this);
return classLoader;
}
}

public synchronized void scan() throws Exception {
Expand Down Expand Up @@ -394,7 +408,6 @@ public List<? extends ClassLoader> getAllClassloaders() throws Exception {

classloaders = new ArrayList<>();
for (PackageMetadata metadata : packages.values()) {
metadata.unpack();
URLClassLoader result =
createClassloaderForPackage(customLibClasspath, metadata, parentClassloader);
classloaders.add(result);
Expand Down Expand Up @@ -422,46 +435,50 @@ public ClassLoader getSystemClassloader() {
}

private static URLClassLoader createClassloaderForPackage(
List<URL> customLibClasspath, PackageMetadata metadata, ClassLoader parentClassloader)
List<URL> customLibClasspath,
PackageMetadata packageMetadata,
ClassLoader parentClassloader)
throws Exception {

if (metadata.classLoader != null) {
return metadata.classLoader;
}

metadata.unpack();

log.debug("Creating classloader for package {}", metadata.name);
List<URL> urls = new ArrayList<>();

log.debug("Adding agents code {}", metadata.directory);
urls.add(metadata.directory.toFile().toURI().toURL());

Path metaInfDirectory = metadata.directory.resolve("META-INF");
if (Files.isDirectory(metaInfDirectory)) {

Path dependencies = metaInfDirectory.resolve("bundled-dependencies");
if (Files.isDirectory(dependencies)) {
try (DirectoryStream<Path> allFiles = Files.newDirectoryStream(dependencies)) {
for (Path file : allFiles) {
if (file.getFileName().toString().endsWith(".jar")) {
urls.add(file.toUri().toURL());
return packageMetadata.buildClassloader(
(metadata) -> {
log.info(
"Creating classloader for package {} id {}",
metadata.name,
System.identityHashCode(metadata));
List<URL> urls = new ArrayList<>();

log.debug("Adding agents code {}", metadata.directory);
urls.add(metadata.directory.toFile().toURI().toURL());

Path metaInfDirectory = metadata.directory.resolve("META-INF");
if (Files.isDirectory(metaInfDirectory)) {

Path dependencies = metaInfDirectory.resolve("bundled-dependencies");
if (Files.isDirectory(dependencies)) {
try (DirectoryStream<Path> allFiles =
Files.newDirectoryStream(dependencies)) {
for (Path file : allFiles) {
if (file.getFileName().toString().endsWith(".jar")) {
urls.add(file.toUri().toURL());
}
}
}
}
}
}
}
}

URLClassLoader result = new NarFileClassLoader(metadata.name, urls, parentClassloader);
URLClassLoader result =
new NarFileClassLoader(metadata.name, urls, parentClassloader);

if (!customLibClasspath.isEmpty()) {
result =
new NarFileClassLoader(
metadata.name + "+custom-lib", customLibClasspath, result);
}
if (!customLibClasspath.isEmpty()) {
result =
new NarFileClassLoader(
metadata.name + "+custom-lib", customLibClasspath, result);
}

metadata.classLoader = result;
metadata.classLoader = result;

return result;
return result;
});
}
}

0 comments on commit 3f3a84e

Please sign in to comment.