Skip to content

Commit

Permalink
fix: improve ollama workflow from CI (#53)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Neale <[email protected]>
Co-authored-by: Adrian Cole <[email protected]>
  • Loading branch information
3 people authored Sep 25, 2024
1 parent 5b34bc5 commit 9feb7ec
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
4 changes: 3 additions & 1 deletion src/exchange/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from copy import deepcopy
from typing import Any, Dict, List, Mapping, Tuple

from attrs import define, evolve, field
from attrs import define, evolve, field, Factory
from tiktoken import get_encoding

from exchange.checkpoint import Checkpoint, CheckpointData
Expand Down Expand Up @@ -44,6 +44,7 @@ class Exchange:
tools: Tuple[Tool] = field(factory=tuple, converter=tuple)
messages: List[Message] = field(factory=list)
checkpoint_data: CheckpointData = field(factory=CheckpointData)
generation_args: dict = field(default=Factory(dict))

@property
def _toolmap(self) -> Mapping[str, Tool]:
Expand Down Expand Up @@ -77,6 +78,7 @@ def generate(self) -> Message:
self.system,
messages=self.messages,
tools=self.tools,
**self.generation_args,
)
self.add(message)
self.add_checkpoints_from_usage(usage) # this has to come after adding the response
Expand Down
29 changes: 17 additions & 12 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,25 @@
too_long_chars = "x" * (2**20 + 1)

cases = [
(get_provider("ollama"), os.getenv("OLLAMA_MODEL", OLLAMA_MODEL)),
(get_provider("openai"), "gpt-4o-mini"),
(get_provider("databricks"), "databricks-meta-llama-3-70b-instruct"),
(get_provider("bedrock"), "anthropic.claude-3-5-sonnet-20240620-v1:0"),
# Set seed and temperature for more determinism, to avoid flakes
(get_provider("ollama"), os.getenv("OLLAMA_MODEL", OLLAMA_MODEL), dict(seed=3, temperature=0.1)),
(get_provider("openai"), "gpt-4o-mini", dict()),
(get_provider("databricks"), "databricks-meta-llama-3-70b-instruct", dict()),
(get_provider("bedrock"), "anthropic.claude-3-5-sonnet-20240620-v1:0", dict()),
]


@pytest.mark.integration # skipped in CI/CD
@pytest.mark.parametrize("provider,model", cases)
def test_simple(provider, model):
@pytest.mark.parametrize("provider,model,kwargs", cases)
def test_simple(provider, model, kwargs):
provider = provider.from_env()

ex = Exchange(
provider=provider,
model=model,
moderator=ContextTruncate(model),
system="You are a helpful assistant.",
generation_args=kwargs,
)

ex.add(Message.user("Who is the most famous wizard from the lord of the rings"))
Expand All @@ -38,8 +40,8 @@ def test_simple(provider, model):


@pytest.mark.integration # skipped in CI/CD
@pytest.mark.parametrize("provider,model", cases)
def test_tools(provider, model, tmp_path):
@pytest.mark.parametrize("provider,model,kwargs", cases)
def test_tools(provider, model, kwargs, tmp_path):
provider = provider.from_env()

def read_file(filename: str) -> str:
Expand All @@ -48,8 +50,8 @@ def read_file(filename: str) -> str:
Args:
filename (str): The path to the file, which can be relative or
absolute. If it is a plain filename, it is assumed to be in the
current working directory.
absolute. If it is a plain filename, it is assumed to be in the
current working directory.
Returns:
str: The contents of the file.
Expand All @@ -60,8 +62,10 @@ def read_file(filename: str) -> str:
ex = Exchange(
provider=provider,
model=model,
moderator=ContextTruncate(model),
system="You are a helpful assistant. Expect to need to read a file using read_file.",
tools=(Tool.from_function(read_file),),
generation_args=kwargs,
)

ex.add(Message.user("What are the contents of this file? test.txt"))
Expand All @@ -72,8 +76,8 @@ def read_file(filename: str) -> str:


@pytest.mark.integration
@pytest.mark.parametrize("provider,model", cases)
def test_tool_use_output_chars(provider, model):
@pytest.mark.parametrize("provider,model,kwargs", cases)
def test_tool_use_output_chars(provider, model, kwargs):
provider = provider.from_env()

def get_password() -> str:
Expand All @@ -86,6 +90,7 @@ def get_password() -> str:
moderator=ContextTruncate(model),
system="You are a helpful assistant. Expect to need to authenticate using get_password.",
tools=(Tool.from_function(get_password),),
generation_args=kwargs,
)

ex.add(Message.user("Can you authenticate this session by responding with the password"))
Expand Down

0 comments on commit 9feb7ec

Please sign in to comment.