Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

History summary #135

Merged
merged 2 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 53 additions & 3 deletions api/src/stampy_chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,14 @@ def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
class PrefixedPrompt(BaseChatPromptTemplate):
"""A prompt that will prefix any messages with a system prompt, but only if messages provided."""

transformer: Callable[[Any], BaseMessage] = lambda i: i
messages_field: str
prompt: str # the system prompt to be used

def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
history = kwargs[self.messages_field]
if history and self.prompt:
return [SystemMessage(content=self.prompt)] + history
return [SystemMessage(content=self.prompt)] + [self.transformer(i) for i in history]
return []


Expand Down Expand Up @@ -132,6 +133,12 @@ def prune(self) -> None:
pruned_memory, self.moving_summary_buffer
)

def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""
Because of how wonderfully LangChain is written, this method was blowing up.
It's not needed, so it's getting the chop.
"""


class ModeratedChatPrompt(ChatPromptTemplate):
"""Wraps a prompt with an OpenAI moderation check which will raise an exception if fails."""
Expand All @@ -157,6 +164,49 @@ def get_model(**kwargs):
return ChatOpenAI(openai_api_key=OPENAI_API_KEY, **kwargs)


class LLMInputsChain(LLMChain):

inputs: Dict[str, Any] = {}

def _call(self, inputs: Dict[str, Any], run_manager=None):
self.inputs = inputs
return super()._call(inputs, run_manager)

def _acall(self, inputs: Dict[str, Any], run_manager=None):
self.inputs = inputs
return super()._acall(inputs, run_manager)

def create_outputs(self, llm_result) -> List[Dict[str, Any]]:
result = super().create_outputs(llm_result)
return [dict(self.inputs, **r) for r in result]


def make_history_summary(settings):
model = get_model(
streaming=False,
max_tokens=settings.maxHistorySummaryTokens,
model=settings.completions
)
summary_prompt = PrefixedPrompt(
input_variables=['history'],
messages_field='history',
prompt=settings.history_summary_prompt,
transformer=lambda m: ChatMessage(**m)
)
return LLMInputsChain(
llm=model,
verbose=False,
output_key='history_summary',
prompt=ModeratedChatPrompt.from_messages([
summary_prompt,
ChatPromptTemplate.from_messages([
ChatMessagePromptTemplate.from_template(template='Q: {query}', role='user'),
]),
SystemMessage(content="Reply in one sentence only"),
]),
)


def make_prompt(settings, chat_model, callbacks):
"""Create a proper prompt object will all the nessesery steps."""
# 1. Create the context prompt from items fetched from pinecone
Expand All @@ -176,7 +226,7 @@ def make_prompt(settings, chat_model, callbacks):
query_prompt = ChatPromptTemplate.from_messages(
[
SystemMessage(content=settings.question_prompt),
ChatMessagePromptTemplate.from_template(template='Q: {query}', role='user'),
ChatMessagePromptTemplate.from_template(template='Q: {history_summary}: {query}', role='user'),
]
)

Expand Down Expand Up @@ -247,7 +297,7 @@ def run_query(session_id: str, query: str, history: List[Dict], settings: Settin
model=settings.completions
)

chain = LLMChain(
chain = make_history_summary(settings) | LLMChain(
llm=chat_model,
verbose=False,
prompt=make_prompt(settings, chat_model, callbacks),
Expand Down
2 changes: 1 addition & 1 deletion api/src/stampy_chat/followups.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class Config:

@property
def input_keys(self) -> List[str]:
return ['query', 'text']
return ['query', 'text', 'history_summary']

@property
def output_keys(self) -> List[str]:
Expand Down
28 changes: 25 additions & 3 deletions api/src/stampy_chat/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

SOURCE_PROMPT = (
"You are a helpful assistant knowledgeable about AI Alignment and Safety. "
"Please give a clear and coherent answer to the user's questions.(written after \"Q:\") "
"Please give a clear and coherent answer to the user's questions. (written after \"Q:\") "
"using the following sources. Each source is labeled with a letter. Feel free to "
"use the sources in any order, and try to use multiple sources in your answers.\n\n"
)
Expand All @@ -19,6 +19,13 @@
"These sources only apply to the last question. Any sources used in previous answers "
"are invalid."
)
HISTORY_SUMMARIZE_PROMPT = (
"You are a helpful assistant knowledgeable about AI Alignment and Safety. "
"Please summarize the following chat history (written after \"H:\") in one "
"sentence so as to put the current questions (written after \"Q:\") in context. "
"Please keep things as terse as possible."
"\nH:"
)

QUESTION_PROMPT = (
"In your answer, please cite any claims you make back to each source "
Expand Down Expand Up @@ -47,6 +54,7 @@
DEFAULT_PROMPTS = {
'context': SOURCE_PROMPT,
'history': HISTORY_PROMPT,
'history_summary': HISTORY_SUMMARIZE_PROMPT,
'question': QUESTION_PROMPT,
'modes': PROMPT_MODES,
}
Expand All @@ -72,8 +80,9 @@ def __init__(
topKBlocks=None,
maxNumTokens=None,
min_response_tokens=10,
tokensBuffer=50,
tokensBuffer=100,
maxHistory=10,
maxHistorySummaryTokens=200,
historyFraction=0.25,
contextFraction=0.5,
**_kwargs,
Expand All @@ -93,6 +102,9 @@ def __init__(
self.maxHistory = maxHistory
"""the max number of previous interactions to use as the history"""

self.maxHistorySummaryTokens = maxHistorySummaryTokens
"""the max number of tokens to be used on the history summary"""

self.historyFraction = historyFraction
"""the (approximate) fraction of num_tokens to use for history text before truncating"""

Expand Down Expand Up @@ -153,6 +165,10 @@ def context_prompt(self):
def history_prompt(self):
return self.prompts['history']

@property
def history_summary_prompt(self):
return self.prompts['history_summary']

@property
def mode_prompt(self):
return self.prompts['modes'].get(self.mode, '')
Expand All @@ -173,4 +189,10 @@ def history_tokens(self):

@property
def max_response_tokens(self):
return min(self.maxNumTokens - self.context_tokens - self.history_tokens, self.maxCompletionTokens)
available_tokens = (
self.maxNumTokens - self.maxHistorySummaryTokens -
self.context_tokens - len(self.encoder.encode(self.context_prompt)) -
self.history_tokens - len(self.encoder.encode(self.history_prompt)) -
len(self.encoder.encode(self.question_prompt))
)
return min(available_tokens, self.maxCompletionTokens)
2 changes: 2 additions & 0 deletions web/src/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ export const ChatResponse = ({
return <p>Loading: Sending query...</p>;
case "semantic":
return <p>Loading: Performing semantic search...</p>;
case "history":
return <p>Loading: Processing history...</p>;
case "context":
return <p>Loading: Creating context...</p>;
case "prompt":
Expand Down
16 changes: 16 additions & 0 deletions web/src/components/settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ export const ChatSettings = ({
max={settings.maxNumTokens}
updater={updateNum("tokensBuffer")}
/>
<NumberInput
field="maxHistorySummaryTokens"
value={settings.maxHistorySummaryTokens}
label="The max number of tokens to use for the history summary"
min="0"
max={settings.maxNumTokens}
updater={updateNum("maxHistorySummaryTokens")}
/>

<SectionHeader text="Prompt options" />
<NumberInput
Expand Down Expand Up @@ -178,6 +186,14 @@ export const ChatPrompts = ({

return (
<div className="chat-prompts mx-5 w-[400px] flex-none border-2 p-5 outline-black">
<details>
<summary>History summary prompt</summary>
<TextareaAutosize
className="border-gray w-full border px-1"
value={settings?.prompts?.history_summary}
onChange={updatePrompt("history_summary")}
/>
</details>
<details open>
<summary>Source prompt</summary>
<TextareaAutosize
Expand Down
7 changes: 7 additions & 0 deletions web/src/hooks/useSettings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ const DEFAULT_PROMPTS = {
'Before the question ("Q: "), there will be a history of previous questions and answers. ' +
"These sources only apply to the last question. any sources used in previous answers " +
"are invalid.",
history_summary:
"You are a helpful assistant knowledgeable about AI Alignment and Safety. " +
'Please summarize the following chat history (written after "H:") in one ' +
'sentence so as to put the current questions (written after "Q:") in context. ' +
"Please keep things as terse as possible." +
"\nH:",
question:
"In your answer, please cite any claims you make back to each source " +
"using the format: [a], [b], etc. If you use multiple sources to make a claim " +
Expand Down Expand Up @@ -131,6 +137,7 @@ const SETTINGS_PARSERS = {
maxNumTokens: withDefault(MODELS["gpt-3.5-turbo"]?.maxNumTokens),
tokensBuffer: withDefault(50), // the number of tokens to leave as a buffer when calculating remaining tokens
maxHistory: withDefault(10), // the max number of previous items to use as history
maxHistorySummaryTokens: withDefault(200), // the max number of tokens to use in the history summary
historyFraction: withDefault(0.25), // the (approximate) fraction of num_tokens to use for history text before truncating
contextFraction: withDefault(0.5), // the (approximate) fraction of num_tokens to use for context text before truncating
};
Expand Down