Skip to content

Commit

Permalink
index instabilit
Browse files Browse the repository at this point in the history
  • Loading branch information
Kav-K committed Nov 9, 2023
1 parent 3d52a9c commit f5e1498
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 29 deletions.
2 changes: 1 addition & 1 deletion gpt3discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from models.openai_model import Model


__version__ = "12.1.6"
__version__ = "12.1.7"


PID_FILE = Path("bot.pid")
Expand Down
62 changes: 40 additions & 22 deletions models/index_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import aiohttp
import discord
import aiofiles
import httpx
import openai
import tiktoken
from functools import partial
Expand Down Expand Up @@ -84,7 +85,6 @@
from models.openai_model import Models
from models.check_model import UrlCheck
from services.environment_service import EnvService

SHORT_TO_LONG_CACHE = {}
MAX_DEEP_COMPOSE_PRICE = EnvService.get_max_deep_compose_price()
EpubReader = download_loader("EpubReader")
Expand All @@ -101,11 +101,21 @@
text_splitter=TokenTextSplitter(chunk_size=1024, chunk_overlap=20)
)
callback_manager = CallbackManager([token_counter])
service_context = ServiceContext.from_defaults(
service_context_no_llm = ServiceContext.from_defaults(
embed_model=embedding_model,
callback_manager=callback_manager,
node_parser=node_parser,
)
timeout = httpx.Timeout(1, read=1, write=1, connect=1)

def get_service_context_with_llm(llm):
service_context = ServiceContext.from_defaults(
embed_model=embedding_model,
callback_manager=callback_manager,
node_parser=node_parser,
llm=llm,
)
return service_context


def dummy_tool(**kwargs):
Expand Down Expand Up @@ -400,21 +410,25 @@ async def index_chat_file(self, message: discord.Message, file: discord.Attachme
# Assert that the filename is < 100 characters, if it is greater, truncate to the first 100 characters and keep the original ending
if len(filename) > 100:
filename = filename[:100] + filename[-4:]
openai.log = "debug"

print("Indexing")
index: VectorStoreIndex = await self.loop.run_in_executor(
None,
partial(
self.index_file,
Path(temp_file.name),
service_context,
get_service_context_with_llm(self.index_chat_chains[message.channel.id].llm),
suffix,
),
)
print("Done Indexing")

summary = await index.as_query_engine(
similarity_top_k=10,
child_branch_factor=6,
child_branch_factor=3,
response_mode="tree_summarize",
service_context=get_service_context_with_llm(self.index_chat_chains[message.channel.id].llm)
).aquery(
f"What is a summary or general idea of this data? Be detailed in your summary (e.g "
f"extract key names, etc) but not too verbose. Your summary should be under a hundred words. "
Expand All @@ -427,7 +441,7 @@ async def index_chat_file(self, message: discord.Message, file: discord.Attachme
f"is no available data if there are no available tools that are relevant."
)

engine = self.get_query_engine(index, message, summary)
engine = self.get_query_engine(index, self.index_chat_chains[message.channel.id].llm)

# Get rid of all special characters in the filename
filename = "".join(
Expand Down Expand Up @@ -485,8 +499,7 @@ async def start_index_chat(self, ctx, model):
preparation_message = await ctx.channel.send(
embed=EmbedStatics.get_index_chat_preparation_message()
)

llm = ChatOpenAI(model=model, temperature=0)
llm = ChatOpenAI(model=model, temperature=0, max_retries=2)
llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0, model_name=model))

max_token_limit = 29000 if "gpt-4" in model else 7500
Expand All @@ -508,7 +521,7 @@ async def start_index_chat(self, ctx, model):
"to the data at the link by the time you respond. When using tools, the input should be "
"clearly created based on the request of the user. For example, if a user uploads an invoice "
"and asks how many usage hours of X was present in the invoice, a good query is 'X hours'. "
"Avoid using single word queries unless the request is very simple. You can query multiple times to break down complex requests and retrieve more information."
"Avoid using single word queries unless the request is very simple. You can query multiple times to break down complex requests and retrieve more information. When calling functions, no special characters are allowed in the function name, keep that in mind."
),
}

Expand Down Expand Up @@ -786,7 +799,7 @@ async def set_file_index(
partial(
self.index_file,
Path(temp_file.name),
service_context,
service_context_no_llm,
suffix,
),
)
Expand Down Expand Up @@ -858,7 +871,7 @@ async def set_link_index_recurse(
functools.partial(
GPTVectorStoreIndex,
documents=documents,
service_context=service_context,
service_context=service_context_no_llm,
use_async=True,
),
)
Expand Down Expand Up @@ -902,16 +915,16 @@ async def set_link_index_recurse(

await response.edit(embed=EmbedStatics.get_index_set_success_embed(price))

def get_query_engine(self, index, message, summary):
def get_query_engine(self, index, llm):
retriever = VectorIndexRetriever(
index=index, similarity_top_k=6, service_context=service_context
index=index, similarity_top_k=6, service_context=get_service_context_with_llm(llm)
)

response_synthesizer = get_response_synthesizer(
response_mode=ResponseMode.COMPACT_ACCUMULATE,
use_async=True,
refine_template=TEXT_QA_SYSTEM_PROMPT,
service_context=service_context,
service_context=get_service_context_with_llm(llm),
verbose=True,
)

Expand All @@ -924,15 +937,17 @@ def get_query_engine(self, index, message, summary):
async def index_link(self, link, summarize=False, index_chat_ctx=None):
try:
if await UrlCheck.check_youtube_link(link):
print("Indexing youtube transcript")
index = await self.loop.run_in_executor(
None, partial(self.index_youtube_transcript, link, service_context)
None, partial(self.index_youtube_transcript, link, service_context_no_llm)
)
print("Indexed youtube transcript")
elif "github" in link:
index = await self.loop.run_in_executor(
None, partial(self.index_github_repository, link, service_context)
None, partial(self.index_github_repository, link, service_context_no_llm)
)
else:
index = await self.index_webpage(link, service_context)
index = await self.index_webpage(link, service_context_no_llm)
except Exception as e:
if index_chat_ctx:
await index_chat_ctx.reply(
Expand All @@ -945,13 +960,16 @@ async def index_link(self, link, summarize=False, index_chat_ctx=None):
summary = None
if index_chat_ctx:
try:
print("Getting transcript summary")
summary = await index.as_query_engine(
response_mode="tree_summarize"
response_mode="tree_summarize",
service_context=get_service_context_with_llm(self.index_chat_chains[index_chat_ctx.channel.id].llm)
).aquery(
"What is a summary or general idea of this document? Be detailed in your summary but not too verbose. Your summary should be under 50 words. This summary will be used in a vector index to retrieve information about certain data. So, at a high level, the summary should describe the document in such a way that a retriever would know to select it when asked questions about it. The link was {link}. Include the an easy identifier derived from the link at the end of the summary."
)
print("Got transcript summary")

engine = self.get_query_engine(index, index_chat_ctx, summary)
engine = self.get_query_engine(index, self.index_chat_chains[index_chat_ctx.channel.id].llm)

# Get rid of all special characters in the link, replace periods with _
link_cleaned = "".join(
Expand Down Expand Up @@ -1078,7 +1096,7 @@ async def set_discord_index(
channel_ids=[channel.id], limit=message_limit, oldest_first=False
)
index = await self.loop.run_in_executor(
None, partial(self.index_discord, document, service_context)
None, partial(self.index_discord, document, service_context_no_llm)
)
try:
price = await self.usage_service.get_price(
Expand Down Expand Up @@ -1255,7 +1273,7 @@ async def compose_indexes(self, user_id, indexes, name, deep_compose):
partial(
GPTVectorStoreIndex.from_documents,
documents=documents,
service_context=service_context,
service_context=service_context_no_llm,
use_async=True,
),
)
Expand Down Expand Up @@ -1300,7 +1318,7 @@ async def backup_discord(
)

index = await self.loop.run_in_executor(
None, partial(self.index_discord, document, service_context)
None, partial(self.index_discord, document, service_context_no_llm)
)
await self.usage_service.update_usage(
token_counter.total_embedding_token_count, "embedding"
Expand Down Expand Up @@ -1362,7 +1380,7 @@ async def query(
response_mode,
nodes,
child_branch_factor,
service_context=service_context,
service_context=service_context_no_llm,
multistep=llm_predictor if multistep else None,
),
)
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ classifiers = [
]
dependencies = [
"Pillow==9.3.0",
"openai==1.1.0",
"openai==1.2.0",
"yt-dlp==2023.3.4",
"ffmpeg==1.4",
"beautifulsoup4==4.11.2",
Expand All @@ -40,7 +40,7 @@ dependencies = [
"sentencepiece==0.1.99",
"protobuf==3.20.2",
"python-pptx==0.6.21",
"langchain==0.0.332",
"langchain==0.0.333",
"unidecode==1.3.6",
"tqdm==4.64.1",
"docx2txt==0.8",
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Pillow==9.3.0
openai==1.1.0
openai==1.2.0
yt-dlp==2023.3.4
ffmpeg==1.4
py-cord==2.4.1
Expand All @@ -21,7 +21,7 @@ sentencepiece==0.1.99
protobuf==3.20.2
python-pptx==0.6.21
sentence-transformers==2.2.2
langchain==0.0.332
langchain==0.0.333
openai-whisper
unidecode==1.3.6
tqdm==4.64.1
Expand Down
4 changes: 2 additions & 2 deletions requirements_base.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Pillow==9.3.0
openai==1.1.0
openai==1.2.0
yt-dlp==2023.3.4
ffmpeg==1.4
py-cord==2.4.1
Expand All @@ -20,7 +20,7 @@ youtube_transcript_api==0.5.0
sentencepiece==0.1.99
protobuf==3.20.2
python-pptx==0.6.21
langchain==0.0.332
langchain==0.0.333
unidecode==1.3.6
tqdm==4.64.1
docx2txt==0.8
Expand Down

0 comments on commit f5e1498

Please sign in to comment.