diff --git a/dotnet/samples/Concepts/Agents/ChatCompletion_FunctionTermination.cs b/dotnet/samples/Concepts/Agents/ChatCompletion_FunctionTermination.cs index e1612bfc83c1..48fb10ba9cdc 100644 --- a/dotnet/samples/Concepts/Agents/ChatCompletion_FunctionTermination.cs +++ b/dotnet/samples/Concepts/Agents/ChatCompletion_FunctionTermination.cs @@ -233,19 +233,7 @@ private Kernel CreateKernelWithFilter() { IKernelBuilder builder = Kernel.CreateBuilder(); - if (this.UseOpenAIConfig) - { - builder.AddOpenAIChatCompletion( - TestConfiguration.OpenAI.ChatModelId, - TestConfiguration.OpenAI.ApiKey); - } - else - { - builder.AddAzureOpenAIChatCompletion( - TestConfiguration.AzureOpenAI.ChatDeploymentName, - TestConfiguration.AzureOpenAI.Endpoint, - TestConfiguration.AzureOpenAI.ApiKey); - } + base.AddChatCompletionToKernel(builder); builder.Services.AddSingleton(new AutoInvocationFilter()); diff --git a/dotnet/samples/Concepts/Agents/OpenAIAssistant_ChartMaker.cs b/dotnet/samples/Concepts/Agents/OpenAIAssistant_ChartMaker.cs index 807d03ecc130..9074e47b3057 100644 --- a/dotnet/samples/Concepts/Agents/OpenAIAssistant_ChartMaker.cs +++ b/dotnet/samples/Concepts/Agents/OpenAIAssistant_ChartMaker.cs @@ -66,7 +66,7 @@ Sum 426 1622 856 2904 async Task InvokeAgentAsync(string input) { ChatMessageContent message = new(AuthorRole.User, input); - chat.AddChatMessage(new(AuthorRole.User, input)); + chat.AddChatMessage(message); this.WriteAgentChatMessage(message); await foreach (ChatMessageContent response in chat.InvokeAsync(agent)) diff --git a/dotnet/samples/Concepts/Agents/OpenAIAssistant_FunctionFilters.cs b/dotnet/samples/Concepts/Agents/OpenAIAssistant_FunctionFilters.cs new file mode 100644 index 000000000000..db212b758c40 --- /dev/null +++ b/dotnet/samples/Concepts/Agents/OpenAIAssistant_FunctionFilters.cs @@ -0,0 +1,218 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.ComponentModel; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents; +using Microsoft.SemanticKernel.Agents.OpenAI; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Agents; + +/// +/// Demonstrate usage of for and +/// filters with +/// via . +/// +public class OpenAIAssistant_FunctionFilters(ITestOutputHelper output) : BaseAgentsTest(output) +{ + protected override bool ForceOpenAI => true; // %%% REMOVE + + [Fact] + public async Task UseFunctionInvocationFilterAsync() + { + // Define the agent + OpenAIAssistantAgent agent = await CreateAssistantAsync(CreateKernelWithInvokeFilter()); + + // Invoke assistant agent (non streaming) + await InvokeAssistantAsync(agent); + } + + [Fact] + public async Task UseFunctionInvocationFilterStreamingAsync() + { + // Define the agent + OpenAIAssistantAgent agent = await CreateAssistantAsync(CreateKernelWithInvokeFilter()); + + // Invoke assistant agent (streaming) + await InvokeAssistantStreamingAsync(agent); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseAutoFunctionInvocationFilterAsync(bool terminate) + { + // Define the agent + OpenAIAssistantAgent agent = await CreateAssistantAsync(CreateKernelWithAutoFilter(terminate)); + + // Invoke assistant agent (non streaming) + await InvokeAssistantAsync(agent); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseAutoFunctionInvocationFilterWithStreamingAgentInvocationAsync(bool terminate) + { + // Define the agent + OpenAIAssistantAgent agent = await CreateAssistantAsync(CreateKernelWithAutoFilter(terminate)); + + // Invoke assistant agent (streaming) + await InvokeAssistantStreamingAsync(agent); + } + + private async Task InvokeAssistantAsync(OpenAIAssistantAgent agent) + { + // Create a thread for the agent conversation. + AgentGroupChat chat = new(); + + try + { + // Respond to user input, invoking functions where appropriate. + ChatMessageContent message = new(AuthorRole.User, "What is the special soup?"); + chat.AddChatMessage(message); + await chat.InvokeAsync(agent).ToArrayAsync(); + + // Display the entire chat history. + ChatMessageContent[] history = await chat.GetChatMessagesAsync().Reverse().ToArrayAsync(); + this.WriteChatHistory(history); + } + finally + { + await chat.ResetAsync(); + await agent.DeleteAsync(); + } + } + + private async Task InvokeAssistantStreamingAsync(OpenAIAssistantAgent agent) + { + // Create a thread for the agent conversation. + AgentGroupChat chat = new(); + + try + { + // Respond to user input, invoking functions where appropriate. + ChatMessageContent message = new(AuthorRole.User, "What is the special soup?"); + chat.AddChatMessage(message); + await chat.InvokeStreamingAsync(agent).ToArrayAsync(); + + // Display the entire chat history. + ChatMessageContent[] history = await chat.GetChatMessagesAsync().Reverse().ToArrayAsync(); + this.WriteChatHistory(history); + } + finally + { + await chat.ResetAsync(); + await agent.DeleteAsync(); + } + } + + private void WriteChatHistory(IEnumerable history) + { + Console.WriteLine("\n================================"); + Console.WriteLine("CHAT HISTORY"); + Console.WriteLine("================================"); + foreach (ChatMessageContent message in history) + { + this.WriteAgentChatMessage(message); + } + } + + private async Task CreateAssistantAsync(Kernel kernel) + { + OpenAIAssistantAgent agent = + await OpenAIAssistantAgent.CreateAsync( + this.GetClientProvider(), + new OpenAIAssistantDefinition(base.Model) + { + Instructions = "Answer questions about the menu.", + Metadata = AssistantSampleMetadata, + }, + kernel: kernel + ); + + KernelPlugin plugin = KernelPluginFactory.CreateFromType(); + agent.Kernel.Plugins.Add(plugin); + + return agent; + } + + private Kernel CreateKernelWithAutoFilter(bool terminate) + { + IKernelBuilder builder = Kernel.CreateBuilder(); + + base.AddChatCompletionToKernel(builder); + + builder.Services.AddSingleton(new AutoInvocationFilter(terminate)); + + return builder.Build(); + } + + private Kernel CreateKernelWithInvokeFilter() + { + IKernelBuilder builder = Kernel.CreateBuilder(); + + base.AddChatCompletionToKernel(builder); + + builder.Services.AddSingleton(new InvocationFilter()); + + return builder.Build(); + } + + private sealed class MenuPlugin + { + [KernelFunction, Description("Provides a list of specials from the menu.")] + [System.Diagnostics.CodeAnalysis.SuppressMessage("Design", "CA1024:Use properties where appropriate", Justification = "Too smart")] + public string GetSpecials() + { + return + """ + Special Soup: Clam Chowder + Special Salad: Cobb Salad + Special Drink: Chai Tea + """; + } + + [KernelFunction, Description("Provides the price of the requested menu item.")] + public string GetItemPrice( + [Description("The name of the menu item.")] + string menuItem) + { + return "$9.99"; + } + } + + private sealed class InvocationFilter() : IFunctionInvocationFilter + { + public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func next) + { + System.Console.WriteLine($"FILTER INVOKED {nameof(InvocationFilter)} - {context.Function.Name}"); + + // Execution the function + await next(context); + + // Signal termination if the function is from the MenuPlugin + if (context.Function.PluginName == nameof(MenuPlugin)) + { + context.Result = new FunctionResult(context.Function, "BLOCKED"); + } + } + } + + private sealed class AutoInvocationFilter(bool terminate = true) : IAutoFunctionInvocationFilter + { + public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) + { + System.Console.WriteLine($"FILTER INVOKED {nameof(AutoInvocationFilter)} - {context.Function.Name}"); + + // Execution the function + await next(context); + + // Signal termination if the function is from the MenuPlugin + if (context.Function.PluginName == nameof(MenuPlugin)) + { + context.Terminate = terminate; + } + } + } +} diff --git a/dotnet/src/Agents/OpenAI/Agents.OpenAI.csproj b/dotnet/src/Agents/OpenAI/Agents.OpenAI.csproj index a5a4cde76d6f..71747e21ffad 100644 --- a/dotnet/src/Agents/OpenAI/Agents.OpenAI.csproj +++ b/dotnet/src/Agents/OpenAI/Agents.OpenAI.csproj @@ -26,6 +26,7 @@ + diff --git a/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs b/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs index a3af1cfb6626..d2c3aaebd735 100644 --- a/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs +++ b/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs @@ -11,6 +11,7 @@ using Azure; using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.FunctionCalling; using OpenAI.Assistants; namespace Microsoft.SemanticKernel.Agents.OpenAI.Internal; @@ -177,6 +178,9 @@ public static async IAsyncEnumerable GetMessagesAsync(Assist logger.LogOpenAIAssistantCreatedRun(nameof(InvokeAsync), run.Id, threadId); + FunctionCallsProcessor functionProcessor = new(logger); // %%% LOGGER TYPE ???? + FunctionChoiceBehaviorOptions functionOptions = new() { AllowConcurrentInvocation = true, AllowParallelCalls = true }; // %%% DYNAMIC ??? + // Evaluate status and process steps and messages, as encountered. HashSet processedStepIds = []; Dictionary functionSteps = []; @@ -206,13 +210,18 @@ public static async IAsyncEnumerable GetMessagesAsync(Assist if (functionCalls.Length > 0) { // Emit function-call content - yield return (IsVisible: false, Message: GenerateFunctionCallContent(agent.GetName(), functionCalls)); + ChatMessageContent functionCallMessage = GenerateFunctionCallContent(agent.GetName(), functionCalls); + yield return (IsVisible: false, Message: functionCallMessage); // Invoke functions for each tool-step - IEnumerable> functionResultTasks = ExecuteFunctionSteps(agent, functionCalls, cancellationToken); - - // Block for function results - FunctionResultContent[] functionResults = await Task.WhenAll(functionResultTasks).ConfigureAwait(false); + FunctionResultContent[] functionResults = + await functionProcessor.InvokeFunctionCallsAsync( + functionCallMessage, + (_) => true, + functionOptions, + kernel, + isStreaming: false, + cancellationToken).ToArrayAsync(cancellationToken).ConfigureAwait(false); // Capture function-call for message processing foreach (FunctionResultContent functionCall in functionResults) @@ -402,6 +411,9 @@ public static async IAsyncEnumerable InvokeStreamin List stepsToProcess = []; ThreadRun? run = null; + FunctionCallsProcessor functionProcessor = new(logger); // %%% LOGGER TYPE ???? + FunctionChoiceBehaviorOptions functionOptions = new() { AllowConcurrentInvocation = true, AllowParallelCalls = true }; + IAsyncEnumerable asyncUpdates = client.CreateRunStreamingAsync(threadId, agent.Id, options, cancellationToken); do { @@ -495,13 +507,17 @@ await client.GetRunStepsAsync(run.ThreadId, run.Id, cancellationToken: cancellat if (functionCalls.Length > 0) { // Emit function-call content - messages?.Add(GenerateFunctionCallContent(agent.GetName(), functionCalls)); - - // Invoke functions for each tool-step - IEnumerable> functionResultTasks = ExecuteFunctionSteps(agent, functionCalls, cancellationToken); - - // Block for function results - FunctionResultContent[] functionResults = await Task.WhenAll(functionResultTasks).ConfigureAwait(false); + ChatMessageContent functionCallMessage = GenerateFunctionCallContent(agent.GetName(), functionCalls); + messages?.Add(functionCallMessage); + + FunctionResultContent[] functionResults = + await functionProcessor.InvokeFunctionCallsAsync( + functionCallMessage, + (_) => true, + functionOptions, + kernel, + isStreaming: true, + cancellationToken).ToArrayAsync(cancellationToken).ConfigureAwait(false); // Process tool output ToolOutput[] toolOutputs = GenerateToolOutputs(functionResults); diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeAllTypes.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeAllTypes.cs index 7e640ea968e1..7067781987bc 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeAllTypes.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeAllTypes.cs @@ -8,7 +8,6 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone; #pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. public record PineconeAllTypes() -#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. { [VectorStoreRecordKey] public string Id { get; init; } @@ -62,3 +61,4 @@ public record PineconeAllTypes() [VectorStoreRecordVector(Dimensions: 8, DistanceFunction: DistanceFunction.DotProductSimilarity)] public ReadOnlyMemory? Embedding { get; set; } } +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeHotel.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeHotel.cs index 3603f2b5ef04..54185830d5c0 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeHotel.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeHotel.cs @@ -9,7 +9,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone; #pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. public record PineconeHotel() -#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. + { [VectorStoreRecordKey] public string HotelId { get; init; } @@ -37,3 +37,4 @@ public record PineconeHotel() [VectorStoreRecordVector(Dimensions: 8, DistanceFunction: DistanceFunction.DotProductSimilarity)] public ReadOnlyMemory DescriptionEmbedding { get; set; } } +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. diff --git a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs index 1b591b78db77..5cb5c70eb155 100644 --- a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs +++ b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -135,87 +136,52 @@ public FunctionCallsProcessor(ILogger? logger = null) bool isStreaming, CancellationToken cancellationToken) { - var functionCalls = FunctionCallContent.GetFunctionCalls(chatMessageContent).ToList(); - - this._logger.LogFunctionCalls(functionCalls); - // Add the result message to the caller's chat history; // this is required for AI model to understand the function results. chatHistory.Add(chatMessageContent); - var functionTasks = options.AllowConcurrentInvocation && functionCalls.Count > 1 ? - new List>(functionCalls.Count) : - null; + FunctionCallContent[] functionCalls = FunctionCallContent.GetFunctionCalls(chatMessageContent).ToArray(); + + this._logger.LogFunctionCalls(functionCalls); + + List>? functionTasks = + options.AllowConcurrentInvocation && functionCalls.Length > 1 ? + new(functionCalls.Length) : + null; // We must send back a result for every function call, regardless of whether we successfully executed it or not. // If we successfully execute it, we'll add the result. If we don't, we'll add an error. - for (int functionCallIndex = 0; functionCallIndex < functionCalls.Count; functionCallIndex++) + for (int functionCallIndex = 0; functionCallIndex < functionCalls.Length; functionCallIndex++) { FunctionCallContent functionCall = functionCalls[functionCallIndex]; - // Check if the function call has an exception. - if (functionCall.Exception is not null) + // Check if the function call is valid to execute. + if (!TryValidateFunctionCall(functionCall, checkIfFunctionAdvertised, kernel, out KernelFunction? function, out string? errorMessage)) { - this.AddFunctionCallResultToChatHistory(chatHistory, functionCall, result: null, errorMessage: $"Error: Function call processing failed. {functionCall.Exception.Message}"); - continue; - } - - // Make sure the requested function is one of the functions that was advertised to the AI model. - if (!checkIfFunctionAdvertised(functionCall)) - { - this.AddFunctionCallResultToChatHistory(chatHistory, functionCall, result: null, errorMessage: "Error: Function call request for a function that wasn't defined."); - continue; - } - - // Look up the function in the kernel - if (!kernel!.Plugins.TryGetFunction(functionCall.PluginName, functionCall.FunctionName, out KernelFunction? function)) - { - this.AddFunctionCallResultToChatHistory(chatHistory, functionCall, result: null, errorMessage: "Error: Requested function could not be found."); + this.AddFunctionCallErrorToChatHistory(chatHistory, functionCall, errorMessage); continue; } // Prepare context for the auto function invocation filter and invoke it. - AutoFunctionInvocationContext invocationContext = new(kernel, function, new(function) { Culture = kernel.Culture }, chatHistory, chatMessageContent) - { - Arguments = functionCall.Arguments, - RequestSequenceIndex = requestIndex, - FunctionSequenceIndex = functionCallIndex, - FunctionCount = functionCalls.Count, - CancellationToken = cancellationToken, - IsStreaming = isStreaming, - ToolCallId = functionCall.Id - }; - - var functionTask = Task.Run<(string? Result, string? ErrorMessage, FunctionCallContent FunctionCall, AutoFunctionInvocationContext Context)>(async () => - { - s_inflightAutoInvokes.Value++; - try + AutoFunctionInvocationContext invocationContext = + new(kernel!, // Kernel cannot be null if function-call is valid + function, + result: new(function) { Culture = kernel!.Culture }, + chatHistory, + chatMessageContent) { - invocationContext = await this.OnAutoFunctionInvocationAsync(kernel, invocationContext, async (context) => - { - // Check if filter requested termination. - if (context.Terminate) - { - return; - } + Arguments = functionCall.Arguments, + RequestSequenceIndex = requestIndex, + FunctionSequenceIndex = functionCallIndex, + FunctionCount = functionCalls.Length, + CancellationToken = cancellationToken, + IsStreaming = isStreaming, + ToolCallId = functionCall.Id + }; - // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any - // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here, - // as the called function could in turn telling the model about itself as a possible candidate for invocation. - context.Result = await function.InvokeAsync(kernel, invocationContext.Arguments, cancellationToken: cancellationToken).ConfigureAwait(false); - }).ConfigureAwait(false); - } -#pragma warning disable CA1031 // Do not catch general exception types - catch (Exception e) -#pragma warning restore CA1031 // Do not catch general exception types - { - return (null, $"Error: Exception while invoking function. {e.Message}", functionCall, invocationContext); - } + s_inflightAutoInvokes.Value++; - // Apply any changes from the auto function invocation filters context to final result. - var stringResult = ProcessFunctionResult(invocationContext.Result.GetValue() ?? string.Empty); - return (stringResult, null, functionCall, invocationContext); - }, cancellationToken); + Task functionTask = this.ExecuteFunctionCallAsync(invocationContext, functionCall, function, kernel, cancellationToken); // If concurrent invocation is enabled, add the task to the list for later waiting. Otherwise, join with it now. if (functionTasks is not null) @@ -224,8 +190,8 @@ public FunctionCallsProcessor(ILogger? logger = null) } else { - var functionResult = await functionTask.ConfigureAwait(false); - this.AddFunctionCallResultToChatHistory(chatHistory, functionResult.FunctionCall, functionResult.Result, functionResult.ErrorMessage); + FunctionResultContext functionResult = await functionTask.ConfigureAwait(false); + this.AddFunctionCallResultToChatHistory(chatHistory, functionResult); // If filter requested termination, return last chat history message. if (functionResult.Context.Terminate) @@ -243,14 +209,14 @@ public FunctionCallsProcessor(ILogger? logger = null) // Wait for all of the function invocations to complete, then add the results to the chat, but stop when we hit a // function for which termination was requested. - await Task.WhenAll(functionTasks).ConfigureAwait(false); - foreach (var functionTask in functionTasks) + FunctionResultContext[] resultContexts = await Task.WhenAll(functionTasks).ConfigureAwait(false); + foreach (FunctionResultContext resultContext in resultContexts) { - this.AddFunctionCallResultToChatHistory(chatHistory, functionTask.Result.FunctionCall, functionTask.Result.Result, functionTask.Result.ErrorMessage); + this.AddFunctionCallResultToChatHistory(chatHistory, resultContext); - if (functionTask.Result.Context.Terminate) + if (resultContext.Context.Terminate) { - this._logger.LogAutoFunctionInvocationProcessTermination(functionTask.Result.Context); + this._logger.LogAutoFunctionInvocationProcessTermination(resultContext.Context); terminationRequested = true; } } @@ -265,14 +231,195 @@ public FunctionCallsProcessor(ILogger? logger = null) return null; } + /// + /// Processes function calls specifically for Open AI Assistant API. In this context, the chat-history is not + /// present in local memory. + /// + /// The chat message content representing AI model response and containing function calls. + /// Callback to check if a function was advertised to AI model or not. + /// Function choice behavior options. + /// The . + /// Boolean flag which indicates whether an operation is invoked within streaming or non-streaming mode. + /// The to monitor for cancellation requests. + /// Last chat history message if function invocation filter requested processing termination, otherwise null. + public async IAsyncEnumerable InvokeFunctionCallsAsync( + ChatMessageContent chatMessageContent, + Func checkIfFunctionAdvertised, + FunctionChoiceBehaviorOptions options, + Kernel kernel, + bool isStreaming, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + FunctionCallContent[] functionCalls = FunctionCallContent.GetFunctionCalls(chatMessageContent).ToArray(); + + this._logger.LogFunctionCalls(functionCalls); + + List> functionTasks = new(functionCalls.Length); + + // We must send back a result for every function call, regardless of whether we successfully executed it or not. + // If we successfully execute it, we'll add the result. If we don't, we'll add an error. + for (int functionCallIndex = 0; functionCallIndex < functionCalls.Length; functionCallIndex++) + { + FunctionCallContent functionCall = functionCalls[functionCallIndex]; + ChatMessageContent functionCallMessage = + new() + { + Items = [.. functionCalls] + }; + + // Check if the function call is valid to execute. + if (!TryValidateFunctionCall(functionCall, checkIfFunctionAdvertised, kernel, out KernelFunction? function, out string? errorMessage)) + { + yield return this.GenerateResultContent(functionCall, result: null, errorMessage); + continue; + } + + // Prepare context for the auto function invocation filter and invoke it. + AutoFunctionInvocationContext invocationContext = + new(kernel!, // Kernel cannot be null if function-call is valid + function, + result: new(function) { Culture = kernel!.Culture }, + [], + functionCallMessage) + { + Arguments = functionCall.Arguments, + FunctionSequenceIndex = functionCallIndex, + FunctionCount = functionCalls.Length, + CancellationToken = cancellationToken, + IsStreaming = isStreaming, + ToolCallId = functionCall.Id + }; + + s_inflightAutoInvokes.Value++; + + functionTasks.Add(this.ExecuteFunctionCallAsync(invocationContext, functionCall, function, kernel, cancellationToken)); + } + + // Wait for all of the function invocations to complete, then add the results to the chat, but stop when we hit a + // function for which termination was requested. + FunctionResultContext[] resultContexts = await Task.WhenAll(functionTasks).ConfigureAwait(false); + foreach (FunctionResultContext resultContext in resultContexts) + { + yield return this.GenerateResultContent(resultContext); + } + } + + private static bool TryValidateFunctionCall( + FunctionCallContent functionCall, + Func checkIfFunctionAdvertised, + Kernel? kernel, + [NotNullWhen(true)] out KernelFunction? function, + out string? errorMessage) + { + function = null; + + // Check if the function call has an exception. + if (functionCall.Exception is not null) + { + errorMessage = $"Error: Function call processing failed. {functionCall.Exception.Message}"; + return false; + } + + // Make sure the requested function is one of the functions that was advertised to the AI model. + if (!checkIfFunctionAdvertised(functionCall)) + { + errorMessage = "Error: Function call request for a function that wasn't defined."; + return false; + } + + // Look up the function in the kernel + if (kernel?.Plugins.TryGetFunction(functionCall.PluginName, functionCall.FunctionName, out function) ?? false) + { + errorMessage = null; + return true; + } + + errorMessage = "Error: Requested function could not be found."; + return false; + } + + private record struct FunctionResultContext(AutoFunctionInvocationContext Context, FunctionCallContent FunctionCall, string? Result, string? ErrorMessage); + + private async Task ExecuteFunctionCallAsync( + AutoFunctionInvocationContext invocationContext, + FunctionCallContent functionCall, + KernelFunction function, + Kernel kernel, + CancellationToken cancellationToken) + { + try + { + invocationContext = + await this.OnAutoFunctionInvocationAsync( + kernel, + invocationContext, + async (context) => + { + // Check if filter requested termination. + if (context.Terminate) + { + return; + } + + // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any + // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here, + // as the called function could in turn telling the model about itself as a possible candidate for invocation. + context.Result = await function.InvokeAsync(kernel, invocationContext.Arguments, cancellationToken: cancellationToken).ConfigureAwait(false); + }).ConfigureAwait(false); + } +#pragma warning disable CA1031 // Do not catch general exception types + catch (Exception e) +#pragma warning restore CA1031 // Do not catch general exception types + { + return new FunctionResultContext(invocationContext, functionCall, null, $"Error: Exception while invoking function. {e.Message}"); + } + + // Apply any changes from the auto function invocation filters context to final result. + string stringResult = ProcessFunctionResult(invocationContext.Result.GetValue() ?? string.Empty); + return new FunctionResultContext(invocationContext, functionCall, stringResult, null); + } + + /// + /// Adds the function call result or error message to the chat history. + /// + /// The chat history to add the function call result to. + /// The function result context. + private void AddFunctionCallResultToChatHistory(ChatHistory chatHistory, FunctionResultContext resultContext) + { + var message = new ChatMessageContent(role: AuthorRole.Tool, content: resultContext.Result); + message.Items.Add(this.GenerateResultContent(resultContext)); + chatHistory.Add(message); + } + /// /// Adds the function call result or error message to the chat history. /// /// The chat history to add the function call result to. - /// The function call. - /// The function result to add to the chat history. - /// The error message to add to the chat history. - private void AddFunctionCallResultToChatHistory(ChatHistory chatHistory, FunctionCallContent functionCall, string? result, string? errorMessage = null) + /// The function call content. + /// An error message. + private void AddFunctionCallErrorToChatHistory(ChatHistory chatHistory, FunctionCallContent functionCall, string? errorMessage) + { + var message = new ChatMessageContent(role: AuthorRole.Tool, content: errorMessage); + message.Items.Add(this.GenerateResultContent(functionCall, result: null, errorMessage)); + chatHistory.Add(message); + } + + /// + /// Creates a instance. + /// + /// The function result context. + private FunctionResultContent GenerateResultContent(FunctionResultContext resultContext) + { + return this.GenerateResultContent(resultContext.FunctionCall, resultContext.Result, resultContext.ErrorMessage); + } + + /// + /// Creates a instance. + /// + /// The function call content. + /// The function result, if available + /// An error message. + private FunctionResultContent GenerateResultContent(FunctionCallContent functionCall, string? result, string? errorMessage) { // Log any error if (errorMessage is not null) @@ -280,12 +427,7 @@ private void AddFunctionCallResultToChatHistory(ChatHistory chatHistory, Functio this._logger.LogFunctionCallRequestFailure(functionCall, errorMessage); } - result ??= errorMessage ?? string.Empty; - - var message = new ChatMessageContent(role: AuthorRole.Tool, content: result); - message.Items.Add(new FunctionResultContent(functionCall.FunctionName, functionCall.PluginName, functionCall.Id, result)); - - chatHistory.Add(message); + return new FunctionResultContent(functionCall.FunctionName, functionCall.PluginName, functionCall.Id, result ?? errorMessage ?? string.Empty); } /// @@ -338,7 +480,7 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( /// /// The result of the function call. /// A string representation of the function result. - public static string? ProcessFunctionResult(object functionResult) + public static string ProcessFunctionResult(object functionResult) { if (functionResult is string stringResult) { diff --git a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessorLoggerExtensions.cs b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessorLoggerExtensions.cs index ad6c2e033af0..2dee9c9786af 100644 --- a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessorLoggerExtensions.cs +++ b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessorLoggerExtensions.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using Microsoft.Extensions.Logging; @@ -72,7 +71,7 @@ public static void LogFunctionChoiceBehaviorConfiguration(this ILogger logger, F /// /// Logs function calls. /// - public static void LogFunctionCalls(this ILogger logger, List functionCalls) + public static void LogFunctionCalls(this ILogger logger, FunctionCallContent[] functionCalls) { if (logger.IsEnabled(LogLevel.Debug)) { diff --git a/dotnet/src/InternalUtilities/samples/AgentUtilities/BaseAgentsTest.cs b/dotnet/src/InternalUtilities/samples/AgentUtilities/BaseAgentsTest.cs index 2174ad307557..989005333946 100644 --- a/dotnet/src/InternalUtilities/samples/AgentUtilities/BaseAgentsTest.cs +++ b/dotnet/src/InternalUtilities/samples/AgentUtilities/BaseAgentsTest.cs @@ -15,7 +15,7 @@ /// /// Base class for samples that demonstrate the usage of agents. /// -public abstract class BaseAgentsTest(ITestOutputHelper output) : BaseTest(output) +public abstract class BaseAgentsTest(ITestOutputHelper output) : BaseTest(output, redirectSystemConsoleOutput: true) { /// /// Metadata key to indicate the assistant as created for a sample. @@ -81,7 +81,7 @@ protected void WriteAgentChatMessage(ChatMessageContent message) } else if (item is FunctionResultContent functionResult) { - Console.WriteLine($" [{item.GetType().Name}] {functionResult.CallId}"); + Console.WriteLine($" [{item.GetType().Name}] {functionResult.CallId} - {functionResult.Result?.AsJson() ?? "*"}"); } }