Skip to content

Commit

Permalink
storages
Browse files Browse the repository at this point in the history
  • Loading branch information
skyline2006 committed Jun 18, 2024
1 parent 224e575 commit b254770
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 26 deletions.
10 changes: 0 additions & 10 deletions modelscope_agent/memory/memory_with_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,12 @@ def _run(self,
query: str = None,
url: str = None,
max_token: int = 4000,
<<<<<<< HEAD
top_k: int = 3,
=======
>>>>>>> origin
**kwargs) -> Union[str, Iterator[str]]:
if isinstance(url, str):
url = [url]
if url and len(url):
self.store_knowledge.add(files=url)
if query:
<<<<<<< HEAD
summary_result = self.store_knowledge.run(query, files=url)
# limit length
return summary_result[0:max_token - 1]
=======
summary_result = self.store_knowledge.run(
query, files=url, **kwargs)
# limit length
Expand All @@ -69,4 +60,3 @@ def _run(self,
return concatenated_records
else:
return summary_result[0:max_token - 1]
>>>>>>> origin
107 changes: 94 additions & 13 deletions modelscope_agent/rag/knowledge.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import inspect
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Type, Union

import fsspec
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.graph_stores.types import GraphStore
from llama_index.core.indices.base import BaseIndex
from llama_index.core.llama_pack.base import BaseLlamaPack
from llama_index.core.llms.llm import LLM
from llama_index.core.postprocessor.types import BaseNodePostprocessor
Expand All @@ -13,7 +16,10 @@
from llama_index.core.schema import (Document, MetadataMode, QueryBundle,
TransformComponent)
from llama_index.core.settings import Settings
from llama_index.core.vector_stores.types import (MetadataFilter,
from llama_index.core.storage.docstore.types import BaseDocumentStore
from llama_index.core.storage.index_store.types import BaseIndexStore
from llama_index.core.vector_stores.types import (BasePydanticVectorStore,
MetadataFilter,
MetadataFilters)
from llama_index.legacy.core.embeddings.base import BaseEmbedding
from modelscope_agent.llm import get_chat_model
Expand Down Expand Up @@ -63,13 +69,30 @@ def __init__(self,
transformations: List[Type[TransformComponent]] = [],
post_processors: List[Type[BaseNodePostprocessor]] = [],
use_cache: bool = True,
docstore: Union[BaseDocumentStore, Type[BaseDocumentStore],
None] = None,
index_store: Union[BaseIndexStore, Type[BaseIndexStore],
None] = None,
vector_store: Union[BasePydanticVectorStore,
Type[BasePydanticVectorStore],
None] = None,
image_store: Union[BasePydanticVectorStore,
Type[BasePydanticVectorStore],
None] = None,
graph_store: Union[GraphStore, Type[GraphStore], None] = None,
**kwargs) -> None:
self.retriever_cls = retriever
self.cache_dir = cache_dir
# self.register_files(files) # TODO: file manager
self.extra_readers = self.get_extra_readers(loaders)
self.embed_model = self.get_emb_model(emb)
Settings._embed_model = self.embed_model
docstore = self.get_storage(docstore)
index_store = self.get_storage(index_store)
vector_store = self.get_storage(vector_store)
image_store = self.get_storage(image_store)
graph_store = self.get_storage(graph_store)

documents = None
if not use_cache:
documents = self.read(files)
Expand All @@ -84,12 +107,31 @@ def __init__(self,
**kwargs)

root_retriever = self.get_root_retriever(
documents, use_cache=use_cache, **kwargs)
documents,
use_cache=use_cache,
docstore=docstore,
index_store=index_store,
vector_store=vector_store,
image_store=image_store,
graph_store=graph_store,
**kwargs)

self.query_engine = None
if root_retriever:
self.query_engine = self.get_query_engine(root_retriever, **kwargs)

def get_storage(
self, storage_or_cls: Union[BaseDocumentStore, Type[BaseDocumentStore],
BaseIndexStore, Type[BaseIndexStore],
BasePydanticVectorStore,
Type[BasePydanticVectorStore], GraphStore,
Type[GraphStore], None]
) -> Union[BaseDocumentStore, BaseIndexStore, BasePydanticVectorStore,
GraphStore, None]:
if inspect.isclass(storage_or_cls):
return storage_or_cls()
return storage_or_cls

def get_llm(self, llm: Union[LLM, BaseChatModel, Dict]) -> LLM:
llama_index_llm = None
if llm and isinstance(llm, BaseChatModel):
Expand Down Expand Up @@ -168,13 +210,16 @@ def get_postprocessors(

return res

def get_root_retriever(self,
documents: List[Document],
use_cache: bool = True,
**kwargs) -> BaseRetriever:

def get_index(self,
documents: List[Document],
use_cache: bool = True,
docstore: Optional[BaseDocumentStore] = None,
index_store: Optional[BaseIndexStore] = None,
vector_store: Optional[BasePydanticVectorStore] = None,
image_store: Optional[BasePydanticVectorStore] = None,
graph_store: Optional[GraphStore] = None,
**kwargs) -> BaseIndex:
# indexing
# 可配置chunk_size等
Settings.chunk_size = 512
index = None
if use_cache:
Expand All @@ -184,6 +229,11 @@ def get_root_retriever(self,
from llama_index.core import StorageContext, load_index_from_storage
# rebuild storage context
storage_context = StorageContext.from_defaults(
docstore=docstore,
index_store=index_store,
vector_store=vector_store,
image_store=image_store,
graph_store=graph_store,
persist_dir=self.cache_dir)
# load index

Expand All @@ -209,6 +259,27 @@ def get_root_retriever(self,

if self.cache_dir is not None:
index.storage_context.persist(persist_dir=self.cache_dir)
return index

def get_root_retriever(
self,
documents: List[Document],
use_cache: bool = True,
docstore: Optional[BaseDocumentStore] = None,
index_store: Optional[BaseIndexStore] = None,
vector_store: Optional[BasePydanticVectorStore] = None,
image_store: Optional[BasePydanticVectorStore] = None,
graph_store: Optional[GraphStore] = None,
**kwargs) -> BaseRetriever:
index = self.get_index(
documents=documents,
use_cache=use_cache,
docstore=docstore,
index_store=index_store,
vector_store=vector_store,
image_store=image_store,
graph_store=graph_store,
**kwargs)

# init retriever tool
if self.retriever_cls:
Expand Down Expand Up @@ -344,8 +415,18 @@ def add(self, files: List[str]):
llm_config = {'model': 'qwen-max', 'model_server': 'dashscope'}
llm = get_chat_model(**llm_config)

knowledge = BaseKnowledge('./data2', use_cache=False, llm=llm)

knowledge.add(['./data/常见QA.pdf'])
print(knowledge.run('高德天气API申请', files=['常见QA.pdf'], use_llm=False))

from llama_index.storage.docstore.mongodb import MongoDocumentStore
from llama_index.storage.index_store.mongodb import MongoIndexStore
MONGO_URI = 'mongodb://localhost'
knowledge = BaseKnowledge(
'./data2',
use_cache=True,
llm=llm,
docstore=MongoDocumentStore.from_uri(MONGO_URI),
index_store=MongoIndexStore.from_uri(MONGO_URI))

# knowledge.add(['./data/常见QA.pdf'])
res = knowledge.run(
'Who decided to compile a book of interviews with startup founders?')
#res = knowledge.run('高德天气API申请', files=['常见QA.pdf'], use_llm=False)
print(res)
3 changes: 0 additions & 3 deletions tests/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def test_memory_with_rag_multi_modal():
summary_str = memory.run('我想看rag的流程图')
print(summary_str)
assert 'rag.png' in summary_str
<<<<<<< HEAD
=======


def test_memory_with_rag_no_use_llm():
Expand All @@ -92,4 +90,3 @@ def test_memory_with_rag_no_use_llm():
print(summary_str)
assert 'file_path' in summary_str
assert 'git-lfs' in summary_str
>>>>>>> origin

0 comments on commit b254770

Please sign in to comment.