Skip to content

Commit

Permalink
feat: add the rag example @LongxingTan (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan authored Mar 16, 2024
1 parent a67c731 commit 52360d7
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 12 deletions.
49 changes: 43 additions & 6 deletions examples/rag_langchain.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
import torch
import transformers
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFacePipeline
from langchain.retrievers import ContextualCompressionRetriever
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain_community.vectorstores.utils import DistanceStrategy
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig

from retrievals import AutoModelForEmbedding, RerankModel
from retrievals.tools import LangchainReranker, RagFeature
from retrievals.tools.langchain import LangchainEmbedding, LangchainReranker, RagFeature

embed_model = AutoModelForEmbedding(model_name_or_path='')
rerank_model = LangchainReranker(model_name_or_path='', top_n=5, device='cuda')

class CFG:
retrieval_model = 'BAAI/bge-large-zh'
rerank_model = ''
llm_model = 'Qwen/Qwen-7B-Chat'


embed_model = LangchainEmbedding(model_name_or_path=CFG.retrieval_model)
rerank_model = LangchainReranker(model_name_or_path=CFG.rerank_model, top_n=5, device='cuda')


documents = PyPDFLoader("llama.pdf").load()
Expand All @@ -19,5 +31,30 @@
search_type="similarity", search_kwargs={"score_threshold": 0.3, "k": 10}
)

compression_retriever = ContextualCompressionRetriever(base_compressor=rerank_model, base_retriever=retriever)
response = compression_retriever.get_relevant_documents("What is Llama 2?")
# compression_retriever = ContextualCompressionRetriever(base_compressor=rerank_model, base_retriever=retriever)
# response = compression_retriever.get_relevant_documents("What is Llama 2?")


tokenizer = AutoTokenizer.from_pretrained(CFG.llm_model, trust_remote_code=True)
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}
model = AutoModelForCausalLM.from_pretrained(
CFG.llm_model, device_map='auto', load_in_4bit=True, max_memory=max_memory, trust_remote_code=True, fp16=True
)
model = model.eval()
model.generation_config = GenerationConfig.from_pretrained(CFG.llm_model, trust_remote_code=True)

query_pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.float16,
device_map="auto",
)

llm = HuggingFacePipeline(pipeline=query_pipeline)

qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, verbose=True)

qa.run('你看了这篇文章后有何感性?')
32 changes: 28 additions & 4 deletions src/retrievals/models/embedding_auto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union

import faiss
import numpy as np
Expand Down Expand Up @@ -60,6 +60,9 @@ class AutoModelForEmbedding(nn.Module):
from the Hugging Face Hub with that name.
"""

encode_kwargs: Dict[str, Any] = dict()
show_progress: bool = False

def __init__(
self,
model_name_or_path: str,
Expand Down Expand Up @@ -184,7 +187,17 @@ def forward_from_loader(self, inputs):
return embeddings

def forward_from_text(self, texts):
return self.forward_from_loader(texts)
batch_dict = self.tokenizer(
texts,
max_length=self.max_length,
return_attention_mask=False,
padding=False,
truncation=True,
)
batch_dict["input_ids"] = [input_ids + [self.tokenizer.eos_token_id] for input_ids in batch_dict["input_ids"]]
batch_dict = self.tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors="pt")
batch_dict.pop("token_type_ids")
return self.forward_from_loader(batch_dict)

def encode(
self,
Expand All @@ -197,7 +210,7 @@ def encode(
device: str = None,
normalize_embeddings: bool = False,
):
if isinstance(inputs, DataLoader):
if isinstance(inputs, (BatchEncoding, Dict)):
return self.encode_from_loader(
loader=inputs,
batch_size=batch_size,
Expand All @@ -208,7 +221,7 @@ def encode(
device=device,
normalize_embeddings=normalize_embeddings,
)
elif isinstance(inputs, (str, Iterable)):
elif isinstance(inputs, (str, List, Tuple)):
return self.encode_from_text(
sentences=inputs,
batch_size=batch_size,
Expand All @@ -219,6 +232,17 @@ def encode(
device=device,
normalize_embeddings=normalize_embeddings,
)
else:
raise ValueError

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model."""
embeddings = self.encode(texts, show_progress_bar=self.show_progress, **self.encode_kwargs)
return embeddings.tolist()

def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model."""
return self.embed_documents([text])[0]

def encode_from_loader(
self,
Expand Down
46 changes: 46 additions & 0 deletions src/retrievals/models/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeVar, Union

from transformers import AutoModel


class RAG(object):
def __init__(self):
pass

@classmethod
def from_pretrained(
cls,
model_name_or_path: Union[str, Path],
n_gpu: int = -1,
verbose: int = 1,
index_root: Optional[str] = None,
):
instance = cls()
instance.model = AutoModel()
return instance

@classmethod
def from_index(cls, index_path: Union[str, Path], n_gpu: int = -1, verbose: int = 1):
instance = cls()
index_path = Path(index_path)
instance.model = AutoModel()

return instance

def add_to_index(self):
return

def encode(self):
return

def index(self):
return

def search(self):
return


class Generator(object):
def __init__(self):
pass
2 changes: 0 additions & 2 deletions src/retrievals/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
from src.retrievals.tools.langchain import LangchainReranker, RagFeature
from src.retrievals.tools.llama_index import LlamaIndexReranker
3 changes: 3 additions & 0 deletions src/retrievals/tools/corpus_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class CorpusProcessor(object):
def __init__(self):
pass
7 changes: 7 additions & 0 deletions src/retrievals/tools/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,18 @@
MarkdownTextSplitter,
)
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Extra, root_validator

from src.retrievals.models.embedding_auto import AutoModelForEmbedding
from src.retrievals.models.rerank import RerankModel


class LangchainEmbedding(AutoModelForEmbedding, Embeddings):
def __init__(self, **kwargs):
super().__init__(**kwargs)


class LangchainReranker(BaseDocumentCompressor):
class Config:
"""Configuration for this pydantic object."""
Expand Down
3 changes: 3 additions & 0 deletions tests/test_models/test_embedding_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ def test_encode_from_text(self):
assert emb.shape == (3, 384)
# assert abs(np.sum(emb) - 32.14627) < 0.001

def test_forward_from_text(self):
pass


class PairwiseModelTest(TestCase, ModelTesterMixin):
def setUp(self) -> None:
Expand Down

0 comments on commit 52360d7

Please sign in to comment.