From 10c3e2a4da236e915bfe39d15aa86dfb92030fef Mon Sep 17 00:00:00 2001 From: zhangqiang Date: Fri, 27 Sep 2024 16:47:48 +0800 Subject: [PATCH] add mock as decorator --- src/ell/__init__.py | 1 + src/ell/lmp/__init__.py | 1 + src/ell/lmp/simple.py | 11 ++++++- src/ell/providers/__init__.py | 1 + src/ell/providers/mock.py | 62 +++++++++++++++++++++++++++++++++++ 5 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 src/ell/providers/mock.py diff --git a/src/ell/__init__.py b/src/ell/__init__.py index e7bd022a..cf1f7c35 100644 --- a/src/ell/__init__.py +++ b/src/ell/__init__.py @@ -5,6 +5,7 @@ from ell.lmp.simple import simple +from ell.lmp.simple import mock from ell.lmp.tool import tool from ell.lmp.complex import complex from ell.types.message import system, user, assistant, Message, ContentBlock diff --git a/src/ell/lmp/__init__.py b/src/ell/lmp/__init__.py index 222958bc..f5c641f7 100644 --- a/src/ell/lmp/__init__.py +++ b/src/ell/lmp/__init__.py @@ -1,2 +1,3 @@ from ell.lmp.simple import simple +from ell.lmp.simple import mock from ell.lmp.complex import complex \ No newline at end of file diff --git a/src/ell/lmp/simple.py b/src/ell/lmp/simple.py index 29320db1..20921ad6 100644 --- a/src/ell/lmp/simple.py +++ b/src/ell/lmp/simple.py @@ -1,7 +1,16 @@ from functools import wraps -from typing import Any, Optional +from typing import Any, Optional, Callable from ell.lmp.complex import complex +from ell.providers.mock import MockAIClient + + +def mock(model: str, client: Optional[Any] = None, exempt_from_tracking=False, mock_func:Callable[..., Any]=None, **api_params): + """Mock decortoar should accept everything passed to simple""" + if mock_func: + api_params['mock_func'] = mock_func + + return simple(model, client=MockAIClient(), exempt_from_tracking=exempt_from_tracking, **api_params) def simple(model: str, client: Optional[Any] = None, exempt_from_tracking=False, **api_params): diff --git a/src/ell/providers/__init__.py b/src/ell/providers/__init__.py index 0f24e7e2..34b3da86 100644 --- a/src/ell/providers/__init__.py +++ b/src/ell/providers/__init__.py @@ -1,6 +1,7 @@ import ell.providers.openai import ell.providers.groq import ell.providers.anthropic +import ell.providers.mock # import ell.providers.mistral # import ell.providers.cohere # import ell.providers.gemini diff --git a/src/ell/providers/mock.py b/src/ell/providers/mock.py new file mode 100644 index 00000000..5d7acb4b --- /dev/null +++ b/src/ell/providers/mock.py @@ -0,0 +1,62 @@ +import random +import string + +from typing import Optional, Dict, Any, List, Type, Tuple +from ell.provider import Provider, EllCallParams, Metadata +from ell.types import Message +from ell.types.message import LMP +from ell.configurator import config, register_provider +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast + + +class MockAIClient: + + def __init__(self, **kwargs): + self.api_key = "mock" + + def chat_completions_create(self, **kwargs): + return None + + +class MockAIProvider(Provider): + dangerous_disable_validation = True + + def provider_call_function( + self, client: MockAIClient, api_call_params: Optional[Dict[str, Any]] = None + ) -> Callable[..., Any]: + return client.chat_completions_create + + def translate_to_provider(self, ell_call: EllCallParams) -> Dict[str, Any]: + return ell_call.api_params.copy() + + def default_mock_func(self) -> Tuple[List[Message], Metadata]: + results = [] + random_str = "".join( + random.choices( + string.ascii_letters + string.digits, k=random.randint(1, 40) + ) + ) + results.append( + Message( + role=("user"), + content="mock_" + random_str, + ) + ) + return results, Metadata + + def translate_from_provider( + self, + _provider_response: Any, + _ell_call: EllCallParams, + provider_call_params: Dict[str, Any], + _origin_id: Optional[str] = None, + _logger: Optional[Callable[..., None]] = None, + ) -> Tuple[List[Message], Metadata]: + if "mock_func" in provider_call_params: + mock_func = provider_call_params["mock_func"] + return [Message(role=("user"), content=mock_func())], Metadata + else: + return self.default_mock_func() + + +register_provider(MockAIProvider(), MockAIClient)