diff --git a/config.py b/config.py index 2c3bf42..2f8e305 100644 --- a/config.py +++ b/config.py @@ -5,7 +5,7 @@ ################## LLM ################## LLM_OPTION = os.getenv('LLM_OPTION', 'openai') # select your LLM service -LANGUAGE = 'en' # options: en, zh +LANGUAGE = os.getenv('LANGUAGE', 'en') # options: en, zh CHAT_CONFIG = { 'openai': { 'openai_model': 'gpt-3.5-turbo', @@ -57,7 +57,7 @@ ################## Embedding ################## TEXTENCODER_CONFIG = { - 'model': 'multi-qa-mpnet-base-cos-v1', + 'model': f'BAAI/bge-base-{LANGUAGE}', 'device': -1, # -1 will use cpu 'norm': True, 'dim': 768 @@ -108,10 +108,17 @@ ############### Rerank configs ################## +if LANGUAGE == 'en' and INSERT_MODE == 'osschat-insert': + rerank_model = 'cross-encoder/ms-marco-MiniLM-L-12-v2' +elif LANGUAGE == 'zh' and INSERT_MODE == 'osschat-insert': + rerank_model = 'amberoad/bert-multilingual-passage-reranking-msmarco' +else: + raise NotImplementedError + RERANK_CONFIG = { - 'rerank': True, - 'rerank_model': 'cross-encoder/ms-marco-MiniLM-L-12-v2', - 'threshold': 0.6, + 'rerank': True, # or False + 'rerank_model': rerank_model, + 'threshold': 0.0, 'rerank_device': -1 # -1 will use cpu } @@ -123,4 +130,4 @@ QUESTIONGENERATOR_CONFIG = { 'model_name': 'gpt-3.5-turbo', 'temperature': 0, -} \ No newline at end of file +} diff --git a/tests/unit_tests/src_towhee/pipelines/test_pipelines.py b/tests/unit_tests/src_towhee/pipelines/test_pipelines.py index 082ed56..1f84e24 100644 --- a/tests/unit_tests/src_towhee/pipelines/test_pipelines.py +++ b/tests/unit_tests/src_towhee/pipelines/test_pipelines.py @@ -84,7 +84,7 @@ def test_openai(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 290 + assert token_count == 261 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -121,7 +121,7 @@ def test_chatglm(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 290 + assert token_count == 261 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -160,7 +160,7 @@ def json(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 290 + assert token_count == 261 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -205,7 +205,7 @@ def output(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 290 + assert token_count == 261 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -244,7 +244,7 @@ def json(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 290 + assert token_count == 261 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -285,7 +285,7 @@ def iter_lines(self): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 290 + assert token_count == 261 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num @@ -323,7 +323,7 @@ def __call__(self, *args, **kwargs): token_count = 0 for x in res: token_count += x[0]['token_count'] - assert token_count == 290 + assert token_count == 261 num = pipelines.count_entities(self.project)['vector store'] assert len(res) <= num