-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: collect total token usage (#32)
- Loading branch information
1 parent
ed8bbbf
commit 8139c74
Showing
7 changed files
with
111 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from collections import defaultdict | ||
from typing import Dict | ||
|
||
from exchange.providers.base import Usage | ||
|
||
|
||
class _TokenUsageCollector: | ||
def __init__(self) -> None: | ||
self.usage_data = [] | ||
|
||
def collect(self, model: str, usage: Usage) -> None: | ||
self.usage_data.append((model, usage)) | ||
|
||
def get_token_usage_group_by_model(self) -> Dict[str, Usage]: | ||
usage_group_by_model = defaultdict(lambda: Usage(0, 0, 0)) | ||
for model, usage in self.usage_data: | ||
usage_by_model = usage_group_by_model[model] | ||
if usage is not None and usage.input_tokens is not None: | ||
usage_by_model.input_tokens += usage.input_tokens | ||
if usage is not None and usage.output_tokens is not None: | ||
usage_by_model.output_tokens += usage.output_tokens | ||
if usage is not None and usage.total_tokens is not None: | ||
usage_by_model.total_tokens += usage.total_tokens | ||
return usage_group_by_model | ||
|
||
|
||
_token_usage_collector = _TokenUsageCollector() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import pytest | ||
|
||
from exchange.providers.base import Usage | ||
|
||
|
||
@pytest.fixture | ||
def dummy_tool(): | ||
def _dummy_tool() -> str: | ||
"""An example tool""" | ||
return "dummy response" | ||
|
||
return _dummy_tool | ||
|
||
|
||
@pytest.fixture | ||
def usage_factory(): | ||
def _create_usage(input_tokens=100, output_tokens=200, total_tokens=300): | ||
return Usage(input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=total_tokens) | ||
|
||
return _create_usage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from unittest.mock import MagicMock | ||
from exchange.exchange import Exchange | ||
from exchange.message import Message | ||
from exchange.moderators.passive import PassiveModerator | ||
from exchange.providers.base import Provider | ||
from exchange.tool import Tool | ||
from exchange.token_usage_collector import _TokenUsageCollector | ||
|
||
MODEL_NAME = "test-model" | ||
|
||
|
||
def create_exchange(mock_provider, dummy_tool): | ||
return Exchange( | ||
provider=mock_provider, | ||
model=MODEL_NAME, | ||
system="test-system", | ||
tools=(Tool.from_function(dummy_tool),), | ||
messages=[], | ||
moderator=PassiveModerator(), | ||
) | ||
|
||
|
||
def test_exchange_generate_collect_usage(usage_factory, dummy_tool, monkeypatch): | ||
mock_provider = MagicMock(spec=Provider) | ||
mock_usage_collector = MagicMock(spec=_TokenUsageCollector) | ||
usage = usage_factory() | ||
mock_provider.complete.return_value = (Message.assistant("msg"), usage) | ||
exchange = create_exchange(mock_provider, dummy_tool) | ||
|
||
monkeypatch.setattr("exchange.exchange._token_usage_collector", mock_usage_collector) | ||
exchange.generate() | ||
|
||
mock_usage_collector.collect.assert_called_once_with(MODEL_NAME, usage) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from exchange.token_usage_collector import _TokenUsageCollector | ||
|
||
|
||
def test_collect(usage_factory): | ||
usage_collector = _TokenUsageCollector() | ||
usage_collector.collect("model1", usage_factory(100, 1000, 1100)) | ||
usage_collector.collect("model1", usage_factory(200, 2000, 2200)) | ||
usage_collector.collect("model2", usage_factory(400, 4000, 4400)) | ||
usage_collector.collect("model3", usage_factory(500, 5000, 5500)) | ||
usage_collector.collect("model3", usage_factory(600, 6000, 6600)) | ||
assert usage_collector.get_token_usage_group_by_model() == { | ||
"model1": usage_factory(300, 3000, 3300), | ||
"model2": usage_factory(400, 4000, 4400), | ||
"model3": usage_factory(1100, 11000, 12100), | ||
} | ||
|
||
|
||
def test_collect_with_non_input_or_output_token(usage_factory): | ||
usage_collector = _TokenUsageCollector() | ||
usage_collector.collect("model1", usage_factory(100, None, None)) | ||
usage_collector.collect("model1", usage_factory(None, 2000, None)) | ||
assert usage_collector.get_token_usage_group_by_model() == { | ||
"model1": usage_factory(100, 2000, 0), | ||
} |