Skip to content

Commit

Permalink
Fix bug in _chat_no_stream to ensure usage_info is updated correc…
Browse files Browse the repository at this point in the history
…tly (#549)

## Change Summary
### Problem Description
When calling the `OpenAi` class's `_chat_no_stream()` method, the
`usage_info` result is empty. Upon investigation, the issue was
identified in the `stat_last_call_token_info` method. Since this method
uses `yield`, it turns into a generator function and does not execute
correctly within `_chat_no_stream()`, causing `usage_info` not to be
updated properly.

### Specific Issue
When calling the llm with no stream, `usage_info` is empty:

```python
from modelscope_agent.llm import get_chat_model

msg = [ {"role": "user", "content": 'hello'} ]
llm = get_chat_model(**llm_config)
resp = llm.chat(messages=msg,
                max_tokens=1024,
                temperature=1.0,
                stream=False)
usage_info = llm.get_usage()
```
#### Actual Output
```python
>>> usage_info = {}
```
#### Expected Output
```python
>>> usage_info = {'prompt_tokens': 5, 'completion_tokens': 10, 'total_tokens': 15}
```

## Related issue number
## Checklist
* [x] The pull request title is a good summary of the changes - it will
be used in the changelog
* [x]  Unit tests for the changes exist
* [x] Run `pre-commit install` and `pre-commit run --all-files` before
git commit, and passed lint check.
* [ ] Some cases need DASHSCOPE_TOKEN_API to pass the Unit Tests, I have
at least **pass the Unit tests on local**
* [ ]  Documentation reflects the changes where applicable
* [x] My PR is ready to review, **please add a comment including the
phrase "please review" to assign reviewers**
  • Loading branch information
Zhikaiiii authored Jul 26, 2024
2 parents 6b7bb78 + a1e9630 commit 42a1307
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 10 deletions.
7 changes: 6 additions & 1 deletion modelscope_agent/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def get_max_length(self) -> int:
def get_usage(self) -> Dict:
return self.last_call_usage_info

def stat_last_call_token_info(self, response):
def stat_last_call_token_info_stream(self, response):
try:
self.last_call_usage_info = response.usage.dict()
return response
Expand All @@ -296,3 +296,8 @@ def stat_last_call_token_info(self, response):
if hasattr(chunk, 'usage') and chunk.usage is not None:
self.last_call_usage_info = chunk.usage.dict()
yield chunk

def stat_last_call_token_info_no_stream(self, response):
if hasattr(response, 'usage'):
self.last_call_usage_info = response.usage.dict()
return response
24 changes: 21 additions & 3 deletions modelscope_agent/llm/dashscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _chat_stream(self,
generation_input['seed'] = kwargs.get('seed')

response = dashscope.Generation.call(**generation_input)
response = self.stat_last_call_token_info(response)
response = self.stat_last_call_token_info_stream(response)
return stream_output(response, **kwargs)

def _chat_no_stream(self,
Expand All @@ -120,7 +120,7 @@ def _chat_no_stream(self,
top_p=top_p,
)
if response.status_code == HTTPStatus.OK:
self.stat_last_call_token_info(response)
self.stat_last_call_token_info_no_stream(response)
return response.output.choices[0].message.content
else:
err = 'Error code: %s, error message: %s' % (
Expand All @@ -129,7 +129,25 @@ def _chat_no_stream(self,
)
return err

def stat_last_call_token_info(self, response):
def stat_last_call_token_info_no_stream(self, response):
try:
if response.usage is not None:
if not response.usage.get('total_tokens'):
total_tokens = response.usage.input_tokens + response.usage.output_tokens
else:
total_tokens = response.usage.total_tokens
self.last_call_usage_info = {
'prompt_tokens': response.usage.input_tokens,
'completion_tokens': response.usage.output_tokens,
'total_tokens': total_tokens
}
else:
logger.warning('No usage info in response')
except AttributeError:
logger.warning('No usage info in response')
return response

def stat_last_call_token_info_stream(self, response):
try:
if response.usage is not None:
if not response.usage.get('total_tokens'):
Expand Down
21 changes: 18 additions & 3 deletions modelscope_agent/llm/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _chat_stream(self,
f'stop: {str(stop)}, stream: True, args: {str(kwargs)}')
stream = self.client.chat(
model=self.model, messages=messages, stream=True)
stream = self.stat_last_call_token_info(stream)
stream = self.stat_last_call_token_info_stream(stream)
for chunk in stream:
tmp_content = chunk['message']['content']
logger.info(f'call ollama success, output: {tmp_content}')
Expand All @@ -55,7 +55,7 @@ def _chat_no_stream(self,
f'call ollama, model: {self.model}, messages: {str(messages)}, '
f'stop: {str(stop)}, stream: False, args: {str(kwargs)}')
response = self.client.chat(model=self.model, messages=messages)
self.stat_last_call_token_info(response)
self.stat_last_call_token_info_no_stream(response)
final_content = response['message']['content']
logger.info(f'call ollama success, output: {final_content}')
return final_content
Expand Down Expand Up @@ -100,7 +100,22 @@ def chat(self,
return super().chat(
messages=messages, stop=stop, stream=stream, **kwargs)

def stat_last_call_token_info(self, response):
def stat_last_call_token_info_no_stream(self, response):
try:
self.last_call_usage_info = {
'prompt_tokens':
response.get('prompt_eval_count', -1),
'completion_tokens':
response.get('eval_count', -1),
'total_tokens':
response.get('prompt_eval_count') + response.get('eval_count')
}
except AttributeError:
logger.warning('No usage info in response')

return response

def stat_last_call_token_info_stream(self, response):
try:
self.last_call_usage_info = {
'prompt_tokens':
Expand Down
6 changes: 3 additions & 3 deletions modelscope_agent/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _chat_stream(self,
stream=True,
stream_options=stream_options,
**kwargs)
response = self.stat_last_call_token_info(response)
response = self.stat_last_call_token_info_stream(response)
# TODO: error handling
for chunk in response:
# sometimes delta.content is None by vllm, we should not yield None
Expand All @@ -72,7 +72,7 @@ def _chat_no_stream(self,
stop=stop,
stream=False,
**kwargs)
self.stat_last_call_token_info(response)
self.stat_last_call_token_info_no_stream(response)
logger.info(
f'call openai api success, output: {response.choices[0].message.content}'
)
Expand Down Expand Up @@ -171,7 +171,7 @@ def _chat_stream(self,
f'stop: {str(stop)}, stream: True, args: {str(kwargs)}')
response = self.client.chat.completions.create(
model=self.model, messages=messages, stop=stop, stream=True)
response = self.stat_last_call_token_info(response)
response = self.stat_last_call_token_info_stream(response)
# TODO: error handling
for chunk in response:
# sometimes delta.content is None by vllm, we should not yield None
Expand Down
1 change: 1 addition & 0 deletions tests/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_tool_exec_run_state(mocker):
assert callback.run_states[1][2].content == '{"test": "test_value"}'


@pytest.mark.skipif(IS_FORKED_PR, reason='only run modelscope-agent main repo')
def test_rag_run_state(mocker):

callback = RunStateCallback()
Expand Down

0 comments on commit 42a1307

Please sign in to comment.