Skip to content

Commit

Permalink
used _TokenUsageCollector as a global variable to track the utoken usage
Browse files Browse the repository at this point in the history
  • Loading branch information
lifeizhou-ap committed Sep 18, 2024
1 parent fa04d35 commit 15e6449
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 61 deletions.
11 changes: 6 additions & 5 deletions src/exchange/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from exchange.moderators.truncate import ContextTruncate
from exchange.providers import Provider, Usage
from exchange.tool import Tool
from exchange.token_usage_collector import TokenUsageCollector
from exchange.token_usage_collector import _TokenUsageCollector, TokenUsage

_token_usage_collector: _TokenUsageCollector = _TokenUsageCollector()

def validate_tool_output(output: str) -> None:
"""Validate tool output for the given model"""
Expand Down Expand Up @@ -44,7 +45,6 @@ class Exchange:
tools: Tuple[Tool] = field(factory=tuple, converter=tuple)
messages: List[Message] = field(factory=list)
checkpoint_data: CheckpointData = field(factory=CheckpointData)
token_usage_collector: TokenUsageCollector = field(factory=TokenUsageCollector)

@property
def _toolmap(self) -> Mapping[str, Tool]:
Expand Down Expand Up @@ -88,9 +88,7 @@ def generate(self) -> Message:
# `rewrite` above.
# self.moderator.rewrite(self)

total_tokens = usage.total_tokens if usage.total_tokens is not None else 0
if self.token_usage_collector and total_tokens > 0:
self.token_usage_collector.collect(self.model, usage)
_token_usage_collector.collect(self.model, usage)
return message

def reply(self, max_tool_use: int = 128) -> Message:
Expand Down Expand Up @@ -332,3 +330,6 @@ def is_allowed_to_call_llm(self) -> bool:
# Some models will have different requirements than others, so it may be better for
# this to be a required method of the provider instead.
return len(self.messages) > 0 and self.messages[-1].role == "user"

def get_token_usage(self) -> List[TokenUsage]:
return _token_usage_collector.get_token_usage_group_by_model()
29 changes: 20 additions & 9 deletions src/exchange/token_usage_collector.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,34 @@
from collections import defaultdict
from dataclasses import dataclass
import queue
from typing import Dict
from typing import List

from exchange.providers.base import Usage

@dataclass
class TokenUsage:
model: str
input_tokens: int
output_tokens: int

class TokenUsageCollector:
class _TokenUsageCollector:
def __init__(self) -> None:
# use thread-safe queue to store usage data from multiple threads
# as data may be collected from multiple threads
self.usage_data_queue = queue.Queue()

def collect(self, model: str, usage: Usage) -> None:
self.usage_data_queue.put((model, usage.total_tokens))
self.usage_data_queue.put((model, usage.input_tokens, usage.output_tokens))

def get_token_count_group_by_model(self) -> Dict[str, int]:
def get_token_usage_group_by_model(self) -> List[TokenUsage]:
all_usage_data = list(self.usage_data_queue.queue)
token_count_group_by_model = defaultdict(lambda: 0)
for model, total_tokens in all_usage_data:
if total_tokens is not None:
token_count_group_by_model[model] += total_tokens
return token_count_group_by_model
token_count_group_by_model = defaultdict(lambda: [0, 0])
for model, input_tokens, output_tokens in all_usage_data:
if input_tokens is not None:
token_count_group_by_model[model][0] += input_tokens
if output_tokens is not None:
token_count_group_by_model[model][1] += output_tokens
token_usage_list = [
TokenUsage(model, input_tokens, output_tokens)
for model, (input_tokens, output_tokens) in token_count_group_by_model.items()]
return token_usage_list
33 changes: 6 additions & 27 deletions tests/test_exchange_collect_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,31 @@
from exchange.moderators.passive import PassiveModerator
from exchange.providers.base import Provider
from exchange.tool import Tool
from exchange.token_usage_collector import TokenUsageCollector
from exchange.token_usage_collector import _TokenUsageCollector

MODEL_NAME = "test-model"


def create_exchange(mock_provider, mock_usage_collector, dummy_tool):
def create_exchange(mock_provider, dummy_tool):
return Exchange(
provider=mock_provider,
model=MODEL_NAME,
system="test-system",
tools=(Tool.from_function(dummy_tool),),
messages=[],
moderator=PassiveModerator(),
token_usage_collector=mock_usage_collector,
)


def test_exchange_generate_collect_usage(usage_factory, dummy_tool):
def test_exchange_generate_collect_usage(usage_factory, dummy_tool, monkeypatch):
mock_provider = MagicMock(spec=Provider)
mock_usage_collector = MagicMock(spec=TokenUsageCollector)
mock_usage_collector = MagicMock(spec=_TokenUsageCollector)
usage = usage_factory()
mock_provider.complete.return_value = (Message.assistant("msg"), usage)
exchange = create_exchange(mock_provider, mock_usage_collector, dummy_tool)
exchange = create_exchange(mock_provider, dummy_tool)

monkeypatch.setattr('exchange.exchange._token_usage_collector', mock_usage_collector)
exchange.generate()

mock_usage_collector.collect.assert_called_once_with(MODEL_NAME, usage)


def test_exchange_generate_not_collect_usage_when_total_tokens_is_none(usage_factory, dummy_tool):
mock_provider = MagicMock(spec=Provider)
mock_usage_collector = MagicMock(spec=TokenUsageCollector)
mock_provider.complete.return_value = (Message.assistant("msg"), usage_factory(total_tokens=None))
exchange = create_exchange(mock_provider, mock_usage_collector, dummy_tool)

exchange.generate()

mock_usage_collector.collect.assert_not_called()


def test_exchange_generate_not_collect_usage_when_total_tokens_is_0(usage_factory, dummy_tool):
mock_provider = MagicMock(spec=Provider)
mock_usage_collector = MagicMock(spec=TokenUsageCollector)
mock_provider.complete.return_value = (Message.assistant("msg"), usage_factory(total_tokens=0))
exchange = create_exchange(mock_provider, mock_usage_collector, dummy_tool)

exchange.generate()

mock_usage_collector.collect.assert_not_called()
39 changes: 19 additions & 20 deletions tests/test_token_usage_collector.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
from exchange.token_usage_collector import TokenUsageCollector

from exchange.token_usage_collector import _TokenUsageCollector, TokenUsage

def test_collect(usage_factory):
usage_collector = TokenUsageCollector()
usage_collector.collect("model1", usage_factory(total_tokens=100))
usage_collector.collect("model1", usage_factory(total_tokens=200))
usage_collector.collect("model2", usage_factory(total_tokens=400))
usage_collector.collect("model3", usage_factory(total_tokens=500))
usage_collector.collect("model3", usage_factory(total_tokens=600))
assert usage_collector.get_token_count_group_by_model() == {
"model1": 300,
"model2": 400,
"model3": 1100,
}
usage_collector = _TokenUsageCollector()
usage_collector.collect("model1", usage_factory(input_tokens=100, output_tokens=1000))
usage_collector.collect("model1", usage_factory(input_tokens=200, output_tokens=2000))
usage_collector.collect("model2", usage_factory(input_tokens=400, output_tokens=4000))
usage_collector.collect("model3", usage_factory(input_tokens=500, output_tokens=5000))
usage_collector.collect("model3", usage_factory(input_tokens=600, output_tokens=6000))
assert usage_collector.get_token_usage_group_by_model() == [
TokenUsage(model="model1", input_tokens=300, output_tokens=3000),
TokenUsage(model="model2", input_tokens=400, output_tokens=4000),
TokenUsage(model="model3", input_tokens=1100, output_tokens=11000),
]


def test_collect_with_non_total_token(usage_factory):
usage_collector = TokenUsageCollector()
usage_collector.collect("model1", usage_factory(total_tokens=100))
usage_collector.collect("model1", usage_factory(total_tokens=None))
assert usage_collector.get_token_count_group_by_model() == {
"model1": 100,
}
def test_collect_with_non_input_or_output_token(usage_factory):
usage_collector = _TokenUsageCollector()
usage_collector.collect("model1", usage_factory(input_tokens=100, output_tokens=None))
usage_collector.collect("model1", usage_factory(input_tokens=None, output_tokens=2000))
assert usage_collector.get_token_usage_group_by_model() == [
TokenUsage(model="model1", input_tokens=100, output_tokens=2000),
]

0 comments on commit 15e6449

Please sign in to comment.