Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Text2Vec加载BAAI_bge-large-zh-v1.5进行相似度计算得到的结果有误 #168

Open
sslalxh opened this issue Apr 1, 2024 · 2 comments

Comments

@sslalxh
Copy link

sslalxh commented Apr 1, 2024

提问时请尽可能提供如下信息:

基本信息

  • 你使用的操作系统: macOS 14.2.1
  • 你使用的Python版本: 3.8.16
  • 你使用的Pytorch版本: torch 2.1.0
  • 你使用的bert4torch版本: 0.4.9.post2
  • 你加载的预训练模型:bge-large-zh-v1.5

核心代码

sentences_1 = ["样例数据-1", "样例数据-2"]
    sentences_2 = ["样例数据-3", "样例数据-4"]
    print(root_model_path)
    
    print('=========================================sentence transformer====================================')
    from sentence_transformers import SentenceTransformer

    model = SentenceTransformer(root_model_path)
    embeddings_1 = model.encode(sentences_1, normalize_embeddings=True)
    embeddings_2 = model.encode(sentences_2, normalize_embeddings=True)
    similarity = embeddings_1 @ embeddings_2.T
    print(similarity)

    print('=========================================bert4torch====================================')
    from bert4torch.pipelines import Text2Vec

    text2vec = Text2Vec(checkpoint_path=root_model_path, device='mps')
    embeddings_1 = text2vec.encode(sentences_1, normalize_embeddings=True)
    embeddings_2 = text2vec.encode(sentences_2, normalize_embeddings=True)
    similarity = embeddings_1 @ embeddings_2.T
    print(similarity)

命令行打印的信息

model/BAAI_bge-large-zh-v1.5
=========================================sentence transformer====================================
[[0.85533315 0.8520633 ]
 [0.87456286 0.8557935 ]]
=========================================bert4torch====================================
Loading checkpoint shards: 100%|██████████| 1/1 [00:01<00:00,  1.16s/it]
[[0.9882161  0.98629344]
 [0.9826301  0.9869268 ]]

想问问为什么用Text2Vec加载模型encode并计算相似度会与SentenceTransformer的结果不一致,尝试了其他sentence,通过Text2Vec的计算结果都是大于0.9,求教

@Tongjilibo
Copy link
Owner

Tongjilibo commented Apr 3, 2024

我刚试了下,指定文件路径,两边结果是一致的啊,bert4torch_config.json文件地址是https://huggingface.co/Tongjilibo/bert4torch_config/tree/main/bge-large-zh-v1.5 但是也有问题,就是指定model_name的时候有点bug,我需要修复一下

Tongjilibo added a commit that referenced this issue Apr 3, 2024
@Tongjilibo
Copy link
Owner

pip install git+https://github.com/Tongjilibo/torch4keras.git
pip install git+https://github.com/Tongjilibo/bert4torch.git@dev

我修改了一下逻辑,可以用以上命令pip install最新版,下面为调用代码

import os
os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"
# root_model_path = 'E:\pretrain_ckpt\embedding\[email protected]'
root_model_path = 'BAAI/bge-large-zh-v1.5'
# root_model_path = '/data/pretrain_ckpt/embedding/BAAI--bge-large-zh-v1.5'

sentences_1 = ["样例数据-1", "样例数据-2"]
sentences_2 = ["样例数据-3", "样例数据-4"]

print('=========================================sentence transformer====================================')
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(root_model_path)
embeddings_1 = model.encode(sentences_1, normalize_embeddings=True)
embeddings_2 = model.encode(sentences_2, normalize_embeddings=True)
similarity = embeddings_1 @ embeddings_2.T
print(similarity)


print('=========================================bert4torch====================================')
from bert4torch.pipelines import Text2Vec
text2vec = Text2Vec(checkpoint_path=root_model_path, device='cuda')
embeddings_1 = text2vec.encode(sentences_1, normalize_embeddings=True)
embeddings_2 = text2vec.encode(sentences_2, normalize_embeddings=True)
similarity = embeddings_1 @ embeddings_2.T
print(similarity)
# 输出结果
=========================================sentence transformer====================================
[[0.85533345 0.8520633 ]
 [0.8745628  0.8557939 ]]
=========================================bert4torch====================================
Loading checkpoint shards: 100%|█████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.57it/s]
[[0.85533345 0.8520633 ]
 [0.8745628  0.8557937 ]]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants