-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
used _TokenUsageCollector as a global variable to track the utoken usage
- Loading branch information
1 parent
fa04d35
commit 15e6449
Showing
4 changed files
with
51 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,34 @@ | ||
from collections import defaultdict | ||
from dataclasses import dataclass | ||
import queue | ||
from typing import Dict | ||
from typing import List | ||
|
||
from exchange.providers.base import Usage | ||
|
||
@dataclass | ||
class TokenUsage: | ||
model: str | ||
input_tokens: int | ||
output_tokens: int | ||
|
||
class TokenUsageCollector: | ||
class _TokenUsageCollector: | ||
def __init__(self) -> None: | ||
# use thread-safe queue to store usage data from multiple threads | ||
# as data may be collected from multiple threads | ||
self.usage_data_queue = queue.Queue() | ||
|
||
def collect(self, model: str, usage: Usage) -> None: | ||
self.usage_data_queue.put((model, usage.total_tokens)) | ||
self.usage_data_queue.put((model, usage.input_tokens, usage.output_tokens)) | ||
|
||
def get_token_count_group_by_model(self) -> Dict[str, int]: | ||
def get_token_usage_group_by_model(self) -> List[TokenUsage]: | ||
all_usage_data = list(self.usage_data_queue.queue) | ||
token_count_group_by_model = defaultdict(lambda: 0) | ||
for model, total_tokens in all_usage_data: | ||
if total_tokens is not None: | ||
token_count_group_by_model[model] += total_tokens | ||
return token_count_group_by_model | ||
token_count_group_by_model = defaultdict(lambda: [0, 0]) | ||
for model, input_tokens, output_tokens in all_usage_data: | ||
if input_tokens is not None: | ||
token_count_group_by_model[model][0] += input_tokens | ||
if output_tokens is not None: | ||
token_count_group_by_model[model][1] += output_tokens | ||
token_usage_list = [ | ||
TokenUsage(model, input_tokens, output_tokens) | ||
for model, (input_tokens, output_tokens) in token_count_group_by_model.items()] | ||
return token_usage_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,23 @@ | ||
from exchange.token_usage_collector import TokenUsageCollector | ||
|
||
from exchange.token_usage_collector import _TokenUsageCollector, TokenUsage | ||
|
||
def test_collect(usage_factory): | ||
usage_collector = TokenUsageCollector() | ||
usage_collector.collect("model1", usage_factory(total_tokens=100)) | ||
usage_collector.collect("model1", usage_factory(total_tokens=200)) | ||
usage_collector.collect("model2", usage_factory(total_tokens=400)) | ||
usage_collector.collect("model3", usage_factory(total_tokens=500)) | ||
usage_collector.collect("model3", usage_factory(total_tokens=600)) | ||
assert usage_collector.get_token_count_group_by_model() == { | ||
"model1": 300, | ||
"model2": 400, | ||
"model3": 1100, | ||
} | ||
usage_collector = _TokenUsageCollector() | ||
usage_collector.collect("model1", usage_factory(input_tokens=100, output_tokens=1000)) | ||
usage_collector.collect("model1", usage_factory(input_tokens=200, output_tokens=2000)) | ||
usage_collector.collect("model2", usage_factory(input_tokens=400, output_tokens=4000)) | ||
usage_collector.collect("model3", usage_factory(input_tokens=500, output_tokens=5000)) | ||
usage_collector.collect("model3", usage_factory(input_tokens=600, output_tokens=6000)) | ||
assert usage_collector.get_token_usage_group_by_model() == [ | ||
TokenUsage(model="model1", input_tokens=300, output_tokens=3000), | ||
TokenUsage(model="model2", input_tokens=400, output_tokens=4000), | ||
TokenUsage(model="model3", input_tokens=1100, output_tokens=11000), | ||
] | ||
|
||
|
||
def test_collect_with_non_total_token(usage_factory): | ||
usage_collector = TokenUsageCollector() | ||
usage_collector.collect("model1", usage_factory(total_tokens=100)) | ||
usage_collector.collect("model1", usage_factory(total_tokens=None)) | ||
assert usage_collector.get_token_count_group_by_model() == { | ||
"model1": 100, | ||
} | ||
def test_collect_with_non_input_or_output_token(usage_factory): | ||
usage_collector = _TokenUsageCollector() | ||
usage_collector.collect("model1", usage_factory(input_tokens=100, output_tokens=None)) | ||
usage_collector.collect("model1", usage_factory(input_tokens=None, output_tokens=2000)) | ||
assert usage_collector.get_token_usage_group_by_model() == [ | ||
TokenUsage(model="model1", input_tokens=100, output_tokens=2000), | ||
] |