Skip to content

Commit

Permalink
.Net Agents - Support name based KernelFunction*Strategy (#9967)
Browse files Browse the repository at this point in the history
### Motivation and Context
<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

Working with customer looking to minimize token usage.


### Description
<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

KernelFunction _selection_ and _termination_ strategies that evaluate
name only can save tokens by not including message content.

### Contribution Checklist
<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [X] I didn't break anyone 😄
  • Loading branch information
crickman authored Dec 17, 2024
1 parent 4a21254 commit 4650d27
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace Agents;
/// on <see cref="KernelAgent.Arguments"/> and also providing override <see cref="KernelArguments"/>
/// when calling <see cref="ChatCompletionAgent.InvokeAsync"/>
/// </summary>
public class ChatCompletion_ServiceSelection(ITestOutputHelper output) : BaseTest(output)
public class ChatCompletion_ServiceSelection(ITestOutputHelper output) : BaseAgentsTest(output)
{
private const string ServiceKeyGood = "chat-good";
private const string ServiceKeyBad = "chat-bad";
Expand Down
3 changes: 1 addition & 2 deletions dotnet/samples/Concepts/Agents/OpenAIAssistant_ChartMaker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ await OpenAIAssistantAgent.CreateAsync(
{
await InvokeAgentAsync(
"""
Display this data using a bar-chart:
Display this data using a bar-chart (not stacked):
Banding Brown Pink Yellow Sum
X00000 339 433 126 898
Expand All @@ -55,7 +55,6 @@ Sum 426 1622 856 2904
""");

await InvokeAgentAsync("Can you regenerate this same chart using the category names as the bar colors?");
await InvokeAgentAsync("Perfect, can you regenerate this as a line chart?");
}
finally
{
Expand Down
69 changes: 0 additions & 69 deletions dotnet/samples/Concepts/Agents/OpenAIAssistant_FileService.cs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ No participant should take more than one turn in a row.
HistoryVariableName = "history",
// Save tokens by not including the entire history in the prompt
HistoryReducer = strategyReducer,
// Only include the agent names and not the message content
EvaluateNameOnly = true,
},
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ public class KernelFunctionSelectionStrategy(KernelFunction function, Kernel ker
/// </summary>
public KernelFunction Function { get; } = function;

/// <summary>
/// Only include agent name in history when invoking <see cref="KernelFunctionTerminationStrategy.Function"/>.
/// </summary>
public bool EvaluateNameOnly { get; init; }

/// <summary>
/// Optionally specify a <see cref="IChatHistoryReducer"/> to reduce the history.
/// </summary>
Expand All @@ -79,7 +84,7 @@ protected sealed override async Task<Agent> SelectAgentAsync(IReadOnlyList<Agent
new(originalArguments, originalArguments.ExecutionSettings?.ToDictionary(kvp => kvp.Key, kvp => kvp.Value))
{
{ this.AgentsVariableName, string.Join(",", agents.Select(a => a.Name)) },
{ this.HistoryVariableName, ChatMessageForPrompt.Format(history) },
{ this.HistoryVariableName, ChatMessageForPrompt.Format(history, this.EvaluateNameOnly) },
};

this.Logger.LogKernelFunctionSelectionStrategyInvokingFunction(nameof(NextAsync), this.Function.PluginName, this.Function.Name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ public class KernelFunctionTerminationStrategy(KernelFunction function, Kernel k
/// </summary>
public KernelFunction Function { get; } = function;

/// <summary>
/// Only include agent name in history when invoking <see cref="KernelFunctionTerminationStrategy.Function"/>.
/// </summary>
public bool EvaluateNameOnly { get; init; }

/// <summary>
/// A callback responsible for translating the <see cref="FunctionResult"/>
/// to the termination criteria.
Expand All @@ -74,7 +79,7 @@ protected sealed override async Task<bool> ShouldAgentTerminateAsync(Agent agent
new(originalArguments, originalArguments.ExecutionSettings?.ToDictionary(kvp => kvp.Key, kvp => kvp.Value))
{
{ this.AgentVariableName, agent.Name ?? agent.Id },
{ this.HistoryVariableName, ChatMessageForPrompt.Format(history) },
{ this.HistoryVariableName, ChatMessageForPrompt.Format(history, this.EvaluateNameOnly) },
};

this.Logger.LogKernelFunctionTerminationStrategyInvokingFunction(nameof(ShouldAgentTerminateAsync), this.Function.PluginName, this.Function.Name);
Expand Down
15 changes: 9 additions & 6 deletions dotnet/src/Agents/Core/Internal/ChatMessageForPrompt.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
Expand Down Expand Up @@ -31,14 +32,16 @@ internal sealed class ChatMessageForPrompt(ChatMessageContent message)
public string Content => message.Content ?? string.Empty;

/// <summary>
/// Convenience method to reference a set of messages.
/// Convenience method to format a set of messages for use in a prompt.
/// </summary>
public static IEnumerable<ChatMessageForPrompt> Prepare(IEnumerable<ChatMessageContent> messages) =>
messages.Where(m => !string.IsNullOrWhiteSpace(m.Content)).Select(m => new ChatMessageForPrompt(m));
public static string Format(IEnumerable<ChatMessageContent> messages, bool useNameOnly = false) =>
useNameOnly ?
JsonSerializer.Serialize(Prepare(messages, m => string.IsNullOrEmpty(m.AuthorName) ? m.Role.Label : m.AuthorName).ToArray(), s_jsonOptions) :
JsonSerializer.Serialize(Prepare(messages, m => new ChatMessageForPrompt(m)).ToArray(), s_jsonOptions);

/// <summary>
/// Convenience method to format a set of messages for use in a prompt.
/// Convenience method to reference a set of messages.
/// </summary>
public static string Format(IEnumerable<ChatMessageContent> messages) =>
JsonSerializer.Serialize(Prepare(messages).ToArray(), s_jsonOptions);
internal static IEnumerable<TResult> Prepare<TResult>(IEnumerable<ChatMessageContent> messages, Func<ChatMessageContent, TResult> transform) =>
messages.Where(m => !string.IsNullOrWhiteSpace(m.Content)).Select(m => transform.Invoke(m));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Text.Json;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents.Internal;
using Microsoft.SemanticKernel.ChatCompletion;
using Xunit;

namespace SemanticKernel.Agents.UnitTests.Core.Internal;

/// <summary>
/// Unit testing of <see cref="ChatMessageForPrompt"/>.
/// </summary>
public class ChatMessageForPromptTests
{
/// <summary>
/// Verify <see cref="ChatMessageForPrompt"/> formats history for prompt.
/// </summary>
[Fact]
public void VerifyFormatHistoryAsync()
{
// Arrange & Act
string history = ChatMessageForPrompt.Format([]);
// Assert
VerifyMessageCount<ChatMessageForTest>(history, 0);

// Arrange & Act
history = ChatMessageForPrompt.Format(CreatHistory());
// Assert
ChatMessageForTest[] messages = VerifyMessageCount<ChatMessageForTest>(history, 4);
Assert.Equal("test", messages[1].Name);
Assert.Equal(string.Empty, messages[2].Name);
Assert.Equal("test", messages[3].Name);
}

/// <summary>
/// Verify <see cref="ChatMessageForPrompt"/> formats history using name only.
/// </summary>
[Fact]
public void VerifyFormatNamesAsync()
{
// Arrange & Act
string history = ChatMessageForPrompt.Format([], useNameOnly: true);
// Assert
VerifyMessageCount<string>(history, 0);

// Arrange & Act
history = ChatMessageForPrompt.Format(CreatHistory(), useNameOnly: true);
// Assert
string[] names = VerifyMessageCount<string>(history, 4);
Assert.Equal("test", names[1]);
Assert.Equal(AuthorRole.Assistant.Label, names[2]);
Assert.Equal("test", names[3]);
}

private static TResult[] VerifyMessageCount<TResult>(string history, int expectedLength)
{
TResult[]? messages = JsonSerializer.Deserialize<TResult[]>(history);
Assert.NotNull(messages);
Assert.Equal(expectedLength, messages.Length);
return messages;
}

private static ChatHistory CreatHistory()
{
return
[
new ChatMessageContent(AuthorRole.User, "content1"),
new ChatMessageContent(AuthorRole.Assistant, "content1") { AuthorName = "test" },
new ChatMessageContent(AuthorRole.Assistant, "content1"),
new ChatMessageContent(AuthorRole.Assistant, "content1") { AuthorName = "test" },
];
}

private sealed class ChatMessageForTest
{
public string Role { get; init; } = string.Empty;

public string? Name { get; init; } = string.Empty;

public string Content { get; init; } = string.Empty;
}
}

0 comments on commit 4650d27

Please sign in to comment.