Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Mic/alt fix todos #58

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/exchange/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,8 @@ def generate(self) -> Message:
self.add(message)
self.add_checkpoints_from_usage(usage) # this has to come after adding the response

# TODO: also call `rewrite` here, as this will make our
# messages *consistently* below the token limit. this currently
# is not the case because we could append a large message after calling
# `rewrite` above.
# self.moderator.rewrite(self)
# also call `rewrite` here, as this will make our messages are consistently below the token limit.
self.moderator.rewrite(self)

_token_usage_collector.collect(self.model, usage)
return message
Expand Down
23 changes: 21 additions & 2 deletions src/exchange/moderators/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,38 @@
from exchange import Message
from exchange.checkpoint import CheckpointData
from exchange.moderators import ContextTruncate, PassiveModerator
from exchange.moderators.truncate import MAX_TOKENS

# this offset is used to prevent summarization from happening too frequently
SUMMARIZATION_OFFSET = 30000


class ContextSummarizer(ContextTruncate):
def __init__(
self,
model: str = "gpt-4o-mini",
max_tokens: int = MAX_TOKENS,
summarization_offset: int = SUMMARIZATION_OFFSET,
) -> None:
super().__init__(model=model, max_tokens=max_tokens)
self.summarization_offset = summarization_offset

def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821
"""Summarize the context history up to the last few messages in the exchange"""

self._update_system_prompt_token_count(exchange)

# TODO: use an offset for summarization
# note: we don't use the num_token_to_keep (defined below) here, because we only want to trigger
# summarization when we pass the max_token threshold. at that point, we will summarize more than
# we need to (using the offset), so we don't need to summarize on every ex.generate(...) call
if exchange.checkpoint_data.total_token_count < self.max_tokens:
return

messages_to_summarize = self._get_messages_to_remove(exchange)
assert self.summarization_offset < self.max_tokens

num_tokens_to_keep = self.max_tokens - self.summarization_offset

messages_to_summarize = self._get_messages_to_remove(exchange, max_tokens=num_tokens_to_keep)
num_messages_to_remove = len(messages_to_summarize)

# the llm will throw an error if the last message isn't a user message
Expand Down
11 changes: 7 additions & 4 deletions src/exchange/moderators/truncate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# so once we get to this token size the token count will exceed this
# by a little bit.
# TODO: make this configurable for each provider
MAX_TOKENS = 100000
MAX_TOKENS = 128000


class ContextTruncate(Moderator):
Expand Down Expand Up @@ -62,15 +62,18 @@ def _update_system_prompt_token_count(self, exchange: Exchange) -> None:
exchange.checkpoint_data.total_token_count -= last_system_prompt_token_count
exchange.checkpoint_data.total_token_count += self.system_prompt_token_count

def _get_messages_to_remove(self, exchange: Exchange) -> List[Message]:
def _get_messages_to_remove(self, exchange: Exchange, max_tokens: Optional[int] = None) -> List[Message]:
if not max_tokens:
max_tokens = self.max_tokens

# this keeps all the messages/checkpoints
throwaway_exchange = exchange.replace(
moderator=PassiveModerator(),
)

# get the messages that we want to remove
# get the messages that we want to summarize
messages_to_remove = []
while throwaway_exchange.checkpoint_data.total_token_count > self.max_tokens:
while throwaway_exchange.checkpoint_data.total_token_count > max_tokens:
_, messages = throwaway_exchange.pop_first_checkpoint()
messages_to_remove.extend(messages)

Expand Down
22 changes: 10 additions & 12 deletions tests/test_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def exchange_instance():

@pytest.fixture
def summarizer_instance():
return ContextSummarizer(max_tokens=300)
return ContextSummarizer(max_tokens=300, summarization_offset=100)


def test_context_summarizer_rewrite(exchange_instance: Exchange, summarizer_instance: ContextSummarizer):
Expand Down Expand Up @@ -149,7 +149,6 @@ def complete(self, model, system, messages, tools):
system_prompt_tokens = 100
input_token_count = system_prompt_tokens

message = self.sequence[self.current_index]
if self.summarize_next:
text = "Summary message here"
self.summarize_next = False
Expand All @@ -160,6 +159,7 @@ def complete(self, model, system, messages, tools):
output_tokens=len(text) * 2,
total_tokens=40 + len(text) * 2,
)
message = self.sequence[self.current_index]

if len(messages) > 0 and type(messages[0].content[0]) is ToolResult:
raise ValueError("ToolResult should not be the first message")
Expand Down Expand Up @@ -199,9 +199,7 @@ def conversation_exchange_instance():
provider=AnotherMockProvider(),
model="test-model",
system="test-system",
moderator=ContextSummarizer(max_tokens=300),
# TODO: make it work with an offset so we don't have to send off requests basically
# at every generate step
moderator=ContextSummarizer(max_tokens=300, summarization_offset=100),
)
return ex

Expand All @@ -215,13 +213,13 @@ def test_summarizer_generic_conversation(conversation_exchange_instance: Exchang
if message.text != "Summary message here":
i += 2
checkpoints = conversation_exchange_instance.checkpoint_data.checkpoints
assert conversation_exchange_instance.checkpoint_data.total_token_count == 570
assert len(checkpoints) == 10
assert len(conversation_exchange_instance.messages) == 10
assert checkpoints[0].start_index == 20
assert checkpoints[0].end_index == 20
assert conversation_exchange_instance.checkpoint_data.total_token_count == 148
assert len(checkpoints) == 4
assert len(conversation_exchange_instance.messages) == 4
assert checkpoints[0].start_index == 26
assert checkpoints[0].end_index == 26
assert checkpoints[-1].start_index == 29
assert checkpoints[-1].end_index == 29
assert conversation_exchange_instance.checkpoint_data.message_index_offset == 20
assert conversation_exchange_instance.provider.summarized_count == 12
assert conversation_exchange_instance.checkpoint_data.message_index_offset == 26
assert conversation_exchange_instance.provider.summarized_count == 10
assert conversation_exchange_instance.moderator.system_prompt_token_count == 100
2 changes: 1 addition & 1 deletion tests/test_truncate.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,5 @@ def test_truncate_on_generic_conversation(conversation_exchange_instance: Exchan
if message.text != "Summary message here":
i += 2
# ensure the total token count is not anything exhorbitant
assert conversation_exchange_instance.checkpoint_data.total_token_count < 700
assert conversation_exchange_instance.checkpoint_data.total_token_count < 500
assert conversation_exchange_instance.moderator.system_prompt_token_count == 100