From 15e64493ef1e707c108889ecaf26bc512e50c968 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Thu, 19 Sep 2024 08:58:02 +1000 Subject: [PATCH] used _TokenUsageCollector as a global variable to track the utoken usage --- src/exchange/exchange.py | 11 ++++---- src/exchange/token_usage_collector.py | 29 +++++++++++++------- tests/test_exchange_collect_usage.py | 33 +++++------------------ tests/test_token_usage_collector.py | 39 +++++++++++++-------------- 4 files changed, 51 insertions(+), 61 deletions(-) diff --git a/src/exchange/exchange.py b/src/exchange/exchange.py index 804ff02..c0c5032 100644 --- a/src/exchange/exchange.py +++ b/src/exchange/exchange.py @@ -13,8 +13,9 @@ from exchange.moderators.truncate import ContextTruncate from exchange.providers import Provider, Usage from exchange.tool import Tool -from exchange.token_usage_collector import TokenUsageCollector +from exchange.token_usage_collector import _TokenUsageCollector, TokenUsage +_token_usage_collector: _TokenUsageCollector = _TokenUsageCollector() def validate_tool_output(output: str) -> None: """Validate tool output for the given model""" @@ -44,7 +45,6 @@ class Exchange: tools: Tuple[Tool] = field(factory=tuple, converter=tuple) messages: List[Message] = field(factory=list) checkpoint_data: CheckpointData = field(factory=CheckpointData) - token_usage_collector: TokenUsageCollector = field(factory=TokenUsageCollector) @property def _toolmap(self) -> Mapping[str, Tool]: @@ -88,9 +88,7 @@ def generate(self) -> Message: # `rewrite` above. # self.moderator.rewrite(self) - total_tokens = usage.total_tokens if usage.total_tokens is not None else 0 - if self.token_usage_collector and total_tokens > 0: - self.token_usage_collector.collect(self.model, usage) + _token_usage_collector.collect(self.model, usage) return message def reply(self, max_tool_use: int = 128) -> Message: @@ -332,3 +330,6 @@ def is_allowed_to_call_llm(self) -> bool: # Some models will have different requirements than others, so it may be better for # this to be a required method of the provider instead. return len(self.messages) > 0 and self.messages[-1].role == "user" + + def get_token_usage(self) -> List[TokenUsage]: + return _token_usage_collector.get_token_usage_group_by_model() diff --git a/src/exchange/token_usage_collector.py b/src/exchange/token_usage_collector.py index a0c855a..6cc4835 100644 --- a/src/exchange/token_usage_collector.py +++ b/src/exchange/token_usage_collector.py @@ -1,23 +1,34 @@ from collections import defaultdict +from dataclasses import dataclass import queue -from typing import Dict +from typing import List from exchange.providers.base import Usage +@dataclass +class TokenUsage: + model: str + input_tokens: int + output_tokens: int -class TokenUsageCollector: +class _TokenUsageCollector: def __init__(self) -> None: # use thread-safe queue to store usage data from multiple threads # as data may be collected from multiple threads self.usage_data_queue = queue.Queue() def collect(self, model: str, usage: Usage) -> None: - self.usage_data_queue.put((model, usage.total_tokens)) + self.usage_data_queue.put((model, usage.input_tokens, usage.output_tokens)) - def get_token_count_group_by_model(self) -> Dict[str, int]: + def get_token_usage_group_by_model(self) -> List[TokenUsage]: all_usage_data = list(self.usage_data_queue.queue) - token_count_group_by_model = defaultdict(lambda: 0) - for model, total_tokens in all_usage_data: - if total_tokens is not None: - token_count_group_by_model[model] += total_tokens - return token_count_group_by_model + token_count_group_by_model = defaultdict(lambda: [0, 0]) + for model, input_tokens, output_tokens in all_usage_data: + if input_tokens is not None: + token_count_group_by_model[model][0] += input_tokens + if output_tokens is not None: + token_count_group_by_model[model][1] += output_tokens + token_usage_list = [ + TokenUsage(model, input_tokens, output_tokens) + for model, (input_tokens, output_tokens) in token_count_group_by_model.items()] + return token_usage_list \ No newline at end of file diff --git a/tests/test_exchange_collect_usage.py b/tests/test_exchange_collect_usage.py index e645bde..4f172e3 100644 --- a/tests/test_exchange_collect_usage.py +++ b/tests/test_exchange_collect_usage.py @@ -4,12 +4,11 @@ from exchange.moderators.passive import PassiveModerator from exchange.providers.base import Provider from exchange.tool import Tool -from exchange.token_usage_collector import TokenUsageCollector +from exchange.token_usage_collector import _TokenUsageCollector MODEL_NAME = "test-model" - -def create_exchange(mock_provider, mock_usage_collector, dummy_tool): +def create_exchange(mock_provider, dummy_tool): return Exchange( provider=mock_provider, model=MODEL_NAME, @@ -17,39 +16,19 @@ def create_exchange(mock_provider, mock_usage_collector, dummy_tool): tools=(Tool.from_function(dummy_tool),), messages=[], moderator=PassiveModerator(), - token_usage_collector=mock_usage_collector, ) -def test_exchange_generate_collect_usage(usage_factory, dummy_tool): +def test_exchange_generate_collect_usage(usage_factory, dummy_tool, monkeypatch): mock_provider = MagicMock(spec=Provider) - mock_usage_collector = MagicMock(spec=TokenUsageCollector) + mock_usage_collector = MagicMock(spec=_TokenUsageCollector) usage = usage_factory() mock_provider.complete.return_value = (Message.assistant("msg"), usage) - exchange = create_exchange(mock_provider, mock_usage_collector, dummy_tool) + exchange = create_exchange(mock_provider, dummy_tool) + monkeypatch.setattr('exchange.exchange._token_usage_collector', mock_usage_collector) exchange.generate() mock_usage_collector.collect.assert_called_once_with(MODEL_NAME, usage) -def test_exchange_generate_not_collect_usage_when_total_tokens_is_none(usage_factory, dummy_tool): - mock_provider = MagicMock(spec=Provider) - mock_usage_collector = MagicMock(spec=TokenUsageCollector) - mock_provider.complete.return_value = (Message.assistant("msg"), usage_factory(total_tokens=None)) - exchange = create_exchange(mock_provider, mock_usage_collector, dummy_tool) - - exchange.generate() - - mock_usage_collector.collect.assert_not_called() - - -def test_exchange_generate_not_collect_usage_when_total_tokens_is_0(usage_factory, dummy_tool): - mock_provider = MagicMock(spec=Provider) - mock_usage_collector = MagicMock(spec=TokenUsageCollector) - mock_provider.complete.return_value = (Message.assistant("msg"), usage_factory(total_tokens=0)) - exchange = create_exchange(mock_provider, mock_usage_collector, dummy_tool) - - exchange.generate() - - mock_usage_collector.collect.assert_not_called() diff --git a/tests/test_token_usage_collector.py b/tests/test_token_usage_collector.py index b1aafd4..a9d35da 100644 --- a/tests/test_token_usage_collector.py +++ b/tests/test_token_usage_collector.py @@ -1,24 +1,23 @@ -from exchange.token_usage_collector import TokenUsageCollector - +from exchange.token_usage_collector import _TokenUsageCollector, TokenUsage def test_collect(usage_factory): - usage_collector = TokenUsageCollector() - usage_collector.collect("model1", usage_factory(total_tokens=100)) - usage_collector.collect("model1", usage_factory(total_tokens=200)) - usage_collector.collect("model2", usage_factory(total_tokens=400)) - usage_collector.collect("model3", usage_factory(total_tokens=500)) - usage_collector.collect("model3", usage_factory(total_tokens=600)) - assert usage_collector.get_token_count_group_by_model() == { - "model1": 300, - "model2": 400, - "model3": 1100, - } + usage_collector = _TokenUsageCollector() + usage_collector.collect("model1", usage_factory(input_tokens=100, output_tokens=1000)) + usage_collector.collect("model1", usage_factory(input_tokens=200, output_tokens=2000)) + usage_collector.collect("model2", usage_factory(input_tokens=400, output_tokens=4000)) + usage_collector.collect("model3", usage_factory(input_tokens=500, output_tokens=5000)) + usage_collector.collect("model3", usage_factory(input_tokens=600, output_tokens=6000)) + assert usage_collector.get_token_usage_group_by_model() == [ + TokenUsage(model="model1", input_tokens=300, output_tokens=3000), + TokenUsage(model="model2", input_tokens=400, output_tokens=4000), + TokenUsage(model="model3", input_tokens=1100, output_tokens=11000), + ] -def test_collect_with_non_total_token(usage_factory): - usage_collector = TokenUsageCollector() - usage_collector.collect("model1", usage_factory(total_tokens=100)) - usage_collector.collect("model1", usage_factory(total_tokens=None)) - assert usage_collector.get_token_count_group_by_model() == { - "model1": 100, - } +def test_collect_with_non_input_or_output_token(usage_factory): + usage_collector = _TokenUsageCollector() + usage_collector.collect("model1", usage_factory(input_tokens=100, output_tokens=None)) + usage_collector.collect("model1", usage_factory(input_tokens=None, output_tokens=2000)) + assert usage_collector.get_token_usage_group_by_model() == [ + TokenUsage(model="model1", input_tokens=100, output_tokens=2000), + ]