Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net Agents - Support IAutoFunctionInvocationFilter for OpenAIAssistantAgent #9690

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -233,19 +233,7 @@ private Kernel CreateKernelWithFilter()
{
IKernelBuilder builder = Kernel.CreateBuilder();

if (this.UseOpenAIConfig)
{
builder.AddOpenAIChatCompletion(
TestConfiguration.OpenAI.ChatModelId,
TestConfiguration.OpenAI.ApiKey);
}
else
{
builder.AddAzureOpenAIChatCompletion(
TestConfiguration.AzureOpenAI.ChatDeploymentName,
TestConfiguration.AzureOpenAI.Endpoint,
TestConfiguration.AzureOpenAI.ApiKey);
}
base.AddChatCompletionToKernel(builder);

builder.Services.AddSingleton<IAutoFunctionInvocationFilter>(new AutoInvocationFilter());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Sum 426 1622 856 2904
async Task InvokeAgentAsync(string input)
{
ChatMessageContent message = new(AuthorRole.User, input);
chat.AddChatMessage(new(AuthorRole.User, input));
chat.AddChatMessage(message);
this.WriteAgentChatMessage(message);

await foreach (ChatMessageContent response in chat.InvokeAsync(agent))
Expand Down
218 changes: 218 additions & 0 deletions dotnet/samples/Concepts/Agents/OpenAIAssistant_FunctionFilters.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// Copyright (c) Microsoft. All rights reserved.
using System.ComponentModel;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Microsoft.SemanticKernel.Agents.OpenAI;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Agents;

/// <summary>
/// Demonstrate usage of <see cref="IAutoFunctionInvocationFilter"/> for and
/// <see cref="IFunctionInvocationFilter"/> filters with <see cref="OpenAIAssistantAgent"/>
/// via <see cref="AgentChat"/>.
/// </summary>
public class OpenAIAssistant_FunctionFilters(ITestOutputHelper output) : BaseAgentsTest(output)
{
protected override bool ForceOpenAI => true; // %%% REMOVE

[Fact]
public async Task UseFunctionInvocationFilterAsync()
{
// Define the agent
OpenAIAssistantAgent agent = await CreateAssistantAsync(CreateKernelWithInvokeFilter());

// Invoke assistant agent (non streaming)
await InvokeAssistantAsync(agent);
}

[Fact]
public async Task UseFunctionInvocationFilterStreamingAsync()
{
// Define the agent
OpenAIAssistantAgent agent = await CreateAssistantAsync(CreateKernelWithInvokeFilter());

// Invoke assistant agent (streaming)
await InvokeAssistantStreamingAsync(agent);
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task UseAutoFunctionInvocationFilterAsync(bool terminate)
{
// Define the agent
OpenAIAssistantAgent agent = await CreateAssistantAsync(CreateKernelWithAutoFilter(terminate));

// Invoke assistant agent (non streaming)
await InvokeAssistantAsync(agent);
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task UseAutoFunctionInvocationFilterWithStreamingAgentInvocationAsync(bool terminate)
{
// Define the agent
OpenAIAssistantAgent agent = await CreateAssistantAsync(CreateKernelWithAutoFilter(terminate));

// Invoke assistant agent (streaming)
await InvokeAssistantStreamingAsync(agent);
}

private async Task InvokeAssistantAsync(OpenAIAssistantAgent agent)
{
// Create a thread for the agent conversation.
AgentGroupChat chat = new();

try
{
// Respond to user input, invoking functions where appropriate.
ChatMessageContent message = new(AuthorRole.User, "What is the special soup?");
chat.AddChatMessage(message);
await chat.InvokeAsync(agent).ToArrayAsync();

// Display the entire chat history.
ChatMessageContent[] history = await chat.GetChatMessagesAsync().Reverse().ToArrayAsync();
this.WriteChatHistory(history);
}
finally
{
await chat.ResetAsync();
await agent.DeleteAsync();
}
}

private async Task InvokeAssistantStreamingAsync(OpenAIAssistantAgent agent)
{
// Create a thread for the agent conversation.
AgentGroupChat chat = new();

try
{
// Respond to user input, invoking functions where appropriate.
ChatMessageContent message = new(AuthorRole.User, "What is the special soup?");
chat.AddChatMessage(message);
await chat.InvokeStreamingAsync(agent).ToArrayAsync();

// Display the entire chat history.
ChatMessageContent[] history = await chat.GetChatMessagesAsync().Reverse().ToArrayAsync();
this.WriteChatHistory(history);
}
finally
{
await chat.ResetAsync();
await agent.DeleteAsync();
}
}

private void WriteChatHistory(IEnumerable<ChatMessageContent> history)
{
Console.WriteLine("\n================================");
Console.WriteLine("CHAT HISTORY");
Console.WriteLine("================================");
foreach (ChatMessageContent message in history)
{
this.WriteAgentChatMessage(message);
}
}

private async Task<OpenAIAssistantAgent> CreateAssistantAsync(Kernel kernel)
{
OpenAIAssistantAgent agent =
await OpenAIAssistantAgent.CreateAsync(
this.GetClientProvider(),
new OpenAIAssistantDefinition(base.Model)
{
Instructions = "Answer questions about the menu.",
Metadata = AssistantSampleMetadata,
},
kernel: kernel
);

KernelPlugin plugin = KernelPluginFactory.CreateFromType<MenuPlugin>();
agent.Kernel.Plugins.Add(plugin);

return agent;
}

private Kernel CreateKernelWithAutoFilter(bool terminate)
{
IKernelBuilder builder = Kernel.CreateBuilder();

base.AddChatCompletionToKernel(builder);

builder.Services.AddSingleton<IAutoFunctionInvocationFilter>(new AutoInvocationFilter(terminate));

return builder.Build();
}

private Kernel CreateKernelWithInvokeFilter()
{
IKernelBuilder builder = Kernel.CreateBuilder();

base.AddChatCompletionToKernel(builder);

builder.Services.AddSingleton<IFunctionInvocationFilter>(new InvocationFilter());

return builder.Build();
}

private sealed class MenuPlugin
{
[KernelFunction, Description("Provides a list of specials from the menu.")]
[System.Diagnostics.CodeAnalysis.SuppressMessage("Design", "CA1024:Use properties where appropriate", Justification = "Too smart")]
public string GetSpecials()
{
return
"""
Special Soup: Clam Chowder
Special Salad: Cobb Salad
Special Drink: Chai Tea
""";
}

[KernelFunction, Description("Provides the price of the requested menu item.")]
public string GetItemPrice(
[Description("The name of the menu item.")]
string menuItem)
{
return "$9.99";
}
}

private sealed class InvocationFilter() : IFunctionInvocationFilter
{
public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func<FunctionInvocationContext, Task> next)
{
System.Console.WriteLine($"FILTER INVOKED {nameof(InvocationFilter)} - {context.Function.Name}");

// Execution the function
await next(context);

// Signal termination if the function is from the MenuPlugin
if (context.Function.PluginName == nameof(MenuPlugin))
{
context.Result = new FunctionResult(context.Function, "BLOCKED");
}
}
}

private sealed class AutoInvocationFilter(bool terminate = true) : IAutoFunctionInvocationFilter
{
public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func<AutoFunctionInvocationContext, Task> next)
{
System.Console.WriteLine($"FILTER INVOKED {nameof(AutoInvocationFilter)} - {context.Function.Name}");

// Execution the function
await next(context);

// Signal termination if the function is from the MenuPlugin
if (context.Function.PluginName == nameof(MenuPlugin))
{
context.Terminate = terminate;
}
}
}
}
1 change: 1 addition & 0 deletions dotnet/src/Agents/OpenAI/Agents.OpenAI.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/System/IListExtensions.cs" Link="%(RecursiveDir)System/%(Filename)%(Extension)" />
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/System/AppContextSwitchHelper.cs" Link="%(RecursiveDir)Utilities/%(Filename)%(Extension)" />
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/Functions/FunctionName.cs" Link="%(RecursiveDir)Utilities/%(Filename)%(Extension)" />
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/connectors/AI/**/*.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
</ItemGroup>

<ItemGroup>
Expand Down
40 changes: 28 additions & 12 deletions dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Azure;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.FunctionCalling;
using OpenAI.Assistants;

namespace Microsoft.SemanticKernel.Agents.OpenAI.Internal;
Expand Down Expand Up @@ -177,6 +178,9 @@ public static async IAsyncEnumerable<ChatMessageContent> GetMessagesAsync(Assist

logger.LogOpenAIAssistantCreatedRun(nameof(InvokeAsync), run.Id, threadId);

FunctionCallsProcessor functionProcessor = new(logger); // %%% LOGGER TYPE ????
FunctionChoiceBehaviorOptions functionOptions = new() { AllowConcurrentInvocation = true, AllowParallelCalls = true }; // %%% DYNAMIC ???

// Evaluate status and process steps and messages, as encountered.
HashSet<string> processedStepIds = [];
Dictionary<string, FunctionResultContent> functionSteps = [];
Expand Down Expand Up @@ -206,13 +210,18 @@ public static async IAsyncEnumerable<ChatMessageContent> GetMessagesAsync(Assist
if (functionCalls.Length > 0)
{
// Emit function-call content
yield return (IsVisible: false, Message: GenerateFunctionCallContent(agent.GetName(), functionCalls));
ChatMessageContent functionCallMessage = GenerateFunctionCallContent(agent.GetName(), functionCalls);
yield return (IsVisible: false, Message: functionCallMessage);

// Invoke functions for each tool-step
IEnumerable<Task<FunctionResultContent>> functionResultTasks = ExecuteFunctionSteps(agent, functionCalls, cancellationToken);

// Block for function results
FunctionResultContent[] functionResults = await Task.WhenAll(functionResultTasks).ConfigureAwait(false);
FunctionResultContent[] functionResults =
await functionProcessor.InvokeFunctionCallsAsync(
functionCallMessage,
(_) => true,
functionOptions,
kernel,
isStreaming: false,
cancellationToken).ToArrayAsync(cancellationToken).ConfigureAwait(false);

// Capture function-call for message processing
foreach (FunctionResultContent functionCall in functionResults)
Expand Down Expand Up @@ -402,6 +411,9 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin
List<RunStep> stepsToProcess = [];
ThreadRun? run = null;

FunctionCallsProcessor functionProcessor = new(logger); // %%% LOGGER TYPE ????
FunctionChoiceBehaviorOptions functionOptions = new() { AllowConcurrentInvocation = true, AllowParallelCalls = true };

IAsyncEnumerable<StreamingUpdate> asyncUpdates = client.CreateRunStreamingAsync(threadId, agent.Id, options, cancellationToken);
do
{
Expand Down Expand Up @@ -495,13 +507,17 @@ await client.GetRunStepsAsync(run.ThreadId, run.Id, cancellationToken: cancellat
if (functionCalls.Length > 0)
{
// Emit function-call content
messages?.Add(GenerateFunctionCallContent(agent.GetName(), functionCalls));

// Invoke functions for each tool-step
IEnumerable<Task<FunctionResultContent>> functionResultTasks = ExecuteFunctionSteps(agent, functionCalls, cancellationToken);

// Block for function results
FunctionResultContent[] functionResults = await Task.WhenAll(functionResultTasks).ConfigureAwait(false);
ChatMessageContent functionCallMessage = GenerateFunctionCallContent(agent.GetName(), functionCalls);
messages?.Add(functionCallMessage);

FunctionResultContent[] functionResults =
await functionProcessor.InvokeFunctionCallsAsync(
functionCallMessage,
(_) => true,
functionOptions,
kernel,
isStreaming: true,
cancellationToken).ToArrayAsync(cancellationToken).ConfigureAwait(false);

// Process tool output
ToolOutput[] toolOutputs = GenerateToolOutputs(functionResults);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone;

#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.
public record PineconeAllTypes()
#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.
{
[VectorStoreRecordKey]
public string Id { get; init; }
Expand Down Expand Up @@ -62,3 +61,4 @@ public record PineconeAllTypes()
[VectorStoreRecordVector(Dimensions: 8, DistanceFunction: DistanceFunction.DotProductSimilarity)]
public ReadOnlyMemory<float>? Embedding { get; set; }
}
#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone;

#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.
public record PineconeHotel()
#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.

{
[VectorStoreRecordKey]
public string HotelId { get; init; }
Expand Down Expand Up @@ -37,3 +37,4 @@ public record PineconeHotel()
[VectorStoreRecordVector(Dimensions: 8, DistanceFunction: DistanceFunction.DotProductSimilarity)]
public ReadOnlyMemory<float> DescriptionEmbedding { get; set; }
}
#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.
Loading
Loading