From c91a2ad8b8c263dfae0ad43bb5600e109c2e7387 Mon Sep 17 00:00:00 2001 From: Luke Alvoeiro Date: Mon, 23 Sep 2024 12:21:40 -0700 Subject: [PATCH 1/6] feat: offset for summarizer moderator --- src/exchange/moderators/summarizer.py | 23 +++++++++++++++++++++-- src/exchange/moderators/truncate.py | 9 ++++++--- tests/test_summarizer.py | 16 +++++++--------- 3 files changed, 34 insertions(+), 14 deletions(-) diff --git a/src/exchange/moderators/summarizer.py b/src/exchange/moderators/summarizer.py index 7e2dd55..b05e31a 100644 --- a/src/exchange/moderators/summarizer.py +++ b/src/exchange/moderators/summarizer.py @@ -3,19 +3,38 @@ from exchange import Message from exchange.checkpoint import CheckpointData from exchange.moderators import ContextTruncate, PassiveModerator +from exchange.moderators.truncate import MAX_TOKENS + +# this offset is used to prevent summarization from happening too frequently +SUMMARIZATION_OFFSET = 30000 class ContextSummarizer(ContextTruncate): + def __init__( + self, + model: str = "gpt-4o-mini", + max_tokens: int = MAX_TOKENS, + summarization_offset: int = SUMMARIZATION_OFFSET, + ) -> None: + super().__init__(model=model, max_tokens=max_tokens) + self.summarization_offset = summarization_offset + def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821 """Summarize the context history up to the last few messages in the exchange""" self._update_system_prompt_token_count(exchange) - # TODO: use an offset for summarization + # note: we don't use the num_token_to_keep (defined below) here, because we only want to trigger + # summarization when we pass the max_token threshold. at that point, we will summarize more than + # we need to (using the offset), so we don't need to summarize on every ex.generate(...) call if exchange.checkpoint_data.total_token_count < self.max_tokens: return - messages_to_summarize = self._get_messages_to_remove(exchange) + assert self.summarization_offset < self.max_tokens + + num_tokens_to_keep = self.max_tokens - self.summarization_offset + + messages_to_summarize = self._get_messages_to_remove(exchange, max_tokens=num_tokens_to_keep) num_messages_to_remove = len(messages_to_summarize) # the llm will throw an error if the last message isn't a user message diff --git a/src/exchange/moderators/truncate.py b/src/exchange/moderators/truncate.py index 41115f6..f0c5059 100644 --- a/src/exchange/moderators/truncate.py +++ b/src/exchange/moderators/truncate.py @@ -62,15 +62,18 @@ def _update_system_prompt_token_count(self, exchange: Exchange) -> None: exchange.checkpoint_data.total_token_count -= last_system_prompt_token_count exchange.checkpoint_data.total_token_count += self.system_prompt_token_count - def _get_messages_to_remove(self, exchange: Exchange) -> List[Message]: + def _get_messages_to_remove(self, exchange: Exchange, max_tokens: Optional[int] = None) -> List[Message]: + if not max_tokens: + max_tokens = self.max_tokens + # this keeps all the messages/checkpoints throwaway_exchange = exchange.replace( moderator=PassiveModerator(), ) - # get the messages that we want to remove + # get the messages that we want to summarize messages_to_remove = [] - while throwaway_exchange.checkpoint_data.total_token_count > self.max_tokens: + while throwaway_exchange.checkpoint_data.total_token_count > max_tokens: _, messages = throwaway_exchange.pop_first_checkpoint() messages_to_remove.extend(messages) diff --git a/tests/test_summarizer.py b/tests/test_summarizer.py index fa72819..ddac0a2 100644 --- a/tests/test_summarizer.py +++ b/tests/test_summarizer.py @@ -49,7 +49,7 @@ def exchange_instance(): @pytest.fixture def summarizer_instance(): - return ContextSummarizer(max_tokens=300) + return ContextSummarizer(max_tokens=300, summarization_offset=100) def test_context_summarizer_rewrite(exchange_instance: Exchange, summarizer_instance: ContextSummarizer): @@ -199,9 +199,7 @@ def conversation_exchange_instance(): provider=AnotherMockProvider(), model="test-model", system="test-system", - moderator=ContextSummarizer(max_tokens=300), - # TODO: make it work with an offset so we don't have to send off requests basically - # at every generate step + moderator=ContextSummarizer(max_tokens=300, summarization_offset=100), ) return ex @@ -215,11 +213,11 @@ def test_summarizer_generic_conversation(conversation_exchange_instance: Exchang if message.text != "Summary message here": i += 2 checkpoints = conversation_exchange_instance.checkpoint_data.checkpoints - assert conversation_exchange_instance.checkpoint_data.total_token_count == 570 - assert len(checkpoints) == 10 - assert len(conversation_exchange_instance.messages) == 10 - assert checkpoints[0].start_index == 20 - assert checkpoints[0].end_index == 20 + assert conversation_exchange_instance.checkpoint_data.total_token_count == 412 + assert len(checkpoints) == 5 + assert len(conversation_exchange_instance.messages) == 5 + assert checkpoints[0].start_index == 25 + assert checkpoints[0].end_index == 25 assert checkpoints[-1].start_index == 29 assert checkpoints[-1].end_index == 29 assert conversation_exchange_instance.checkpoint_data.message_index_offset == 20 From 911cd645498c5fe801b44046f103024f84190025 Mon Sep 17 00:00:00 2001 From: Luke Alvoeiro Date: Mon, 23 Sep 2024 12:41:45 -0700 Subject: [PATCH 2/6] feat: always stay below token count --- src/exchange/exchange.py | 7 ++----- tests/test_summarizer.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/src/exchange/exchange.py b/src/exchange/exchange.py index feece4f..008ee7a 100644 --- a/src/exchange/exchange.py +++ b/src/exchange/exchange.py @@ -81,11 +81,8 @@ def generate(self) -> Message: self.add(message) self.add_checkpoints_from_usage(usage) # this has to come after adding the response - # TODO: also call `rewrite` here, as this will make our - # messages *consistently* below the token limit. this currently - # is not the case because we could append a large message after calling - # `rewrite` above. - # self.moderator.rewrite(self) + # also call `rewrite` here, as this will make our messages are consistently below the token limit. + self.moderator.rewrite(self) _token_usage_collector.collect(self.model, usage) return message diff --git a/tests/test_summarizer.py b/tests/test_summarizer.py index ddac0a2..701a390 100644 --- a/tests/test_summarizer.py +++ b/tests/test_summarizer.py @@ -149,7 +149,6 @@ def complete(self, model, system, messages, tools): system_prompt_tokens = 100 input_token_count = system_prompt_tokens - message = self.sequence[self.current_index] if self.summarize_next: text = "Summary message here" self.summarize_next = False @@ -160,6 +159,7 @@ def complete(self, model, system, messages, tools): output_tokens=len(text) * 2, total_tokens=40 + len(text) * 2, ) + message = self.sequence[self.current_index] if len(messages) > 0 and type(messages[0].content[0]) is ToolResult: raise ValueError("ToolResult should not be the first message") @@ -213,13 +213,13 @@ def test_summarizer_generic_conversation(conversation_exchange_instance: Exchang if message.text != "Summary message here": i += 2 checkpoints = conversation_exchange_instance.checkpoint_data.checkpoints - assert conversation_exchange_instance.checkpoint_data.total_token_count == 412 - assert len(checkpoints) == 5 - assert len(conversation_exchange_instance.messages) == 5 - assert checkpoints[0].start_index == 25 - assert checkpoints[0].end_index == 25 + assert conversation_exchange_instance.checkpoint_data.total_token_count == 148 + assert len(checkpoints) == 4 + assert len(conversation_exchange_instance.messages) == 4 + assert checkpoints[0].start_index == 26 + assert checkpoints[0].end_index == 26 assert checkpoints[-1].start_index == 29 assert checkpoints[-1].end_index == 29 - assert conversation_exchange_instance.checkpoint_data.message_index_offset == 20 - assert conversation_exchange_instance.provider.summarized_count == 12 + assert conversation_exchange_instance.checkpoint_data.message_index_offset == 26 + assert conversation_exchange_instance.provider.summarized_count == 10 assert conversation_exchange_instance.moderator.system_prompt_token_count == 100 From 3a4be82966c07222f5ad704ed64249f945827f05 Mon Sep 17 00:00:00 2001 From: Luke Alvoeiro Date: Mon, 23 Sep 2024 12:42:53 -0700 Subject: [PATCH 3/6] fix: increase token limit since we will now consistently stay below it --- src/exchange/moderators/truncate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/exchange/moderators/truncate.py b/src/exchange/moderators/truncate.py index f0c5059..e80b8d2 100644 --- a/src/exchange/moderators/truncate.py +++ b/src/exchange/moderators/truncate.py @@ -14,7 +14,7 @@ # so once we get to this token size the token count will exceed this # by a little bit. # TODO: make this configurable for each provider -MAX_TOKENS = 100000 +MAX_TOKENS = 128000 class ContextTruncate(Moderator): From 148136e4a856be4028bc3ac6fe6e6f0d5aae6314 Mon Sep 17 00:00:00 2001 From: Luke Alvoeiro Date: Mon, 23 Sep 2024 12:46:54 -0700 Subject: [PATCH 4/6] feat: enable summarize moderator as the default moderator --- src/exchange/exchange.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/exchange/exchange.py b/src/exchange/exchange.py index 008ee7a..1150d69 100644 --- a/src/exchange/exchange.py +++ b/src/exchange/exchange.py @@ -10,7 +10,7 @@ from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message from exchange.moderators import Moderator -from exchange.moderators.truncate import ContextTruncate +from exchange.moderators.summarizer import ContextSummarizer from exchange.providers import Provider, Usage from exchange.tool import Tool from exchange.token_usage_collector import _token_usage_collector @@ -40,7 +40,7 @@ class Exchange: provider: Provider model: str system: str - moderator: Moderator = field(default=ContextTruncate()) + moderator: Moderator = field(default=ContextSummarizer()) tools: Tuple[Tool] = field(factory=tuple, converter=tuple) messages: List[Message] = field(factory=list) checkpoint_data: CheckpointData = field(factory=CheckpointData) From 42933d3e6aee691a4d188a5fa0d7435e76f9705f Mon Sep 17 00:00:00 2001 From: Luke Alvoeiro Date: Mon, 23 Sep 2024 12:52:23 -0700 Subject: [PATCH 5/6] chore: update check in `test_truncate` to ensure no regressions --- tests/test_truncate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_truncate.py b/tests/test_truncate.py index 3875303..24ae727 100644 --- a/tests/test_truncate.py +++ b/tests/test_truncate.py @@ -128,5 +128,5 @@ def test_truncate_on_generic_conversation(conversation_exchange_instance: Exchan if message.text != "Summary message here": i += 2 # ensure the total token count is not anything exhorbitant - assert conversation_exchange_instance.checkpoint_data.total_token_count < 700 + assert conversation_exchange_instance.checkpoint_data.total_token_count < 500 assert conversation_exchange_instance.moderator.system_prompt_token_count == 100 From 7a89d1eb1829ee07547869591745e018f8be1885 Mon Sep 17 00:00:00 2001 From: Mic Neale Date: Wed, 25 Sep 2024 10:14:11 +1000 Subject: [PATCH 6/6] tiny tweak to see if this helps --- src/exchange/exchange.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/exchange/exchange.py b/src/exchange/exchange.py index 1150d69..008ee7a 100644 --- a/src/exchange/exchange.py +++ b/src/exchange/exchange.py @@ -10,7 +10,7 @@ from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message from exchange.moderators import Moderator -from exchange.moderators.summarizer import ContextSummarizer +from exchange.moderators.truncate import ContextTruncate from exchange.providers import Provider, Usage from exchange.tool import Tool from exchange.token_usage_collector import _token_usage_collector @@ -40,7 +40,7 @@ class Exchange: provider: Provider model: str system: str - moderator: Moderator = field(default=ContextSummarizer()) + moderator: Moderator = field(default=ContextTruncate()) tools: Tuple[Tool] = field(factory=tuple, converter=tuple) messages: List[Message] = field(factory=list) checkpoint_data: CheckpointData = field(factory=CheckpointData)