Skip to content

Commit

Permalink
fix: support func type in code and class in schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
MichalAI21 committed Dec 12, 2024
1 parent 18c8cd4 commit 3101f90
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
25 changes: 19 additions & 6 deletions ai21/clients/common/beta/assistant/plans.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict
from typing import Any, Dict, Type

from pydantic import BaseModel

from ai21.models.responses.plan_response import PlanResponse, ListPlanResponse
from ai21.types import NOT_GIVEN, NotGiven
Expand All @@ -16,23 +19,33 @@ def create(
self,
*,
assistant_id: str,
code: str,
schemas: list[dict] | NotGiven = NOT_GIVEN,
code: str | callable,
schemas: list[dict] | list[Type[BaseModel]] | NotGiven = NOT_GIVEN,
**kwargs,
) -> PlanResponse:
pass

def _create_body(
self,
*,
code: str,
schemas: list[dict] | NotGiven = NOT_GIVEN,
code: str | callable,
schemas: list[dict] | list[Type[BaseModel]] | NotGiven = NOT_GIVEN,
**kwargs,
) -> Dict[str, Any]:
if callable(code):
code = inspect.getsource(code).strip()

schema_dicts = []
for schema in schemas:
if inspect.isclass(schema) and issubclass(schema, BaseModel):
schema_dicts.append(schema.model_json_schema())
else:
schema_dicts.append(schema)

return remove_not_given(
{
"code": code,
"schemas": schemas,
"schemas": schema_dicts,
**kwargs,
}
)
Expand Down
9 changes: 5 additions & 4 deletions examples/studio/assistant/user_defined_plans.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
ask_for_approval=True,\n )\n return summary"""


def test_func():
pass


class ExampleSchema(BaseModel):
name: str
id: str
Expand All @@ -30,10 +34,7 @@ def main():

assistant = ai21_client.beta.assistants.create(name="My Assistant")

plan = ai21_client.beta.assistants.plans.create(
assistant_id=assistant.id, code=CODE_STR, schemas=[ExampleSchema.model_json_schema()]
)

plan = ai21_client.beta.assistants.plans.create(assistant_id=assistant.id, code=test_func, schemas=[ExampleSchema])
ai21_client.beta.assistants.routes.create(
assistant_id=assistant.id, plan_id=plan.id, name="My Route", examples=["hi"], description="My Route Description"
)
Expand Down

0 comments on commit 3101f90

Please sign in to comment.