Skip to content

Commit

Permalink
Update config (#82)
Browse files Browse the repository at this point in the history
Signed-off-by: Jael Gu <[email protected]>
  • Loading branch information
jaelgu authored Sep 26, 2023
1 parent ce5c97f commit a7a4a21
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
19 changes: 13 additions & 6 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

################## LLM ##################
LLM_OPTION = os.getenv('LLM_OPTION', 'openai') # select your LLM service
LANGUAGE = 'en' # options: en, zh
LANGUAGE = os.getenv('LANGUAGE', 'en') # options: en, zh
CHAT_CONFIG = {
'openai': {
'openai_model': 'gpt-3.5-turbo',
Expand Down Expand Up @@ -57,7 +57,7 @@

################## Embedding ##################
TEXTENCODER_CONFIG = {
'model': 'multi-qa-mpnet-base-cos-v1',
'model': f'BAAI/bge-base-{LANGUAGE}',
'device': -1, # -1 will use cpu
'norm': True,
'dim': 768
Expand Down Expand Up @@ -108,10 +108,17 @@


############### Rerank configs ##################
if LANGUAGE == 'en' and INSERT_MODE == 'osschat-insert':
rerank_model = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
elif LANGUAGE == 'zh' and INSERT_MODE == 'osschat-insert':
rerank_model = 'amberoad/bert-multilingual-passage-reranking-msmarco'
else:
raise NotImplementedError

RERANK_CONFIG = {
'rerank': True,
'rerank_model': 'cross-encoder/ms-marco-MiniLM-L-12-v2',
'threshold': 0.6,
'rerank': True, # or False
'rerank_model': rerank_model,
'threshold': 0.0,
'rerank_device': -1 # -1 will use cpu
}

Expand All @@ -123,4 +130,4 @@
QUESTIONGENERATOR_CONFIG = {
'model_name': 'gpt-3.5-turbo',
'temperature': 0,
}
}
14 changes: 7 additions & 7 deletions tests/unit_tests/src_towhee/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_openai(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 290
assert token_count == 261
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -121,7 +121,7 @@ def test_chatglm(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 290
assert token_count == 261
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -160,7 +160,7 @@ def json(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 290
assert token_count == 261
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -205,7 +205,7 @@ def output(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 290
assert token_count == 261
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -244,7 +244,7 @@ def json(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 290
assert token_count == 261
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -285,7 +285,7 @@ def iter_lines(self):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 290
assert token_count == 261
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -323,7 +323,7 @@ def __call__(self, *args, **kwargs):
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 290
assert token_count == 261
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down

0 comments on commit a7a4a21

Please sign in to comment.