Skip to content

Commit

Permalink
Python: Fix Anthropic parallel tool call (#10005)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
When using the Anthropic connector with function calling, it's possible
that the model will request multiple functions in a single request. This
is referred to as [`parallel tool
use`](https://docs.anthropic.com/en/docs/build-with-claude/tool-use#disabling-parallel-tool-use)
by Anthropic.

When the model requests multiple functions, it expects the tool results
to be included in a single user message to be passed back to the model.
Right now, the Anthropic connector parses the tool results into multiple
user messages, which causes the model to throw an error.

This is a possible regression introduced by this PR:
#9938

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->
1. Fix the Anthropic connector to handle parallel tool calls.
2. Add unit tests to ensure that future changes don't break this
functionality.
3. Enable integration test on Anthropic since we have a service endpoint
now.
4. Refactoring.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
TaoChenOSU authored Dec 19, 2024
1 parent 16690ed commit f50117f
Show file tree
Hide file tree
Showing 5 changed files with 611 additions and 498 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging
import sys
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from typing import Any, ClassVar

if sys.version_info >= (3, 12):
Expand All @@ -26,7 +26,10 @@
from semantic_kernel.connectors.ai.anthropic.prompt_execution_settings.anthropic_prompt_execution_settings import (
AnthropicChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.anthropic.services.utils import MESSAGE_CONVERTERS
from semantic_kernel.connectors.ai.anthropic.services.utils import (
MESSAGE_CONVERTERS,
update_settings_from_function_call_configuration,
)
from semantic_kernel.connectors.ai.anthropic.settings.anthropic_settings import AnthropicSettings
from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration
Expand All @@ -43,10 +46,10 @@
from semantic_kernel.contents.utils.finish_reason import FinishReason as SemanticKernelFinishReason
from semantic_kernel.exceptions.service_exceptions import (
ServiceInitializationError,
ServiceInvalidRequestError,
ServiceInvalidResponseError,
ServiceResponseException,
)
from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata
from semantic_kernel.utils.experimental_decorator import experimental_class
from semantic_kernel.utils.telemetry.model_diagnostics.decorators import (
trace_chat_completion,
Expand Down Expand Up @@ -130,6 +133,19 @@ def get_prompt_execution_settings_class(self) -> type["PromptExecutionSettings"]
def service_url(self) -> str | None:
return str(self.async_client.base_url)

@override
def _update_function_choice_settings_callback(
self,
) -> Callable[[FunctionCallChoiceConfiguration, "PromptExecutionSettings", FunctionChoiceType], None]:
return update_settings_from_function_call_configuration

@override
def _reset_function_choice_settings(self, settings: "PromptExecutionSettings") -> None:
if hasattr(settings, "tool_choice"):
settings.tool_choice = None
if hasattr(settings, "tools"):
settings.tools = None

@override
@trace_chat_completion(MODEL_PROVIDER_NAME)
async def _inner_get_chat_message_contents(
Expand Down Expand Up @@ -171,6 +187,7 @@ async def _inner_get_streaming_chat_message_contents(
async for message in response:
yield message

@override
def _prepare_chat_history_for_request(
self,
chat_history: "ChatHistory",
Expand All @@ -194,14 +211,37 @@ def _prepare_chat_history_for_request(
system_message_content = None
system_message_count = 0
formatted_messages: list[dict[str, Any]] = []
for message in chat_history.messages:
# Skip system messages after the first one is found
if message.role == AuthorRole.SYSTEM:
for i in range(len(chat_history)):
prev_message = chat_history[i - 1] if i > 0 else None
curr_message = chat_history[i]
if curr_message.role == AuthorRole.SYSTEM:
# Skip system messages after the first one is found
if system_message_count == 0:
system_message_content = message.content
system_message_content = curr_message.content
system_message_count += 1
elif curr_message.role == AuthorRole.USER or curr_message.role == AuthorRole.ASSISTANT:
formatted_messages.append(MESSAGE_CONVERTERS[curr_message.role](curr_message))
elif curr_message.role == AuthorRole.TOOL:
if prev_message is None:
# Under no circumstances should a tool message be the first message in the chat history
raise ServiceInvalidRequestError("Tool message found without a preceding message.")
if prev_message.role == AuthorRole.USER or prev_message.role == AuthorRole.SYSTEM:
# A tool message should not be found after a user or system message
# Please NOTE that in SK there are the USER role and the TOOL role, but in Anthropic
# the tool messages are considered as USER messages. We are checking against the SK roles.
raise ServiceInvalidRequestError("Tool message found after a user or system message.")

formatted_message = MESSAGE_CONVERTERS[curr_message.role](curr_message)
if prev_message.role == AuthorRole.ASSISTANT:
# The first tool message after an assistant message should be a new message
formatted_messages.append(formatted_message)
else:
# Append the tool message to the previous tool message.
# This indicates that the assistant message requested multiple parallel tool calls.
# Anthropic requires that parallel Tool messages are grouped together in a single message.
formatted_messages[-1][content_key] += formatted_message[content_key]
else:
formatted_messages.append(MESSAGE_CONVERTERS[message.role](message))
raise ServiceInvalidRequestError(f"Unsupported role in chat history: {curr_message.role}")

if system_message_count > 1:
logger.warning(
Expand Down Expand Up @@ -277,50 +317,6 @@ def _create_streaming_chat_message_content(
items=items,
)

def update_settings_from_function_call_configuration_anthropic(
self,
function_choice_configuration: FunctionCallChoiceConfiguration,
settings: "PromptExecutionSettings",
type: "FunctionChoiceType",
) -> None:
"""Update the settings from a FunctionChoiceConfiguration."""
if (
function_choice_configuration.available_functions
and hasattr(settings, "tools")
and hasattr(settings, "tool_choice")
):
settings.tools = [
self.kernel_function_metadata_to_function_call_format_anthropic(f)
for f in function_choice_configuration.available_functions
]

if (
settings.function_choice_behavior
and settings.function_choice_behavior.type_ == FunctionChoiceType.REQUIRED
) or type == FunctionChoiceType.REQUIRED:
settings.tool_choice = {"type": "any"}
else:
settings.tool_choice = {"type": type.value}

def kernel_function_metadata_to_function_call_format_anthropic(
self,
metadata: KernelFunctionMetadata,
) -> dict[str, Any]:
"""Convert the kernel function metadata to function calling format."""
return {
"name": metadata.fully_qualified_name,
"description": metadata.description or "",
"input_schema": {
"type": "object",
"properties": {p.name: p.schema_data for p in metadata.parameters},
"required": [p.name for p in metadata.parameters if p.is_required],
},
}

@override
def _update_function_choice_settings_callback(self):
return self.update_settings_from_function_call_configuration_anthropic

async def _send_chat_request(self, settings: AnthropicChatPromptExecutionSettings) -> list["ChatMessageContent"]:
"""Send the chat request."""
try:
Expand Down Expand Up @@ -382,10 +378,3 @@ def _get_tool_calls_from_message(self, message: Message) -> list[FunctionCallCon
)

return tool_calls

@override
def _reset_function_choice_settings(self, settings: "PromptExecutionSettings") -> None:
if hasattr(settings, "tool_choice"):
settings.tool_choice = None
if hasattr(settings, "tools"):
settings.tools = None
74 changes: 59 additions & 15 deletions python/semantic_kernel/connectors/ai/anthropic/services/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
from collections.abc import Callable, Mapping
from typing import Any

from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceType
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,29 +54,32 @@ def _format_assistant_message(message: ChatMessageContent) -> dict[str, Any]:
"type": "tool_use",
"id": item.id or "",
"name": item.name or "",
"input": item.arguments if isinstance(item.arguments, Mapping) else json.loads(item.arguments or ""),
"input": item.arguments
if isinstance(item.arguments, Mapping)
else json.loads(item.arguments)
if item.arguments
else {},
})
else:
logger.warning(
f"Unsupported item type in Assistant message while formatting chat history for Anthropic: {type(item)}"
)

formatted_message: dict[str, Any] = {"role": "assistant", "content": []}

if message.content:
# Only include the text content if it is not empty.
# Otherwise, the Anthropic client will throw an error.
formatted_message["content"].append({ # type: ignore
"type": "text",
"text": message.content,
})
if tool_calls:
return {
"role": "assistant",
"content": [
{
"type": "text",
"text": message.content,
},
*tool_calls,
],
}
# Only include the tool calls if there are any.
# Otherwise, the Anthropic client will throw an error.
formatted_message["content"].extend(tool_calls) # type: ignore

return {
"role": "assistant",
"content": message.content,
}
return formatted_message


def _format_tool_message(message: ChatMessageContent) -> dict[str, Any]:
Expand Down Expand Up @@ -108,3 +115,40 @@ def _format_tool_message(message: ChatMessageContent) -> dict[str, Any]:
AuthorRole.ASSISTANT: _format_assistant_message,
AuthorRole.TOOL: _format_tool_message,
}


def update_settings_from_function_call_configuration(
function_choice_configuration: FunctionCallChoiceConfiguration,
settings: PromptExecutionSettings,
type: FunctionChoiceType,
) -> None:
"""Update the settings from a FunctionChoiceConfiguration."""
if (
function_choice_configuration.available_functions
and hasattr(settings, "tools")
and hasattr(settings, "tool_choice")
):
settings.tools = [
kernel_function_metadata_to_function_call_format(f)
for f in function_choice_configuration.available_functions
]

if (
settings.function_choice_behavior and settings.function_choice_behavior.type_ == FunctionChoiceType.REQUIRED
) or type == FunctionChoiceType.REQUIRED:
settings.tool_choice = {"type": "any"}
else:
settings.tool_choice = {"type": type.value}


def kernel_function_metadata_to_function_call_format(metadata: KernelFunctionMetadata) -> dict[str, Any]:
"""Convert the kernel function metadata to function calling format."""
return {
"name": metadata.fully_qualified_name,
"description": metadata.description or "",
"input_schema": {
"type": "object",
"properties": {p.name: p.schema_data for p in metadata.parameters},
"required": [p.name for p in metadata.parameters if p.is_required],
},
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@
onnx_setup: bool = is_service_setup_for_testing(
["ONNX_GEN_AI_CHAT_MODEL_FOLDER"], raise_if_not_set=False
) # Tests are optional for ONNX
anthropic_setup: bool = is_service_setup_for_testing(
["ANTHROPIC_API_KEY", "ANTHROPIC_CHAT_MODEL_ID"], raise_if_not_set=False
) # We don't have an Anthropic deployment
anthropic_setup: bool = is_service_setup_for_testing(["ANTHROPIC_API_KEY", "ANTHROPIC_CHAT_MODEL_ID"])
# When testing Bedrock, after logging into AWS CLI this has been set, so we can use it to check if the service is setup
bedrock_setup: bool = is_service_setup_for_testing(["AWS_DEFAULT_REGION"], raise_if_not_set=False)

Expand Down
Loading

0 comments on commit f50117f

Please sign in to comment.