From 4650d27938ebbe8c1fa076fac6b78b6b184ef037 Mon Sep 17 00:00:00 2001 From: Chris <66376200+crickman@users.noreply.github.com> Date: Tue, 17 Dec 2024 12:43:24 -0800 Subject: [PATCH] .Net Agents - Support name based KernelFunction*Strategy (#9967) ### Motivation and Context Working with customer looking to minimize token usage. ### Description KernelFunction _selection_ and _termination_ strategies that evaluate name only can save tokens by not including message content. ### Contribution Checklist - [X] The code builds clean without any errors or warnings - [X] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [X] All unit tests pass, and I have added new tests where possible - [X] I didn't break anyone :smile: --- .../Agents/ChatCompletion_ServiceSelection.cs | 2 +- .../Agents/OpenAIAssistant_ChartMaker.cs | 3 +- .../Agents/OpenAIAssistant_FileService.cs | 69 ---------------- .../Step04_KernelFunctionStrategies.cs | 2 + .../Chat/KernelFunctionSelectionStrategy.cs | 7 +- .../Chat/KernelFunctionTerminationStrategy.cs | 7 +- .../Core/Internal/ChatMessageForPrompt.cs | 15 ++-- .../Internal/ChatMessageForPromptTests.cs | 82 +++++++++++++++++++ 8 files changed, 107 insertions(+), 80 deletions(-) delete mode 100644 dotnet/samples/Concepts/Agents/OpenAIAssistant_FileService.cs create mode 100644 dotnet/src/Agents/UnitTests/Core/Internal/ChatMessageForPromptTests.cs diff --git a/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs b/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs index 8921dd2a6f9e..783524adf7f1 100644 --- a/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs +++ b/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs @@ -11,7 +11,7 @@ namespace Agents; /// on and also providing override /// when calling /// -public class ChatCompletion_ServiceSelection(ITestOutputHelper output) : BaseTest(output) +public class ChatCompletion_ServiceSelection(ITestOutputHelper output) : BaseAgentsTest(output) { private const string ServiceKeyGood = "chat-good"; private const string ServiceKeyBad = "chat-bad"; diff --git a/dotnet/samples/Concepts/Agents/OpenAIAssistant_ChartMaker.cs b/dotnet/samples/Concepts/Agents/OpenAIAssistant_ChartMaker.cs index 9074e47b3057..83ea083ec674 100644 --- a/dotnet/samples/Concepts/Agents/OpenAIAssistant_ChartMaker.cs +++ b/dotnet/samples/Concepts/Agents/OpenAIAssistant_ChartMaker.cs @@ -44,7 +44,7 @@ await OpenAIAssistantAgent.CreateAsync( { await InvokeAgentAsync( """ - Display this data using a bar-chart: + Display this data using a bar-chart (not stacked): Banding Brown Pink Yellow Sum X00000 339 433 126 898 @@ -55,7 +55,6 @@ Sum 426 1622 856 2904 """); await InvokeAgentAsync("Can you regenerate this same chart using the category names as the bar colors?"); - await InvokeAgentAsync("Perfect, can you regenerate this as a line chart?"); } finally { diff --git a/dotnet/samples/Concepts/Agents/OpenAIAssistant_FileService.cs b/dotnet/samples/Concepts/Agents/OpenAIAssistant_FileService.cs deleted file mode 100644 index a8f31622c753..000000000000 --- a/dotnet/samples/Concepts/Agents/OpenAIAssistant_FileService.cs +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -using Microsoft.SemanticKernel; -using Microsoft.SemanticKernel.Connectors.OpenAI; -using Resources; - -namespace Agents; - -/// -/// Demonstrate using . -/// -public class OpenAIAssistant_FileService(ITestOutputHelper output) : BaseTest(output) -{ - /// - /// Retrieval tool not supported on Azure OpenAI. - /// - protected override bool ForceOpenAI => true; - - [Fact] - public async Task UploadAndRetrieveFilesAsync() - { -#pragma warning disable CS0618 // Type or member is obsolete - OpenAIFileService fileService = new(TestConfiguration.OpenAI.ApiKey); - - BinaryContent[] files = [ - new AudioContent(await EmbeddedResource.ReadAllAsync("test_audio.wav")!, mimeType: "audio/wav") { InnerContent = "test_audio.wav" }, - new ImageContent(await EmbeddedResource.ReadAllAsync("sample_image.jpg")!, mimeType: "image/jpeg") { InnerContent = "sample_image.jpg" }, - new ImageContent(await EmbeddedResource.ReadAllAsync("test_image.jpg")!, mimeType: "image/jpeg") { InnerContent = "test_image.jpg" }, - new BinaryContent(data: await EmbeddedResource.ReadAllAsync("travelinfo.txt"), mimeType: "text/plain") { InnerContent = "travelinfo.txt" } - ]; - - Dictionary fileContents = new(); - foreach (BinaryContent file in files) - { - OpenAIFileReference result = await fileService.UploadContentAsync(file, new(file.InnerContent!.ToString()!, OpenAIFilePurpose.FineTune)); - fileContents.Add(result.Id, file); - } - - foreach (OpenAIFileReference fileReference in await fileService.GetFilesAsync(OpenAIFilePurpose.FineTune)) - { - // Only interested in the files we uploaded - if (!fileContents.ContainsKey(fileReference.Id)) - { - continue; - } - - BinaryContent content = await fileService.GetFileContentAsync(fileReference.Id); - - string? mimeType = fileContents[fileReference.Id].MimeType; - string? fileName = fileContents[fileReference.Id].InnerContent!.ToString(); - ReadOnlyMemory data = content.Data ?? new(); - - BinaryContent typedContent = mimeType switch - { - "image/jpeg" => new ImageContent(data, mimeType) { Uri = content.Uri, InnerContent = fileName, Metadata = content.Metadata }, - "audio/wav" => new AudioContent(data, mimeType) { Uri = content.Uri, InnerContent = fileName, Metadata = content.Metadata }, - _ => new BinaryContent(data, mimeType) { Uri = content.Uri, InnerContent = fileName, Metadata = content.Metadata } - }; - - Console.WriteLine($"\nFile: {fileName} - {mimeType}"); - Console.WriteLine($"Type: {typedContent}"); - Console.WriteLine($"Uri: {typedContent.Uri}"); - - // Delete the test file remotely - await fileService.DeleteFileAsync(fileReference.Id); - } - -#pragma warning restore CS0618 // Type or member is obsolete - } -} diff --git a/dotnet/samples/GettingStartedWithAgents/Step04_KernelFunctionStrategies.cs b/dotnet/samples/GettingStartedWithAgents/Step04_KernelFunctionStrategies.cs index f3916ad1e583..4c7930bd2533 100644 --- a/dotnet/samples/GettingStartedWithAgents/Step04_KernelFunctionStrategies.cs +++ b/dotnet/samples/GettingStartedWithAgents/Step04_KernelFunctionStrategies.cs @@ -122,6 +122,8 @@ No participant should take more than one turn in a row. HistoryVariableName = "history", // Save tokens by not including the entire history in the prompt HistoryReducer = strategyReducer, + // Only include the agent names and not the message content + EvaluateNameOnly = true, }, } }; diff --git a/dotnet/src/Agents/Core/Chat/KernelFunctionSelectionStrategy.cs b/dotnet/src/Agents/Core/Chat/KernelFunctionSelectionStrategy.cs index ca73ab5ccc8b..fcfea6e1fa93 100644 --- a/dotnet/src/Agents/Core/Chat/KernelFunctionSelectionStrategy.cs +++ b/dotnet/src/Agents/Core/Chat/KernelFunctionSelectionStrategy.cs @@ -53,6 +53,11 @@ public class KernelFunctionSelectionStrategy(KernelFunction function, Kernel ker /// public KernelFunction Function { get; } = function; + /// + /// Only include agent name in history when invoking . + /// + public bool EvaluateNameOnly { get; init; } + /// /// Optionally specify a to reduce the history. /// @@ -79,7 +84,7 @@ protected sealed override async Task SelectAgentAsync(IReadOnlyList kvp.Key, kvp => kvp.Value)) { { this.AgentsVariableName, string.Join(",", agents.Select(a => a.Name)) }, - { this.HistoryVariableName, ChatMessageForPrompt.Format(history) }, + { this.HistoryVariableName, ChatMessageForPrompt.Format(history, this.EvaluateNameOnly) }, }; this.Logger.LogKernelFunctionSelectionStrategyInvokingFunction(nameof(NextAsync), this.Function.PluginName, this.Function.Name); diff --git a/dotnet/src/Agents/Core/Chat/KernelFunctionTerminationStrategy.cs b/dotnet/src/Agents/Core/Chat/KernelFunctionTerminationStrategy.cs index 622366bc768d..26ad20e747dc 100644 --- a/dotnet/src/Agents/Core/Chat/KernelFunctionTerminationStrategy.cs +++ b/dotnet/src/Agents/Core/Chat/KernelFunctionTerminationStrategy.cs @@ -53,6 +53,11 @@ public class KernelFunctionTerminationStrategy(KernelFunction function, Kernel k /// public KernelFunction Function { get; } = function; + /// + /// Only include agent name in history when invoking . + /// + public bool EvaluateNameOnly { get; init; } + /// /// A callback responsible for translating the /// to the termination criteria. @@ -74,7 +79,7 @@ protected sealed override async Task ShouldAgentTerminateAsync(Agent agent new(originalArguments, originalArguments.ExecutionSettings?.ToDictionary(kvp => kvp.Key, kvp => kvp.Value)) { { this.AgentVariableName, agent.Name ?? agent.Id }, - { this.HistoryVariableName, ChatMessageForPrompt.Format(history) }, + { this.HistoryVariableName, ChatMessageForPrompt.Format(history, this.EvaluateNameOnly) }, }; this.Logger.LogKernelFunctionTerminationStrategyInvokingFunction(nameof(ShouldAgentTerminateAsync), this.Function.PluginName, this.Function.Name); diff --git a/dotnet/src/Agents/Core/Internal/ChatMessageForPrompt.cs b/dotnet/src/Agents/Core/Internal/ChatMessageForPrompt.cs index 2ec91664ce4b..8d970988466b 100644 --- a/dotnet/src/Agents/Core/Internal/ChatMessageForPrompt.cs +++ b/dotnet/src/Agents/Core/Internal/ChatMessageForPrompt.cs @@ -1,4 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Linq; using System.Text.Json; @@ -31,14 +32,16 @@ internal sealed class ChatMessageForPrompt(ChatMessageContent message) public string Content => message.Content ?? string.Empty; /// - /// Convenience method to reference a set of messages. + /// Convenience method to format a set of messages for use in a prompt. /// - public static IEnumerable Prepare(IEnumerable messages) => - messages.Where(m => !string.IsNullOrWhiteSpace(m.Content)).Select(m => new ChatMessageForPrompt(m)); + public static string Format(IEnumerable messages, bool useNameOnly = false) => + useNameOnly ? + JsonSerializer.Serialize(Prepare(messages, m => string.IsNullOrEmpty(m.AuthorName) ? m.Role.Label : m.AuthorName).ToArray(), s_jsonOptions) : + JsonSerializer.Serialize(Prepare(messages, m => new ChatMessageForPrompt(m)).ToArray(), s_jsonOptions); /// - /// Convenience method to format a set of messages for use in a prompt. + /// Convenience method to reference a set of messages. /// - public static string Format(IEnumerable messages) => - JsonSerializer.Serialize(Prepare(messages).ToArray(), s_jsonOptions); + internal static IEnumerable Prepare(IEnumerable messages, Func transform) => + messages.Where(m => !string.IsNullOrWhiteSpace(m.Content)).Select(m => transform.Invoke(m)); } diff --git a/dotnet/src/Agents/UnitTests/Core/Internal/ChatMessageForPromptTests.cs b/dotnet/src/Agents/UnitTests/Core/Internal/ChatMessageForPromptTests.cs new file mode 100644 index 000000000000..00a64cd68ca6 --- /dev/null +++ b/dotnet/src/Agents/UnitTests/Core/Internal/ChatMessageForPromptTests.cs @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Text.Json; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents.Internal; +using Microsoft.SemanticKernel.ChatCompletion; +using Xunit; + +namespace SemanticKernel.Agents.UnitTests.Core.Internal; + +/// +/// Unit testing of . +/// +public class ChatMessageForPromptTests +{ + /// + /// Verify formats history for prompt. + /// + [Fact] + public void VerifyFormatHistoryAsync() + { + // Arrange & Act + string history = ChatMessageForPrompt.Format([]); + // Assert + VerifyMessageCount(history, 0); + + // Arrange & Act + history = ChatMessageForPrompt.Format(CreatHistory()); + // Assert + ChatMessageForTest[] messages = VerifyMessageCount(history, 4); + Assert.Equal("test", messages[1].Name); + Assert.Equal(string.Empty, messages[2].Name); + Assert.Equal("test", messages[3].Name); + } + + /// + /// Verify formats history using name only. + /// + [Fact] + public void VerifyFormatNamesAsync() + { + // Arrange & Act + string history = ChatMessageForPrompt.Format([], useNameOnly: true); + // Assert + VerifyMessageCount(history, 0); + + // Arrange & Act + history = ChatMessageForPrompt.Format(CreatHistory(), useNameOnly: true); + // Assert + string[] names = VerifyMessageCount(history, 4); + Assert.Equal("test", names[1]); + Assert.Equal(AuthorRole.Assistant.Label, names[2]); + Assert.Equal("test", names[3]); + } + + private static TResult[] VerifyMessageCount(string history, int expectedLength) + { + TResult[]? messages = JsonSerializer.Deserialize(history); + Assert.NotNull(messages); + Assert.Equal(expectedLength, messages.Length); + return messages; + } + + private static ChatHistory CreatHistory() + { + return + [ + new ChatMessageContent(AuthorRole.User, "content1"), + new ChatMessageContent(AuthorRole.Assistant, "content1") { AuthorName = "test" }, + new ChatMessageContent(AuthorRole.Assistant, "content1"), + new ChatMessageContent(AuthorRole.Assistant, "content1") { AuthorName = "test" }, + ]; + } + + private sealed class ChatMessageForTest + { + public string Role { get; init; } = string.Empty; + + public string? Name { get; init; } = string.Empty; + + public string Content { get; init; } = string.Empty; + } +}