diff --git a/modelscope_agent/llm/base.py b/modelscope_agent/llm/base.py index 39438023..c666e961 100644 --- a/modelscope_agent/llm/base.py +++ b/modelscope_agent/llm/base.py @@ -287,7 +287,7 @@ def get_max_length(self) -> int: def get_usage(self) -> Dict: return self.last_call_usage_info - def stat_last_call_token_info(self, response): + def stat_last_call_token_info_stream(self, response): try: self.last_call_usage_info = response.usage.dict() return response @@ -296,3 +296,8 @@ def stat_last_call_token_info(self, response): if hasattr(chunk, 'usage') and chunk.usage is not None: self.last_call_usage_info = chunk.usage.dict() yield chunk + + def stat_last_call_token_info_no_stream(self, response): + if hasattr(response, 'usage'): + self.last_call_usage_info = response.usage.dict() + return response diff --git a/modelscope_agent/llm/dashscope.py b/modelscope_agent/llm/dashscope.py index 9ead4398..e966bb50 100644 --- a/modelscope_agent/llm/dashscope.py +++ b/modelscope_agent/llm/dashscope.py @@ -101,7 +101,7 @@ def _chat_stream(self, generation_input['seed'] = kwargs.get('seed') response = dashscope.Generation.call(**generation_input) - response = self.stat_last_call_token_info(response) + response = self.stat_last_call_token_info_stream(response) return stream_output(response, **kwargs) def _chat_no_stream(self, @@ -120,7 +120,7 @@ def _chat_no_stream(self, top_p=top_p, ) if response.status_code == HTTPStatus.OK: - self.stat_last_call_token_info(response) + self.stat_last_call_token_info_no_stream(response) return response.output.choices[0].message.content else: err = 'Error code: %s, error message: %s' % ( @@ -129,7 +129,25 @@ def _chat_no_stream(self, ) return err - def stat_last_call_token_info(self, response): + def stat_last_call_token_info_no_stream(self, response): + try: + if response.usage is not None: + if not response.usage.get('total_tokens'): + total_tokens = response.usage.input_tokens + response.usage.output_tokens + else: + total_tokens = response.usage.total_tokens + self.last_call_usage_info = { + 'prompt_tokens': response.usage.input_tokens, + 'completion_tokens': response.usage.output_tokens, + 'total_tokens': total_tokens + } + else: + logger.warning('No usage info in response') + except AttributeError: + logger.warning('No usage info in response') + return response + + def stat_last_call_token_info_stream(self, response): try: if response.usage is not None: if not response.usage.get('total_tokens'): diff --git a/modelscope_agent/llm/ollama.py b/modelscope_agent/llm/ollama.py index e0e25c4d..a9e35586 100644 --- a/modelscope_agent/llm/ollama.py +++ b/modelscope_agent/llm/ollama.py @@ -39,7 +39,7 @@ def _chat_stream(self, f'stop: {str(stop)}, stream: True, args: {str(kwargs)}') stream = self.client.chat( model=self.model, messages=messages, stream=True) - stream = self.stat_last_call_token_info(stream) + stream = self.stat_last_call_token_info_stream(stream) for chunk in stream: tmp_content = chunk['message']['content'] logger.info(f'call ollama success, output: {tmp_content}') @@ -55,7 +55,7 @@ def _chat_no_stream(self, f'call ollama, model: {self.model}, messages: {str(messages)}, ' f'stop: {str(stop)}, stream: False, args: {str(kwargs)}') response = self.client.chat(model=self.model, messages=messages) - self.stat_last_call_token_info(response) + self.stat_last_call_token_info_no_stream(response) final_content = response['message']['content'] logger.info(f'call ollama success, output: {final_content}') return final_content @@ -100,7 +100,22 @@ def chat(self, return super().chat( messages=messages, stop=stop, stream=stream, **kwargs) - def stat_last_call_token_info(self, response): + def stat_last_call_token_info_no_stream(self, response): + try: + self.last_call_usage_info = { + 'prompt_tokens': + response.get('prompt_eval_count', -1), + 'completion_tokens': + response.get('eval_count', -1), + 'total_tokens': + response.get('prompt_eval_count') + response.get('eval_count') + } + except AttributeError: + logger.warning('No usage info in response') + + return response + + def stat_last_call_token_info_stream(self, response): try: self.last_call_usage_info = { 'prompt_tokens': diff --git a/modelscope_agent/llm/openai.py b/modelscope_agent/llm/openai.py index 67a795ad..89f757e5 100644 --- a/modelscope_agent/llm/openai.py +++ b/modelscope_agent/llm/openai.py @@ -46,7 +46,7 @@ def _chat_stream(self, stream=True, stream_options=stream_options, **kwargs) - response = self.stat_last_call_token_info(response) + response = self.stat_last_call_token_info_stream(response) # TODO: error handling for chunk in response: # sometimes delta.content is None by vllm, we should not yield None @@ -72,7 +72,7 @@ def _chat_no_stream(self, stop=stop, stream=False, **kwargs) - self.stat_last_call_token_info(response) + self.stat_last_call_token_info_no_stream(response) logger.info( f'call openai api success, output: {response.choices[0].message.content}' ) @@ -171,7 +171,7 @@ def _chat_stream(self, f'stop: {str(stop)}, stream: True, args: {str(kwargs)}') response = self.client.chat.completions.create( model=self.model, messages=messages, stop=stop, stream=True) - response = self.stat_last_call_token_info(response) + response = self.stat_last_call_token_info_stream(response) # TODO: error handling for chunk in response: # sometimes delta.content is None by vllm, we should not yield None diff --git a/tests/test_callback.py b/tests/test_callback.py index b573fc51..46b95eaa 100644 --- a/tests/test_callback.py +++ b/tests/test_callback.py @@ -67,6 +67,7 @@ def test_tool_exec_run_state(mocker): assert callback.run_states[1][2].content == '{"test": "test_value"}' +@pytest.mark.skipif(IS_FORKED_PR, reason='only run modelscope-agent main repo') def test_rag_run_state(mocker): callback = RunStateCallback()