Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OllamaLLM and OllamaEmbeddings classes #231

Merged
merged 26 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e9712a9
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 15, 2024
b52c45e
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 16, 2024
84c1780
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 17, 2024
47d4782
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 21, 2024
bc7a2f9
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 22, 2024
a945284
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 22, 2024
4e13c23
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 23, 2024
5367bed
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 24, 2024
21d1223
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 25, 2024
3329cd7
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 25, 2024
d8f6364
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Oct 28, 2024
4cec2f3
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Nov 4, 2024
4445b49
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Nov 5, 2024
939b18c
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Nov 18, 2024
1104519
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Nov 22, 2024
1893b85
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Nov 25, 2024
6e4ebda
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Nov 28, 2024
8db7f01
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Dec 9, 2024
d601268
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Dec 10, 2024
3b00587
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python
stellasia Dec 11, 2024
e6d939d
Add OllamaLLM and OllamaEmbeddings classes using the ollama python cl…
stellasia Dec 11, 2024
d3acad9
Try removing import
stellasia Dec 11, 2024
fa44e57
:(
stellasia Dec 11, 2024
63d0ad6
Add tests + reformat import in ollama embeddings for consistency with…
stellasia Dec 11, 2024
35e2d39
Merge branch 'main' of https://github.com/neo4j/neo4j-graphrag-python…
stellasia Dec 12, 2024
41b051e
Fix after merge
stellasia Dec 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
## Added
- Integrated json-repair package to handle and repair invalid JSON generated by LLMs.
- Introduced InvalidJSONError exception for handling cases where JSON repair fails.
- Added `OllamaLLM` and `OllamaEmbeddings` classes to make Ollama support more explicit. Implementations using the `OpenAILLM` and `OpenAIEmbeddings` classes will still work.

## Changed
- Updated LLM prompts to include stricter instructions for generating valid JSON.
Expand Down
12 changes: 12 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,12 @@ AzureOpenAIEmbeddings
.. autoclass:: neo4j_graphrag.embeddings.openai.AzureOpenAIEmbeddings
:members:

OllamaEmbeddings
================

.. autoclass:: neo4j_graphrag.embeddings.ollama.OllamaEmbeddings
:members:

VertexAIEmbeddings
==================

Expand Down Expand Up @@ -257,6 +263,12 @@ AzureOpenAILLM
:members:
:undoc-members: get_messages, client_class, async_client_class

OllamaLLM
---------

.. autoclass:: neo4j_graphrag.llm.ollama_llm.OllamaLLM
:members:


VertexAILLM
-----------
Expand Down
34 changes: 5 additions & 29 deletions docs/source/user_guide_rag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,13 @@ See :ref:`coherellm`.
Using a Local Model via Ollama
-------------------------------

Similarly to the official OpenAI Python client, the `OpenAILLM` can be
used with Ollama. Assuming Ollama is running on the default address `127.0.0.1:11434`,
Assuming Ollama is running on the default address `127.0.0.1:11434`,
it can be queried using the following:

.. code:: python

from neo4j_graphrag.llm import OpenAILLM
llm = OpenAILLM(api_key="ollama", base_url="http://127.0.0.1:11434/v1", model_name="orca-mini")
from neo4j_graphrag.llm import OllamaLLM
llm = OllamaLLM(model_name="orca-mini")
llm.invoke("say something")


Expand Down Expand Up @@ -428,6 +427,7 @@ Currently, this package supports the following embedders:
- :ref:`mistralaiembeddings`
- :ref:`cohereembeddings`
- :ref:`azureopenaiembeddings`
- :ref:`ollamaembeddings`

The `OpenAIEmbeddings` was illustrated previously. Here is how to use the `SentenceTransformerEmbeddings`:

Expand All @@ -438,31 +438,7 @@ The `OpenAIEmbeddings` was illustrated previously. Here is how to use the `Sente
embedder = SentenceTransformerEmbeddings(model="all-MiniLM-L6-v2") # Note: this is the default model


If another embedder is desired, a custom embedder can be created. For example, consider
the following implementation of an embedder that wraps the `OllamaEmbedding` model from LlamaIndex:

.. code:: python

from llama_index.embeddings.ollama import OllamaEmbedding
from neo4j_graphrag.embeddings.base import Embedder

class OllamaEmbedder(Embedder):
def __init__(self, ollama_embedding):
self.embedder = ollama_embedding

def embed_query(self, text: str) -> list[float]:
embedding = self.embedder.get_text_embedding_batch(
[text], show_progress=True
)
return embedding[0]

ollama_embedding = OllamaEmbedding(
model_name="llama3",
base_url="http://localhost:11434",
ollama_additional_kwargs={"mirostat": 0},
)
embedder = OllamaEmbedder(ollama_embedding)
vector = embedder.embed_query("some text")
If another embedder is desired, a custom embedder can be created, using the `Embedder` interface.


Other Vector Retriever Configuration
Expand Down
11 changes: 3 additions & 8 deletions examples/customize/embeddings/ollama_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
"""This example demonstrate how to embed a text into a vector
using OpenAI models and API.
using a local model served by Ollama.
"""

from neo4j_graphrag.embeddings import OpenAIEmbeddings
from neo4j_graphrag.embeddings import OllamaEmbeddings

# not used but needs to be provided
api_key = "ollama"

embeder = OpenAIEmbeddings(
base_url="http://localhost:11434/v1",
api_key=api_key,
embeder = OllamaEmbeddings(
model="<model_name>",
)
res = embeder.embed_query("my question")
Expand Down
11 changes: 5 additions & 6 deletions examples/customize/llms/ollama_llm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from neo4j_graphrag.llm import LLMResponse, OpenAILLM
"""This example demonstrate how to invoke an LLM using a local model
served by Ollama.
"""

# not used but needs to be provided
api_key = "ollama"
from neo4j_graphrag.llm import LLMResponse, OllamaLLM

llm = OpenAILLM(
base_url="http://localhost:11434/v1",
llm = OllamaLLM(
model_name="<model_name>",
api_key=api_key,
)
res: LLMResponse = llm.invoke("What is the additive color model?")
print(res.content)
40 changes: 28 additions & 12 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ openai = {version = "^1.51.1", optional = true }
anthropic = { version = "^0.36.0", optional = true}
sentence-transformers = {version = "^3.0.0", optional = true }
json-repair = "^0.30.2"
ollama = {version = "^0.4.4", optional = true}

[tool.poetry.group.dev.dependencies]
urllib3 = "<2"
Expand All @@ -68,6 +69,7 @@ pinecone = ["pinecone-client"]
google = ["google-cloud-aiplatform"]
cohere = ["cohere"]
anthropic = ["anthropic"]
ollama = ["ollama"]
openai = ["openai"]
mistralai = ["mistralai"]
qdrant = ["qdrant-client"]
Expand Down
2 changes: 2 additions & 0 deletions src/neo4j_graphrag/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
from .base import Embedder
from .cohere import CohereEmbeddings
from .mistral import MistralAIEmbeddings
from .ollama import OllamaEmbeddings
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from .sentence_transformers import SentenceTransformerEmbeddings
from .vertexai import VertexAIEmbeddings

__all__ = [
"Embedder",
"SentenceTransformerEmbeddings",
"OllamaEmbeddings",
"OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
"VertexAIEmbeddings",
Expand Down
3 changes: 1 addition & 2 deletions src/neo4j_graphrag/embeddings/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]:
**kwargs (Any): Additional keyword arguments to pass to the Mistral AI client.
"""
embeddings_batch_response = self.mistral_client.embeddings.create(
model=self.model,
inputs=[text],
model=self.model, inputs=[text], **kwargs
)
if embeddings_batch_response is None or not embeddings_batch_response.data:
raise EmbeddingsGenerationError("Failed to retrieve embeddings.")
Expand Down
65 changes: 65 additions & 0 deletions src/neo4j_graphrag/embeddings/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any

from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.exceptions import EmbeddingsGenerationError


class OllamaEmbeddings(Embedder):
"""
Ollama embeddings class.
This class uses the ollama Python client to generate vector embeddings for text data.

Args:
model (str): The name of the Mistral AI text embedding model to use. Defaults to "mistral-embed".
"""

def __init__(self, model: str, **kwargs: Any) -> None:
try:
import ollama
except ImportError:
raise ImportError(
"Could not import ollama python client. "
"Please install it with `pip install ollama`."
)
self.model = model
self.client = ollama.Client(**kwargs)

def embed_query(self, text: str, **kwargs: Any) -> list[float]:
"""
Generate embeddings for a given query using an Ollama text embedding model.

Args:
text (str): The text to generate an embedding for.
**kwargs (Any): Additional keyword arguments to pass to the Ollama client.
"""
embeddings_response = self.client.embed(
model=self.model,
input=text,
**kwargs,
)

if embeddings_response is None or embeddings_response.embeddings is None:
raise EmbeddingsGenerationError("Failed to retrieve embeddings.")

embedding = embeddings_response.embeddings
if not isinstance(embedding, list):
raise EmbeddingsGenerationError("Embedding is not a list of floats.")

return embedding
2 changes: 2 additions & 0 deletions src/neo4j_graphrag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .base import LLMInterface
from .cohere_llm import CohereLLM
from .mistralai_llm import MistralAILLM
from .ollama_llm import OllamaLLM
from .openai_llm import AzureOpenAILLM, OpenAILLM
from .types import LLMResponse
from .vertexai_llm import VertexAILLM
Expand All @@ -25,6 +26,7 @@
"CohereLLM",
"LLMResponse",
"LLMInterface",
"OllamaLLM",
"OpenAILLM",
"VertexAILLM",
"AzureOpenAILLM",
Expand Down
Loading
Loading