Skip to content

Commit

Permalink
Fix ut (#85)
Browse files Browse the repository at this point in the history
Signed-off-by: Jael Gu <[email protected]>
  • Loading branch information
jaelgu authored Oct 23, 2023
1 parent 0f4893c commit cbfde4d
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 37 deletions.
1 change: 0 additions & 1 deletion .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ jobs:
pip install coverage
pip install pytest
pip install -r requirements.txt
pip install -r test_requirements.txt
- name: Install test dependency
shell: bash
working-directory: tests
Expand Down
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

################## LLM ##################
LLM_OPTION = os.getenv('LLM_OPTION', 'openai') # select your LLM service
LANGUAGE = os.getenv('LANGUAGE', 'en') # options: en, zh
LANGUAGE = os.getenv('DOC_LANGUAGE', 'en') # options: en, zh
CHAT_CONFIG = {
'openai': {
'openai_model': 'gpt-3.5-turbo',
Expand Down
26 changes: 23 additions & 3 deletions src_langchain/llm/ernie.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from config import CHAT_CONFIG # pylint: disable=C0413
from typing import Any, List, Dict
from typing import Any, List, Dict, Optional
import os
import sys

Expand All @@ -9,6 +8,9 @@

sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))

from config import CHAT_CONFIG # pylint: disable=C0413


CHAT_CONFIG = CHAT_CONFIG['ernie']
llm_kwargs = CHAT_CONFIG.get('llm_kwargs', {})

Expand All @@ -20,7 +22,25 @@ class ChatLLM(BaseChatModel):
eb_access_token: str = CHAT_CONFIG['eb_access_token'] or os.getenv('EB_ACCESS_TOKEN')
llm_kwargs: dict = llm_kwargs

def _generate(self, messages: List[BaseMessage]) -> ChatResult:
def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None) -> ChatResult: # pylint: disable=W0613
import erniebot # pylint: disable=C0415
erniebot.api_type = self.eb_api_type
erniebot.access_token = self.eb_access_token

message_dicts = self._create_message_dicts(messages)

response = erniebot.ChatCompletion.create(
model=self.model_name,
messages=message_dicts,
**self.llm_kwargs
)
return self._create_chat_result(response)

async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None # pylint: disable=W0613
) -> ChatResult:
import erniebot # pylint: disable=C0415
erniebot.api_type = self.eb_api_type
erniebot.access_token = self.eb_access_token
Expand Down
15 changes: 0 additions & 15 deletions test_requirements.txt

This file was deleted.

3 changes: 2 additions & 1 deletion tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ milvus
transformers
dashscope
zhipuai
sentence-transformers
sentence-transformers
erniebot
2 changes: 1 addition & 1 deletion tests/unit_tests/src_langchain/llm/test_ernie.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from unittest.mock import patch
from langchain.schema import HumanMessage, AIMessage

sys.path.append(os.path.join(os.path.dirname(__file__), '../../../../..'))
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../..'))


class TestERNIE(unittest.TestCase):
Expand Down
44 changes: 29 additions & 15 deletions tests/unit_tests/src_towhee/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ def create_pipelines(llm_src):

class TestPipelines(unittest.TestCase):
project = 'akcio_ut'
data_src = 'https://towhee.io'
data_src = 'akcio_ut.txt'
question = 'test question'

@classmethod
def setUpClass(cls):
with open(cls.data_src, 'w+', encoding='utf-8') as tmp_f:
tmp_f.write('This is test content.')
milvus_server.cleanup()
milvus_server.start()

Expand All @@ -84,7 +86,7 @@ def test_openai(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 261
assert token_count == 5
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -121,7 +123,7 @@ def test_chatglm(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 261
assert token_count == 5
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand All @@ -137,13 +139,24 @@ def test_chatglm(self):

def test_ernie(self):

class MockRequest:
def json(self):
return {'result': MOCK_ANSWER}

with patch('requests.request') as mock_llm:

mock_llm.return_value = MockRequest()
from erniebot.response import EBResponse

with patch('erniebot.ChatCompletion.create') as mock_post:
mock_res = EBResponse(code=200,
body={'id': 'as-0000000000', 'object': 'chat.completion', 'created': 11111111,
'result': MOCK_ANSWER,
'usage': {'prompt_tokens': 1, 'completion_tokens': 13, 'total_tokens': 14},
'need_clear_history': False, 'is_truncated': False},
headers={'Connection': 'keep-alive',
'Content-Security-Policy': 'frame-ancestors https://*.baidu.com/',
'Content-Type': 'application/json', 'Date': 'Mon, 23 Oct 2023 03:30:53 GMT',
'Server': 'nginx', 'Statement': 'AI-generated',
'Vary': 'Origin, Access-Control-Request-Method, Access-Control-Request-Headers',
'X-Frame-Options': 'allow-from https://*.baidu.com/',
'X-Request-Id': '0' * 32,
'Transfer-Encoding': 'chunked'}
)
mock_post.return_value = mock_res

pipelines = create_pipelines('ernie')

Expand All @@ -160,7 +173,7 @@ def json(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 261
assert token_count == 5
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -205,7 +218,7 @@ def output(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 261
assert token_count == 5
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -244,7 +257,7 @@ def json(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 261
assert token_count == 5
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -285,7 +298,7 @@ def iter_lines(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 261
assert token_count == 5
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -323,7 +336,7 @@ def __call__(self, *args, **kwargs):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 261
assert token_count == 5
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand All @@ -339,6 +352,7 @@ def __call__(self, *args, **kwargs):

@classmethod
def tearDownClass(cls):
os.remove(cls.data_src)
milvus_server.stop()
milvus_server.cleanup()

Expand Down

0 comments on commit cbfde4d

Please sign in to comment.