diff --git a/ai_engine_sdk/api_models/agents_json_messages.py b/ai_engine_sdk/api_models/agents_json_messages.py index 552b8b3..733062f 100644 --- a/ai_engine_sdk/api_models/agents_json_messages.py +++ b/ai_engine_sdk/api_models/agents_json_messages.py @@ -37,7 +37,6 @@ class TaskSelectionMessage(AgentJsonMessage): text: str options: Dict[str, TaskOption] - def get_options_keys(self) -> list[TaskOption]: return [option for option in self.options] @@ -50,7 +49,6 @@ class DataRequestMessage(AgentJsonMessage): class ConfirmationMessage(AgentJsonMessage): type: Literal[AgentJsonMessageTypes.CONFIRMATION] = AgentJsonMessageTypes.CONFIRMATION text: str - model: str payload: Dict[str, Any] @@ -70,14 +68,16 @@ def is_agent_json_confirmation_message(message_type: str) -> bool: def is_task_selection_message(message_type: str) -> bool: union_of_type = TaskSelectionTypes - allowed_values = [literal for lit in get_args(union_of_type) for literal in get_args(lit)] + allowed_values = [literal for lit in get_args( + union_of_type) for literal in get_args(lit)] return message_type.upper() in allowed_values def is_data_request_message(message_type: str) -> bool: union_of_type = DataRequestTypes if get_origin(union_of_type) is Union: - allowed_values = [literal for lit in get_args(union_of_type) for literal in get_args(lit)] + allowed_values = [literal for lit in get_args( + union_of_type) for literal in get_args(lit)] elif get_origin(union_of_type) is Literal: allowed_values = get_args(union_of_type) diff --git a/ai_engine_sdk/client.py b/ai_engine_sdk/client.py index 309b28e..7c0279c 100644 --- a/ai_engine_sdk/client.py +++ b/ai_engine_sdk/client.py @@ -280,7 +280,7 @@ async def get_messages(self) -> List[ApiBaseMessage]: 'id': message['message_id'], 'timestamp': message['timestamp'], 'text': agent_json['text'], - 'options':indexed_task_options + 'options': indexed_task_options }) ) elif is_api_context_json(message_type=agent_json_type, agent_json_text=agent_json['text']): @@ -289,8 +289,7 @@ async def get_messages(self) -> List[ApiBaseMessage]: 'id': message['message_id'], 'timestamp': message['timestamp'], 'text': agent_json['text'], - 'model': agent_json['context_json']['digest'], - 'payload': agent_json['context_json']['args'], + 'payload': agent_json['context_json'], }) ) elif is_data_request_message(message_type=agent_json_type): @@ -352,7 +351,7 @@ async def delete(self): endpoint=f"/v1beta1/engine/chat/sessions/{self.session_id}" ) - async def execute_function(self, function_ids: list[str], objective: str, context: str|None = None): + async def execute_function(self, function_ids: list[str], objective: str, context: str | None = None): await self._submit_message( payload=ApiUserMessageExecuteFunctions.model_validate({ "functions": function_ids, @@ -367,7 +366,6 @@ def __init__(self, api_key: str, options: Optional[dict] = None): self._api_base_url = options.get('api_base_url') if options and 'api_base_url' in options else default_api_base_url self._api_key = api_key - #### # Function groups #### @@ -579,4 +577,4 @@ async def share_function_group( payload=payload ) logger.debug(f"FG successfully shared: {function_group_id} with {target_user_email}") - return raw_response + return raw_response \ No newline at end of file