Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mock client for fast prototyping code #176

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ell/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


from ell.lmp.simple import simple
from ell.lmp.simple import mock
from ell.lmp.tool import tool
from ell.lmp.complex import complex
from ell.types.message import system, user, assistant, Message, ContentBlock
Expand Down
1 change: 1 addition & 0 deletions src/ell/lmp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from ell.lmp.simple import simple
from ell.lmp.simple import mock
from ell.lmp.complex import complex
11 changes: 10 additions & 1 deletion src/ell/lmp/simple.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from functools import wraps
from typing import Any, Optional
from typing import Any, Optional, Callable

from ell.lmp.complex import complex
from ell.providers.mock import MockAIClient


def mock(model: str, client: Optional[Any] = None, exempt_from_tracking=False, mock_func:Callable[..., Any]=None, **api_params):
"""Mock decortoar should accept everything passed to simple"""
if mock_func:
api_params['mock_func'] = mock_func

return simple(model, client=MockAIClient(), exempt_from_tracking=exempt_from_tracking, **api_params)


def simple(model: str, client: Optional[Any] = None, exempt_from_tracking=False, **api_params):
Expand Down
1 change: 1 addition & 0 deletions src/ell/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ell.providers.openai
import ell.providers.groq
import ell.providers.anthropic
import ell.providers.mock
# import ell.providers.mistral
# import ell.providers.cohere
# import ell.providers.gemini
Expand Down
62 changes: 62 additions & 0 deletions src/ell/providers/mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import random
import string

from typing import Optional, Dict, Any, List, Type, Tuple
from ell.provider import Provider, EllCallParams, Metadata
from ell.types import Message
from ell.types.message import LMP
from ell.configurator import config, register_provider
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast


class MockAIClient:

def __init__(self, **kwargs):
self.api_key = "mock"

def chat_completions_create(self, **kwargs):
return None


class MockAIProvider(Provider):
dangerous_disable_validation = True

def provider_call_function(
self, client: MockAIClient, api_call_params: Optional[Dict[str, Any]] = None
) -> Callable[..., Any]:
return client.chat_completions_create

def translate_to_provider(self, ell_call: EllCallParams) -> Dict[str, Any]:
return ell_call.api_params.copy()

def default_mock_func(self) -> Tuple[List[Message], Metadata]:
results = []
random_str = "".join(
random.choices(
string.ascii_letters + string.digits, k=random.randint(1, 40)
)
)
results.append(
Message(
role=("user"),
content="mock_" + random_str,
)
)
return results, Metadata

def translate_from_provider(
self,
_provider_response: Any,
_ell_call: EllCallParams,
provider_call_params: Dict[str, Any],
_origin_id: Optional[str] = None,
_logger: Optional[Callable[..., None]] = None,
) -> Tuple[List[Message], Metadata]:
if "mock_func" in provider_call_params:
mock_func = provider_call_params["mock_func"]
return [Message(role=("user"), content=mock_func())], Metadata
else:
return self.default_mock_func()


register_provider(MockAIProvider(), MockAIClient)