Skip to content

Commit

Permalink
Merge pull request #20 from tanaygodse/digest-issue
Browse files Browse the repository at this point in the history
['digest'] issue and ['args'] issue fix
  • Loading branch information
qati authored Sep 30, 2024
2 parents 229a0c0 + 6e8c494 commit 5295a12
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
8 changes: 4 additions & 4 deletions ai_engine_sdk/api_models/agents_json_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class TaskSelectionMessage(AgentJsonMessage):
text: str
options: Dict[str, TaskOption]


def get_options_keys(self) -> list[TaskOption]:
return [option for option in self.options]

Expand All @@ -50,7 +49,6 @@ class DataRequestMessage(AgentJsonMessage):
class ConfirmationMessage(AgentJsonMessage):
type: Literal[AgentJsonMessageTypes.CONFIRMATION] = AgentJsonMessageTypes.CONFIRMATION
text: str
model: str
payload: Dict[str, Any]


Expand All @@ -70,14 +68,16 @@ def is_agent_json_confirmation_message(message_type: str) -> bool:

def is_task_selection_message(message_type: str) -> bool:
union_of_type = TaskSelectionTypes
allowed_values = [literal for lit in get_args(union_of_type) for literal in get_args(lit)]
allowed_values = [literal for lit in get_args(
union_of_type) for literal in get_args(lit)]
return message_type.upper() in allowed_values


def is_data_request_message(message_type: str) -> bool:
union_of_type = DataRequestTypes
if get_origin(union_of_type) is Union:
allowed_values = [literal for lit in get_args(union_of_type) for literal in get_args(lit)]
allowed_values = [literal for lit in get_args(
union_of_type) for literal in get_args(lit)]
elif get_origin(union_of_type) is Literal:
allowed_values = get_args(union_of_type)

Expand Down
10 changes: 4 additions & 6 deletions ai_engine_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
'id': message['message_id'],
'timestamp': message['timestamp'],
'text': agent_json['text'],
'options':indexed_task_options
'options': indexed_task_options
})
)
elif is_api_context_json(message_type=agent_json_type, agent_json_text=agent_json['text']):
Expand All @@ -289,8 +289,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
'id': message['message_id'],
'timestamp': message['timestamp'],
'text': agent_json['text'],
'model': agent_json['context_json']['digest'],
'payload': agent_json['context_json']['args'],
'payload': agent_json['context_json'],
})
)
elif is_data_request_message(message_type=agent_json_type):
Expand Down Expand Up @@ -352,7 +351,7 @@ async def delete(self):
endpoint=f"/v1beta1/engine/chat/sessions/{self.session_id}"
)

async def execute_function(self, function_ids: list[str], objective: str, context: str|None = None):
async def execute_function(self, function_ids: list[str], objective: str, context: str | None = None):
await self._submit_message(
payload=ApiUserMessageExecuteFunctions.model_validate({
"functions": function_ids,
Expand All @@ -367,7 +366,6 @@ def __init__(self, api_key: str, options: Optional[dict] = None):
self._api_base_url = options.get('api_base_url') if options and 'api_base_url' in options else default_api_base_url
self._api_key = api_key


####
# Function groups
####
Expand Down Expand Up @@ -579,4 +577,4 @@ async def share_function_group(
payload=payload
)
logger.debug(f"FG successfully shared: {function_group_id} with {target_user_email}")
return raw_response
return raw_response

0 comments on commit 5295a12

Please sign in to comment.