Skip to content

Commit

Permalink
.Net: Add store and metadata properties to OpenAIPromptExecutionSetti…
Browse files Browse the repository at this point in the history
…ngs (#9936)

### Motivation and Context

Closes #9918 

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
markwallace-microsoft authored Dec 11, 2024
1 parent 11c80af commit d229179
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 13 deletions.
30 changes: 30 additions & 0 deletions dotnet/samples/Concepts/ChatCompletion/OpenAI_ChatCompletion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,36 @@ public async Task ChatPromptWithInnerContentAsync()
OutputInnerContent(replyInnerContent!);
}

/// <summary>
/// Demonstrates how you can store the output of a chat completion request for use in the OpenAI model distillation or evals products.
/// </summary>
/// <remarks>
/// This sample adds metadata to the chat completion request which allows the requests to be filtered in the OpenAI dashboard.
/// </remarks>
[Fact]
public async Task ChatPromptStoreWithMetadataAsync()
{
Assert.NotNull(TestConfiguration.OpenAI.ChatModelId);
Assert.NotNull(TestConfiguration.OpenAI.ApiKey);

StringBuilder chatPrompt = new("""
<message role="system">You are a librarian, expert about books</message>
<message role="user">Hi, I'm looking for book suggestions about Artificial Intelligence</message>
""");

var kernel = Kernel.CreateBuilder()
.AddOpenAIChatCompletion(TestConfiguration.OpenAI.ChatModelId, TestConfiguration.OpenAI.ApiKey)
.Build();

var functionResult = await kernel.InvokePromptAsync(chatPrompt.ToString(),
new(new OpenAIPromptExecutionSettings { Store = true, Metadata = new Dictionary<string, string>() { { "concept", "chatcompletion" } } }));

var messageContent = functionResult.GetValue<ChatMessageContent>(); // Retrieves underlying chat message content from FunctionResult.
var replyInnerContent = messageContent!.InnerContent as OpenAI.Chat.ChatCompletion; // Retrieves inner content from ChatMessageContent.

OutputInnerContent(replyInnerContent!);
}

private async Task StartChatAsync(IChatCompletionService chatGPT)
{
Console.WriteLine("Chat content:");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public void ItCreatesOpenAIExecutionSettingsWithCorrectDefaults()
Assert.Null(executionSettings.Logprobs);
Assert.Null(executionSettings.AzureChatDataSource);
Assert.Equal(maxTokensSettings, executionSettings.MaxTokens);
Assert.Null(executionSettings.Store);
Assert.Null(executionSettings.Metadata);
}

[Fact]
Expand All @@ -54,13 +56,24 @@ public void ItUsesExistingOpenAIExecutionSettings()
Logprobs = true,
TopLogprobs = 5,
TokenSelectionBiases = new Dictionary<int, int>() { { 1, 2 }, { 3, 4 } },
Seed = 123456,
Store = true,
Metadata = new Dictionary<string, string>() { { "foo", "bar" } }
};

// Act
AzureOpenAIPromptExecutionSettings executionSettings = AzureOpenAIPromptExecutionSettings.FromExecutionSettings(actualSettings);

// Assert
Assert.Equal(actualSettings, executionSettings);
Assert.Equal(actualSettings, executionSettings);
Assert.Equal(actualSettings.MaxTokens, executionSettings.MaxTokens);
Assert.Equal(actualSettings.Logprobs, executionSettings.Logprobs);
Assert.Equal(actualSettings.TopLogprobs, executionSettings.TopLogprobs);
Assert.Equal(actualSettings.TokenSelectionBiases, executionSettings.TokenSelectionBiases);
Assert.Equal(actualSettings.Seed, executionSettings.Seed);
Assert.Equal(actualSettings.Store, executionSettings.Store);
Assert.Equal(actualSettings.Metadata, executionSettings.Metadata);
}

[Fact]
Expand All @@ -71,7 +84,9 @@ public void ItCanUseOpenAIExecutionSettings()
{
ExtensionData = new Dictionary<string, object>() {
{ "max_tokens", 1000 },
{ "temperature", 0 }
{ "temperature", 0 },
{ "store", true },
{ "metadata", new Dictionary<string, string>() { { "foo", "bar" } } }
}
};

Expand All @@ -82,6 +97,8 @@ public void ItCanUseOpenAIExecutionSettings()
Assert.NotNull(executionSettings);
Assert.Equal(1000, executionSettings.MaxTokens);
Assert.Equal(0, executionSettings.Temperature);
Assert.True(executionSettings.Store);
Assert.Equal(new Dictionary<string, string>() { { "foo", "bar" } }, executionSettings.Metadata);
}

[Fact]
Expand All @@ -103,6 +120,8 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesSnakeCase()
{ "seed", 123456 },
{ "logprobs", true },
{ "top_logprobs", 5 },
{ "store", true },
{ "metadata", new Dictionary<string, string>() { { "foo", "bar" } } }
}
};

Expand Down Expand Up @@ -131,7 +150,9 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesAsStrings()
{ "token_selection_biases", new Dictionary<string, string>() { { "1", "2" }, { "3", "4" } } },
{ "seed", 123456 },
{ "logprobs", true },
{ "top_logprobs", 5 }
{ "top_logprobs", 5 },
{ "store", true },
{ "metadata", new Dictionary<string, string>() { { "foo", "bar" } } }
}
};

Expand All @@ -158,7 +179,9 @@ public void ItCreatesOpenAIExecutionSettingsFromJsonSnakeCase()
"max_tokens": 128,
"seed": 123456,
"logprobs": true,
"top_logprobs": 5
"top_logprobs": 5,
"store": true,
"metadata": { "foo": "bar" }
}
""";
var actualSettings = JsonSerializer.Deserialize<PromptExecutionSettings>(json);
Expand Down Expand Up @@ -217,7 +240,9 @@ public void PromptExecutionSettingsFreezeWorksAsExpected()
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"stop_sequences": [ "DONE" ],
"token_selection_biases": { "1": 2, "3": 4 }
"token_selection_biases": { "1": 2, "3": 4 },
"store": true,
"metadata": { "foo": "bar" }
}
""";
var executionSettings = JsonSerializer.Deserialize<AzureOpenAIPromptExecutionSettings>(configPayload);
Expand All @@ -232,6 +257,8 @@ public void PromptExecutionSettingsFreezeWorksAsExpected()
Assert.Throws<InvalidOperationException>(() => executionSettings.TopP = 1);
Assert.Throws<NotSupportedException>(() => executionSettings.StopSequences?.Add("STOP"));
Assert.Throws<NotSupportedException>(() => executionSettings.TokenSelectionBiases?.Add(5, 6));
Assert.Throws<InvalidOperationException>(() => executionSettings.Store = false);
Assert.Throws<NotSupportedException>(() => executionSettings.Metadata?.Add("bar", "foo"));

executionSettings!.Freeze(); // idempotent
Assert.True(executionSettings.IsFrozen);
Expand Down Expand Up @@ -267,7 +294,9 @@ public void ItCanCreateAzureOpenAIPromptExecutionSettingsFromOpenAIPromptExecuti
Logprobs = true,
Seed = 123456,
TopLogprobs = 5,
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions,
Store = true,
Metadata = new Dictionary<string, string>() { { "foo", "bar" } }
};

// Act
Expand Down Expand Up @@ -307,5 +336,7 @@ private static void AssertExecutionSettings(AzureOpenAIPromptExecutionSettings e
Assert.Equal(123456, executionSettings.Seed);
Assert.Equal(true, executionSettings.Logprobs);
Assert.Equal(5, executionSettings.TopLogprobs);
Assert.Equal(true, executionSettings.Store);
Assert.Equal(new Dictionary<string, string>() { { "foo", "bar" } }, executionSettings.Metadata);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ protected override ChatCompletionOptions CreateChatCompletionOptions(
EndUserId = executionSettings.User,
TopLogProbabilityCount = executionSettings.TopLogprobs,
IncludeLogProbabilities = executionSettings.Logprobs,
StoredOutputEnabled = executionSettings.Store,
};

var responseFormat = GetResponseFormat(executionSettings);
Expand Down Expand Up @@ -90,6 +91,14 @@ protected override ChatCompletionOptions CreateChatCompletionOptions(
}
}

if (executionSettings.Metadata is not null)
{
foreach (var kvp in executionSettings.Metadata)
{
options.Metadata.Add(kvp.Key, kvp.Value);
}
}

if (toolCallingConfig.Options?.AllowParallelCalls is not null)
{
options.AllowParallelToolCalls = toolCallingConfig.Options.AllowParallelCalls;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ public void ItCreatesOpenAIExecutionSettingsWithCorrectDefaults()
Assert.Null(executionSettings.TopLogprobs);
Assert.Null(executionSettings.Logprobs);
Assert.Equal(128, executionSettings.MaxTokens);
Assert.Null(executionSettings.Store);
Assert.Null(executionSettings.Metadata);
}

[Fact]
Expand All @@ -44,12 +46,15 @@ public void ItUsesExistingOpenAIExecutionSettings()
TopP = 0.7,
FrequencyPenalty = 0.7,
PresencePenalty = 0.7,
StopSequences = new string[] { "foo", "bar" },
StopSequences = ["foo", "bar"],
ChatSystemPrompt = "chat system prompt",
MaxTokens = 128,
Logprobs = true,
TopLogprobs = 5,
TokenSelectionBiases = new Dictionary<int, int>() { { 1, 2 }, { 3, 4 } },
Seed = 123456,
Store = true,
Metadata = new Dictionary<string, string>() { { "foo", "bar" } }
};

// Act
Expand All @@ -58,7 +63,13 @@ public void ItUsesExistingOpenAIExecutionSettings()
// Assert
Assert.NotNull(executionSettings);
Assert.Equal(actualSettings, executionSettings);
Assert.Equal(128, executionSettings.MaxTokens);
Assert.Equal(actualSettings.MaxTokens, executionSettings.MaxTokens);
Assert.Equal(actualSettings.Logprobs, executionSettings.Logprobs);
Assert.Equal(actualSettings.TopLogprobs, executionSettings.TopLogprobs);
Assert.Equal(actualSettings.TokenSelectionBiases, executionSettings.TokenSelectionBiases);
Assert.Equal(actualSettings.Seed, executionSettings.Seed);
Assert.Equal(actualSettings.Store, executionSettings.Store);
Assert.Equal(actualSettings.Metadata, executionSettings.Metadata);
}

[Fact]
Expand All @@ -69,7 +80,9 @@ public void ItCanUseOpenAIExecutionSettings()
{
ExtensionData = new Dictionary<string, object>() {
{ "max_tokens", 1000 },
{ "temperature", 0 }
{ "temperature", 0 },
{ "store", true },
{ "metadata", new Dictionary<string, string>() { { "foo", "bar" } } }
}
};

Expand All @@ -80,6 +93,8 @@ public void ItCanUseOpenAIExecutionSettings()
Assert.NotNull(executionSettings);
Assert.Equal(1000, executionSettings.MaxTokens);
Assert.Equal(0, executionSettings.Temperature);
Assert.True(executionSettings.Store);
Assert.Equal(new Dictionary<string, string>() { { "foo", "bar" } }, executionSettings.Metadata);
}

[Fact]
Expand All @@ -102,6 +117,8 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesSnakeCase()
{ "seed", 123456 },
{ "logprobs", true },
{ "top_logprobs", 5 },
{ "store", true },
{ "metadata", new Dictionary<string, string>() { { "foo", "bar" } } }
}
};

Expand Down Expand Up @@ -131,7 +148,9 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesAsStrings()
{ "token_selection_biases", new Dictionary<string, string>() { { "1", "2" }, { "3", "4" } } },
{ "seed", 123456 },
{ "logprobs", true },
{ "top_logprobs", 5 }
{ "top_logprobs", 5 },
{ "store", true },
{ "metadata", new Dictionary<string, string>() { { "foo", "bar" } } }
}
};

Expand Down Expand Up @@ -159,7 +178,9 @@ public void ItCreatesOpenAIExecutionSettingsFromJsonSnakeCase()
"max_tokens": 128,
"seed": 123456,
"logprobs": true,
"top_logprobs": 5
"top_logprobs": 5,
"store": true,
"metadata": { "foo": "bar" }
}
""";
var actualSettings = JsonSerializer.Deserialize<PromptExecutionSettings>(json);
Expand Down Expand Up @@ -219,7 +240,12 @@ public void PromptExecutionSettingsFreezeWorksAsExpected()
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"stop_sequences": [ "DONE" ],
"token_selection_biases": { "1": 2, "3": 4 }
"token_selection_biases": { "1": 2, "3": 4 },
"seed": 123456,
"logprobs": true,
"top_logprobs": 5,
"store": true,
"metadata": { "foo": "bar" }
}
""";
var executionSettings = JsonSerializer.Deserialize<OpenAIPromptExecutionSettings>(configPayload);
Expand All @@ -234,6 +260,11 @@ public void PromptExecutionSettingsFreezeWorksAsExpected()
Assert.Throws<InvalidOperationException>(() => executionSettings.TopP = 1);
Assert.Throws<NotSupportedException>(() => executionSettings.StopSequences?.Add("STOP"));
Assert.Throws<NotSupportedException>(() => executionSettings.TokenSelectionBiases?.Add(5, 6));
Assert.Throws<InvalidOperationException>(() => executionSettings.Seed = 654321);
Assert.Throws<InvalidOperationException>(() => executionSettings.Logprobs = false);
Assert.Throws<InvalidOperationException>(() => executionSettings.TopLogprobs = 10);
Assert.Throws<InvalidOperationException>(() => executionSettings.Store = false);
Assert.Throws<NotSupportedException>(() => executionSettings.Metadata?.Add("bar", "baz"));

executionSettings!.Freeze(); // idempotent
Assert.True(executionSettings.IsFrozen);
Expand Down Expand Up @@ -285,5 +316,7 @@ private static void AssertExecutionSettings(OpenAIPromptExecutionSettings execut
Assert.Equal(123456, executionSettings.Seed);
Assert.Equal(true, executionSettings.Logprobs);
Assert.Equal(5, executionSettings.TopLogprobs);
Assert.Equal(true, executionSettings.Store);
Assert.Equal(new Dictionary<string, string>() { { "foo", "bar" } }, executionSettings.Metadata);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,8 @@ protected virtual ChatCompletionOptions CreateChatCompletionOptions(
#pragma warning restore OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
EndUserId = executionSettings.User,
TopLogProbabilityCount = executionSettings.TopLogprobs,
IncludeLogProbabilities = executionSettings.Logprobs
IncludeLogProbabilities = executionSettings.Logprobs,
StoredOutputEnabled = executionSettings.Store,
};

var responseFormat = GetResponseFormat(executionSettings);
Expand Down Expand Up @@ -496,6 +497,14 @@ protected virtual ChatCompletionOptions CreateChatCompletionOptions(
options.AllowParallelToolCalls = toolCallingConfig.Options.AllowParallelCalls;
}

if (executionSettings.Metadata is not null)
{
foreach (var kvp in executionSettings.Metadata)
{
options.Metadata.Add(kvp.Key, kvp.Value);
}
}

return options;
}

Expand Down
Loading

0 comments on commit d229179

Please sign in to comment.