Skip to content

Commit

Permalink
feat: collect total token usage (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
lifeizhou-ap authored Sep 19, 2024
1 parent ed8bbbf commit 8139c74
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 7 deletions.
5 changes: 5 additions & 0 deletions src/exchange/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from exchange.moderators.truncate import ContextTruncate
from exchange.providers import Provider, Usage
from exchange.tool import Tool
from exchange.token_usage_collector import _token_usage_collector


def validate_tool_output(output: str) -> None:
Expand Down Expand Up @@ -86,6 +87,7 @@ def generate(self) -> Message:
# `rewrite` above.
# self.moderator.rewrite(self)

_token_usage_collector.collect(self.model, usage)
return message

def reply(self, max_tool_use: int = 128) -> Message:
Expand Down Expand Up @@ -327,3 +329,6 @@ def is_allowed_to_call_llm(self) -> bool:
# Some models will have different requirements than others, so it may be better for
# this to be a required method of the provider instead.
return len(self.messages) > 0 and self.messages[-1].role == "user"

def get_token_usage(self) -> Dict[str, Usage]:
return _token_usage_collector.get_token_usage_group_by_model()
2 changes: 1 addition & 1 deletion src/exchange/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from exchange.tool import Tool


@define
@define(hash=True)
class Usage:
input_tokens: int = field(factory=None)
output_tokens: int = field(default=None)
Expand Down
27 changes: 27 additions & 0 deletions src/exchange/token_usage_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from collections import defaultdict
from typing import Dict

from exchange.providers.base import Usage


class _TokenUsageCollector:
def __init__(self) -> None:
self.usage_data = []

def collect(self, model: str, usage: Usage) -> None:
self.usage_data.append((model, usage))

def get_token_usage_group_by_model(self) -> Dict[str, Usage]:
usage_group_by_model = defaultdict(lambda: Usage(0, 0, 0))
for model, usage in self.usage_data:
usage_by_model = usage_group_by_model[model]
if usage is not None and usage.input_tokens is not None:
usage_by_model.input_tokens += usage.input_tokens
if usage is not None and usage.output_tokens is not None:
usage_by_model.output_tokens += usage.output_tokens
if usage is not None and usage.total_tokens is not None:
usage_by_model.total_tokens += usage.total_tokens
return usage_group_by_model


_token_usage_collector = _TokenUsageCollector()
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest

from exchange.providers.base import Usage


@pytest.fixture
def dummy_tool():
def _dummy_tool() -> str:
"""An example tool"""
return "dummy response"

return _dummy_tool


@pytest.fixture
def usage_factory():
def _create_usage(input_tokens=100, output_tokens=200, total_tokens=300):
return Usage(input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=total_tokens)

return _create_usage
33 changes: 33 additions & 0 deletions tests/test_exchange_collect_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from unittest.mock import MagicMock
from exchange.exchange import Exchange
from exchange.message import Message
from exchange.moderators.passive import PassiveModerator
from exchange.providers.base import Provider
from exchange.tool import Tool
from exchange.token_usage_collector import _TokenUsageCollector

MODEL_NAME = "test-model"


def create_exchange(mock_provider, dummy_tool):
return Exchange(
provider=mock_provider,
model=MODEL_NAME,
system="test-system",
tools=(Tool.from_function(dummy_tool),),
messages=[],
moderator=PassiveModerator(),
)


def test_exchange_generate_collect_usage(usage_factory, dummy_tool, monkeypatch):
mock_provider = MagicMock(spec=Provider)
mock_usage_collector = MagicMock(spec=_TokenUsageCollector)
usage = usage_factory()
mock_provider.complete.return_value = (Message.assistant("msg"), usage)
exchange = create_exchange(mock_provider, dummy_tool)

monkeypatch.setattr("exchange.exchange._token_usage_collector", mock_usage_collector)
exchange.generate()

mock_usage_collector.collect.assert_called_once_with(MODEL_NAME, usage)
7 changes: 1 addition & 6 deletions tests/test_exchange_frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,14 @@
from exchange.tool import Tool


def dummy_tool() -> str:
"""An example tool"""
return "dummy response"


class MockProvider(Provider):
def complete(self, model, system, messages, tools=None):
return Message(role="assistant", content=[Text(text="This is a mock response.")]), Usage.from_dict(
{"total_tokens": 35}
)


def test_exchange_immutable():
def test_exchange_immutable(dummy_tool):
# Create an instance of Exchange
provider = MockProvider()
# intentionally setting a list instead of tuple on tools, it should be converted
Expand Down
24 changes: 24 additions & 0 deletions tests/test_token_usage_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from exchange.token_usage_collector import _TokenUsageCollector


def test_collect(usage_factory):
usage_collector = _TokenUsageCollector()
usage_collector.collect("model1", usage_factory(100, 1000, 1100))
usage_collector.collect("model1", usage_factory(200, 2000, 2200))
usage_collector.collect("model2", usage_factory(400, 4000, 4400))
usage_collector.collect("model3", usage_factory(500, 5000, 5500))
usage_collector.collect("model3", usage_factory(600, 6000, 6600))
assert usage_collector.get_token_usage_group_by_model() == {
"model1": usage_factory(300, 3000, 3300),
"model2": usage_factory(400, 4000, 4400),
"model3": usage_factory(1100, 11000, 12100),
}


def test_collect_with_non_input_or_output_token(usage_factory):
usage_collector = _TokenUsageCollector()
usage_collector.collect("model1", usage_factory(100, None, None))
usage_collector.collect("model1", usage_factory(None, 2000, None))
assert usage_collector.get_token_usage_group_by_model() == {
"model1": usage_factory(100, 2000, 0),
}

0 comments on commit 8139c74

Please sign in to comment.