From f50117f75a9e587f8f8cc2c8d7f5be3bd36b9ea5 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 18 Dec 2024 21:16:21 -0800 Subject: [PATCH] Python: Fix Anthropic parallel tool call (#10005) ### Motivation and Context When using the Anthropic connector with function calling, it's possible that the model will request multiple functions in a single request. This is referred to as [`parallel tool use`](https://docs.anthropic.com/en/docs/build-with-claude/tool-use#disabling-parallel-tool-use) by Anthropic. When the model requests multiple functions, it expects the tool results to be included in a single user message to be passed back to the model. Right now, the Anthropic connector parses the tool results into multiple user messages, which causes the model to throw an error. This is a possible regression introduced by this PR: https://github.com/microsoft/semantic-kernel/pull/9938 ### Description 1. Fix the Anthropic connector to handle parallel tool calls. 2. Add unit tests to ensure that future changes don't break this functionality. 3. Enable integration test on Anthropic since we have a service endpoint now. 4. Refactoring. ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --- .../services/anthropic_chat_completion.py | 107 ++-- .../connectors/ai/anthropic/services/utils.py | 74 ++- .../completions/chat_completion_test_base.py | 4 +- .../unit/connectors/ai/anthropic/conftest.py | 400 +++++++++++++ .../test_anthropic_chat_completion.py | 524 ++++-------------- 5 files changed, 611 insertions(+), 498 deletions(-) create mode 100644 python/tests/unit/connectors/ai/anthropic/conftest.py diff --git a/python/semantic_kernel/connectors/ai/anthropic/services/anthropic_chat_completion.py b/python/semantic_kernel/connectors/ai/anthropic/services/anthropic_chat_completion.py index ed2616ba71aa..4c4c9da92d60 100644 --- a/python/semantic_kernel/connectors/ai/anthropic/services/anthropic_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/anthropic/services/anthropic_chat_completion.py @@ -3,7 +3,7 @@ import json import logging import sys -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from typing import Any, ClassVar if sys.version_info >= (3, 12): @@ -26,7 +26,10 @@ from semantic_kernel.connectors.ai.anthropic.prompt_execution_settings.anthropic_prompt_execution_settings import ( AnthropicChatPromptExecutionSettings, ) -from semantic_kernel.connectors.ai.anthropic.services.utils import MESSAGE_CONVERTERS +from semantic_kernel.connectors.ai.anthropic.services.utils import ( + MESSAGE_CONVERTERS, + update_settings_from_function_call_configuration, +) from semantic_kernel.connectors.ai.anthropic.settings.anthropic_settings import AnthropicSettings from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration @@ -43,10 +46,10 @@ from semantic_kernel.contents.utils.finish_reason import FinishReason as SemanticKernelFinishReason from semantic_kernel.exceptions.service_exceptions import ( ServiceInitializationError, + ServiceInvalidRequestError, ServiceInvalidResponseError, ServiceResponseException, ) -from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata from semantic_kernel.utils.experimental_decorator import experimental_class from semantic_kernel.utils.telemetry.model_diagnostics.decorators import ( trace_chat_completion, @@ -130,6 +133,19 @@ def get_prompt_execution_settings_class(self) -> type["PromptExecutionSettings"] def service_url(self) -> str | None: return str(self.async_client.base_url) + @override + def _update_function_choice_settings_callback( + self, + ) -> Callable[[FunctionCallChoiceConfiguration, "PromptExecutionSettings", FunctionChoiceType], None]: + return update_settings_from_function_call_configuration + + @override + def _reset_function_choice_settings(self, settings: "PromptExecutionSettings") -> None: + if hasattr(settings, "tool_choice"): + settings.tool_choice = None + if hasattr(settings, "tools"): + settings.tools = None + @override @trace_chat_completion(MODEL_PROVIDER_NAME) async def _inner_get_chat_message_contents( @@ -171,6 +187,7 @@ async def _inner_get_streaming_chat_message_contents( async for message in response: yield message + @override def _prepare_chat_history_for_request( self, chat_history: "ChatHistory", @@ -194,14 +211,37 @@ def _prepare_chat_history_for_request( system_message_content = None system_message_count = 0 formatted_messages: list[dict[str, Any]] = [] - for message in chat_history.messages: - # Skip system messages after the first one is found - if message.role == AuthorRole.SYSTEM: + for i in range(len(chat_history)): + prev_message = chat_history[i - 1] if i > 0 else None + curr_message = chat_history[i] + if curr_message.role == AuthorRole.SYSTEM: + # Skip system messages after the first one is found if system_message_count == 0: - system_message_content = message.content + system_message_content = curr_message.content system_message_count += 1 + elif curr_message.role == AuthorRole.USER or curr_message.role == AuthorRole.ASSISTANT: + formatted_messages.append(MESSAGE_CONVERTERS[curr_message.role](curr_message)) + elif curr_message.role == AuthorRole.TOOL: + if prev_message is None: + # Under no circumstances should a tool message be the first message in the chat history + raise ServiceInvalidRequestError("Tool message found without a preceding message.") + if prev_message.role == AuthorRole.USER or prev_message.role == AuthorRole.SYSTEM: + # A tool message should not be found after a user or system message + # Please NOTE that in SK there are the USER role and the TOOL role, but in Anthropic + # the tool messages are considered as USER messages. We are checking against the SK roles. + raise ServiceInvalidRequestError("Tool message found after a user or system message.") + + formatted_message = MESSAGE_CONVERTERS[curr_message.role](curr_message) + if prev_message.role == AuthorRole.ASSISTANT: + # The first tool message after an assistant message should be a new message + formatted_messages.append(formatted_message) + else: + # Append the tool message to the previous tool message. + # This indicates that the assistant message requested multiple parallel tool calls. + # Anthropic requires that parallel Tool messages are grouped together in a single message. + formatted_messages[-1][content_key] += formatted_message[content_key] else: - formatted_messages.append(MESSAGE_CONVERTERS[message.role](message)) + raise ServiceInvalidRequestError(f"Unsupported role in chat history: {curr_message.role}") if system_message_count > 1: logger.warning( @@ -277,50 +317,6 @@ def _create_streaming_chat_message_content( items=items, ) - def update_settings_from_function_call_configuration_anthropic( - self, - function_choice_configuration: FunctionCallChoiceConfiguration, - settings: "PromptExecutionSettings", - type: "FunctionChoiceType", - ) -> None: - """Update the settings from a FunctionChoiceConfiguration.""" - if ( - function_choice_configuration.available_functions - and hasattr(settings, "tools") - and hasattr(settings, "tool_choice") - ): - settings.tools = [ - self.kernel_function_metadata_to_function_call_format_anthropic(f) - for f in function_choice_configuration.available_functions - ] - - if ( - settings.function_choice_behavior - and settings.function_choice_behavior.type_ == FunctionChoiceType.REQUIRED - ) or type == FunctionChoiceType.REQUIRED: - settings.tool_choice = {"type": "any"} - else: - settings.tool_choice = {"type": type.value} - - def kernel_function_metadata_to_function_call_format_anthropic( - self, - metadata: KernelFunctionMetadata, - ) -> dict[str, Any]: - """Convert the kernel function metadata to function calling format.""" - return { - "name": metadata.fully_qualified_name, - "description": metadata.description or "", - "input_schema": { - "type": "object", - "properties": {p.name: p.schema_data for p in metadata.parameters}, - "required": [p.name for p in metadata.parameters if p.is_required], - }, - } - - @override - def _update_function_choice_settings_callback(self): - return self.update_settings_from_function_call_configuration_anthropic - async def _send_chat_request(self, settings: AnthropicChatPromptExecutionSettings) -> list["ChatMessageContent"]: """Send the chat request.""" try: @@ -382,10 +378,3 @@ def _get_tool_calls_from_message(self, message: Message) -> list[FunctionCallCon ) return tool_calls - - @override - def _reset_function_choice_settings(self, settings: "PromptExecutionSettings") -> None: - if hasattr(settings, "tool_choice"): - settings.tool_choice = None - if hasattr(settings, "tools"): - settings.tools = None diff --git a/python/semantic_kernel/connectors/ai/anthropic/services/utils.py b/python/semantic_kernel/connectors/ai/anthropic/services/utils.py index 774d93615927..31acecb0468f 100644 --- a/python/semantic_kernel/connectors/ai/anthropic/services/utils.py +++ b/python/semantic_kernel/connectors/ai/anthropic/services/utils.py @@ -5,11 +5,15 @@ from collections.abc import Callable, Mapping from typing import Any +from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceType +from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.function_call_content import FunctionCallContent from semantic_kernel.contents.function_result_content import FunctionResultContent from semantic_kernel.contents.text_content import TextContent from semantic_kernel.contents.utils.author_role import AuthorRole +from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata logger: logging.Logger = logging.getLogger(__name__) @@ -50,29 +54,32 @@ def _format_assistant_message(message: ChatMessageContent) -> dict[str, Any]: "type": "tool_use", "id": item.id or "", "name": item.name or "", - "input": item.arguments if isinstance(item.arguments, Mapping) else json.loads(item.arguments or ""), + "input": item.arguments + if isinstance(item.arguments, Mapping) + else json.loads(item.arguments) + if item.arguments + else {}, }) else: logger.warning( f"Unsupported item type in Assistant message while formatting chat history for Anthropic: {type(item)}" ) + formatted_message: dict[str, Any] = {"role": "assistant", "content": []} + + if message.content: + # Only include the text content if it is not empty. + # Otherwise, the Anthropic client will throw an error. + formatted_message["content"].append({ # type: ignore + "type": "text", + "text": message.content, + }) if tool_calls: - return { - "role": "assistant", - "content": [ - { - "type": "text", - "text": message.content, - }, - *tool_calls, - ], - } + # Only include the tool calls if there are any. + # Otherwise, the Anthropic client will throw an error. + formatted_message["content"].extend(tool_calls) # type: ignore - return { - "role": "assistant", - "content": message.content, - } + return formatted_message def _format_tool_message(message: ChatMessageContent) -> dict[str, Any]: @@ -108,3 +115,40 @@ def _format_tool_message(message: ChatMessageContent) -> dict[str, Any]: AuthorRole.ASSISTANT: _format_assistant_message, AuthorRole.TOOL: _format_tool_message, } + + +def update_settings_from_function_call_configuration( + function_choice_configuration: FunctionCallChoiceConfiguration, + settings: PromptExecutionSettings, + type: FunctionChoiceType, +) -> None: + """Update the settings from a FunctionChoiceConfiguration.""" + if ( + function_choice_configuration.available_functions + and hasattr(settings, "tools") + and hasattr(settings, "tool_choice") + ): + settings.tools = [ + kernel_function_metadata_to_function_call_format(f) + for f in function_choice_configuration.available_functions + ] + + if ( + settings.function_choice_behavior and settings.function_choice_behavior.type_ == FunctionChoiceType.REQUIRED + ) or type == FunctionChoiceType.REQUIRED: + settings.tool_choice = {"type": "any"} + else: + settings.tool_choice = {"type": type.value} + + +def kernel_function_metadata_to_function_call_format(metadata: KernelFunctionMetadata) -> dict[str, Any]: + """Convert the kernel function metadata to function calling format.""" + return { + "name": metadata.fully_qualified_name, + "description": metadata.description or "", + "input_schema": { + "type": "object", + "properties": {p.name: p.schema_data for p in metadata.parameters}, + "required": [p.name for p in metadata.parameters if p.is_required], + }, + } diff --git a/python/tests/integration/completions/chat_completion_test_base.py b/python/tests/integration/completions/chat_completion_test_base.py index 61152512ae11..a31882951c9b 100644 --- a/python/tests/integration/completions/chat_completion_test_base.py +++ b/python/tests/integration/completions/chat_completion_test_base.py @@ -66,9 +66,7 @@ onnx_setup: bool = is_service_setup_for_testing( ["ONNX_GEN_AI_CHAT_MODEL_FOLDER"], raise_if_not_set=False ) # Tests are optional for ONNX -anthropic_setup: bool = is_service_setup_for_testing( - ["ANTHROPIC_API_KEY", "ANTHROPIC_CHAT_MODEL_ID"], raise_if_not_set=False -) # We don't have an Anthropic deployment +anthropic_setup: bool = is_service_setup_for_testing(["ANTHROPIC_API_KEY", "ANTHROPIC_CHAT_MODEL_ID"]) # When testing Bedrock, after logging into AWS CLI this has been set, so we can use it to check if the service is setup bedrock_setup: bool = is_service_setup_for_testing(["AWS_DEFAULT_REGION"], raise_if_not_set=False) diff --git a/python/tests/unit/connectors/ai/anthropic/conftest.py b/python/tests/unit/connectors/ai/anthropic/conftest.py new file mode 100644 index 000000000000..dc7d54cae463 --- /dev/null +++ b/python/tests/unit/connectors/ai/anthropic/conftest.py @@ -0,0 +1,400 @@ +# Copyright (c) Microsoft. All rights reserved. +from collections.abc import AsyncGenerator +from unittest.mock import AsyncMock, MagicMock + +import pytest +from anthropic import AsyncAnthropic +from anthropic.lib.streaming import TextEvent +from anthropic.lib.streaming._types import InputJsonEvent +from anthropic.types import ( + ContentBlockStopEvent, + InputJSONDelta, + Message, + MessageDeltaUsage, + MessageStopEvent, + RawContentBlockDeltaEvent, + RawContentBlockStartEvent, + RawMessageDeltaEvent, + RawMessageStartEvent, + TextBlock, + TextDelta, + ToolUseBlock, + Usage, +) +from anthropic.types.raw_message_delta_event import Delta + +from semantic_kernel.connectors.ai.anthropic.prompt_execution_settings.anthropic_prompt_execution_settings import ( + AnthropicChatPromptExecutionSettings, +) +from semantic_kernel.contents.chat_message_content import ( + ChatMessageContent, + FunctionCallContent, + FunctionResultContent, + TextContent, +) +from semantic_kernel.contents.const import ContentTypes +from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent, StreamingTextContent +from semantic_kernel.contents.utils.author_role import AuthorRole +from semantic_kernel.contents.utils.finish_reason import FinishReason + + +@pytest.fixture +def mock_tool_calls_message() -> ChatMessageContent: + return ChatMessageContent( + ai_model_id="claude-3-opus-20240229", + metadata={}, + content_type="message", + role=AuthorRole.ASSISTANT, + name=None, + items=[ + TextContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type="text", + text="", + encoding=None, + ), + FunctionCallContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type=ContentTypes.FUNCTION_CALL_CONTENT, + id="test_function_call_content", + index=1, + name="math-Add", + function_name="Add", + plugin_name="math", + arguments={"input": 3, "amount": 3}, + ), + ], + encoding=None, + finish_reason=FinishReason.TOOL_CALLS, + ) + + +@pytest.fixture +def mock_parallel_tool_calls_message() -> ChatMessageContent: + return ChatMessageContent( + ai_model_id="claude-3-opus-20240229", + metadata={}, + content_type="message", + role=AuthorRole.ASSISTANT, + name=None, + items=[ + TextContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type="text", + text="", + encoding=None, + ), + FunctionCallContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type=ContentTypes.FUNCTION_CALL_CONTENT, + id="test_function_call_content_1", + index=1, + name="math-Add", + function_name="Add", + plugin_name="math", + arguments={"input": 3, "amount": 3}, + ), + FunctionCallContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type=ContentTypes.FUNCTION_CALL_CONTENT, + id="test_function_call_content_2", + index=1, + name="math-Subtract", + function_name="Subtract", + plugin_name="math", + arguments={"input": 6, "amount": 3}, + ), + ], + encoding=None, + finish_reason=FinishReason.TOOL_CALLS, + ) + + +@pytest.fixture +def mock_streaming_tool_calls_message() -> list: + stream_events = [ + RawMessageStartEvent( + message=Message( + id="test_message_id", + content=[], + model="claude-3-opus-20240229", + role="assistant", + stop_reason=None, + stop_sequence=None, + type="message", + usage=Usage(input_tokens=1720, output_tokens=2), + ), + type="message_start", + ), + RawContentBlockStartEvent(content_block=TextBlock(text="", type="text"), index=0, type="content_block_start"), + RawContentBlockDeltaEvent( + delta=TextDelta(text="", type="text_delta"), index=0, type="content_block_delta" + ), + TextEvent(type="text", text="", snapshot=""), + RawContentBlockDeltaEvent( + delta=TextDelta(text="", type="text_delta"), index=0, type="content_block_delta" + ), + TextEvent(type="text", text="", snapshot=""), + ContentBlockStopEvent( + index=0, type="content_block_stop", content_block=TextBlock(text="", type="text") + ), + RawContentBlockStartEvent( + content_block=ToolUseBlock(id="test_tool_use_message_id", input={}, name="math-Add", type="tool_use"), + index=1, + type="content_block_start", + ), + RawContentBlockDeltaEvent( + delta=InputJSONDelta(partial_json='{"input": 3, "amount": 3}', type="input_json_delta"), + index=1, + type="content_block_delta", + ), + InputJsonEvent(type="input_json", partial_json='{"input": 3, "amount": 3}', snapshot={"input": 3, "amount": 3}), + ContentBlockStopEvent( + index=1, + type="content_block_stop", + content_block=ToolUseBlock( + id="test_tool_use_block_id", input={"input": 3, "amount": 3}, name="math-Add", type="tool_use" + ), + ), + RawMessageDeltaEvent( + delta=Delta(stop_reason="tool_use", stop_sequence=None), + type="message_delta", + usage=MessageDeltaUsage(output_tokens=159), + ), + MessageStopEvent( + type="message_stop", + message=Message( + id="test_message_id", + content=[ + TextBlock(text="", type="text"), + ToolUseBlock( + id="test_tool_use_block_id", input={"input": 3, "amount": 3}, name="math-Add", type="tool_use" + ), + ], + model="claude-3-opus-20240229", + role="assistant", + stop_reason="tool_use", + stop_sequence=None, + type="message", + usage=Usage(input_tokens=100, output_tokens=100), + ), + ), + ] + + async def async_generator(): + for event in stream_events: + yield event + + stream_mock = AsyncMock() + stream_mock.__aenter__.return_value = async_generator() + + return stream_mock + + +@pytest.fixture +def mock_tool_call_result_message() -> ChatMessageContent: + return ChatMessageContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type="message", + role=AuthorRole.TOOL, + name=None, + items=[ + FunctionResultContent( + id="test_function_call_content", + result=6, + ) + ], + encoding=None, + finish_reason=FinishReason.TOOL_CALLS, + ) + + +@pytest.fixture +def mock_parallel_tool_call_result_message() -> ChatMessageContent: + return ChatMessageContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type="message", + role=AuthorRole.TOOL, + name=None, + items=[ + FunctionResultContent( + id="test_function_call_content_1", + result=6, + ), + FunctionResultContent( + id="test_function_call_content_2", + result=3, + ), + ], + encoding=None, + finish_reason=FinishReason.TOOL_CALLS, + ) + + +@pytest.fixture +def mock_streaming_chat_message_content() -> StreamingChatMessageContent: + return StreamingChatMessageContent( + choice_index=0, + ai_model_id="claude-3-opus-20240229", + metadata={}, + role=AuthorRole.ASSISTANT, + name=None, + items=[ + StreamingTextContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type="text", + text="", + encoding=None, + choice_index=0, + ), + FunctionCallContent( + inner_content=None, + ai_model_id=None, + metadata={}, + content_type=ContentTypes.FUNCTION_CALL_CONTENT, + id="tool_id", + index=0, + name="math-Add", + function_name="Add", + plugin_name="math", + arguments='{"input": 3, "amount": 3}', + ), + ], + encoding=None, + finish_reason=FinishReason.TOOL_CALLS, + ) + + +@pytest.fixture +def mock_settings() -> AnthropicChatPromptExecutionSettings: + return AnthropicChatPromptExecutionSettings() + + +@pytest.fixture +def mock_chat_message_response() -> Message: + return Message( + id="test_message_id", + content=[TextBlock(text="Hello, how are you?", type="text")], + model="claude-3-opus-20240229", + role="assistant", + stop_reason="end_turn", + stop_sequence=None, + type="message", + usage=Usage(input_tokens=10, output_tokens=10), + ) + + +@pytest.fixture +def mock_streaming_message_response() -> AsyncGenerator: + raw_message_start_event = RawMessageStartEvent( + message=Message( + id="test_message_id", + content=[], + model="claude-3-opus-20240229", + role="assistant", + stop_reason=None, + stop_sequence=None, + type="message", + usage=Usage(input_tokens=41, output_tokens=3), + ), + type="message_start", + ) + + raw_content_block_start_event = RawContentBlockStartEvent( + content_block=TextBlock(text="", type="text"), + index=0, + type="content_block_start", + ) + + raw_content_block_delta_event = RawContentBlockDeltaEvent( + delta=TextDelta(text="Hello! It", type="text_delta"), + index=0, + type="content_block_delta", + ) + + text_event = TextEvent( + type="text", + text="Hello! It", + snapshot="Hello! It", + ) + + content_block_stop_event = ContentBlockStopEvent( + index=0, + type="content_block_stop", + content_block=TextBlock(text="Hello! It's nice to meet you.", type="text"), + ) + + raw_message_delta_event = RawMessageDeltaEvent( + delta=Delta(stop_reason="end_turn", stop_sequence=None), + type="message_delta", + usage=MessageDeltaUsage(output_tokens=84), + ) + + message_stop_event = MessageStopEvent( + type="message_stop", + message=Message( + id="test_message_stop_id", + content=[TextBlock(text="Hello! It's nice to meet you.", type="text")], + model="claude-3-opus-20240229", + role="assistant", + stop_reason="end_turn", + stop_sequence=None, + type="message", + usage=Usage(input_tokens=41, output_tokens=84), + ), + ) + + # Combine all mock events into a list + stream_events = [ + raw_message_start_event, + raw_content_block_start_event, + raw_content_block_delta_event, + text_event, + content_block_stop_event, + raw_message_delta_event, + message_stop_event, + ] + + async def async_generator(): + for event in stream_events: + yield event + + # Create an AsyncMock for the stream + stream_mock = AsyncMock() + stream_mock.__aenter__.return_value = async_generator() + + return stream_mock + + +@pytest.fixture +def mock_anthropic_client_completion(mock_chat_message_response: Message) -> AsyncAnthropic: + client = MagicMock(spec=AsyncAnthropic) + messages_mock = MagicMock() + messages_mock.create = AsyncMock(return_value=mock_chat_message_response) + client.messages = messages_mock + return client + + +@pytest.fixture +def mock_anthropic_client_completion_stream(mock_streaming_message_response: AsyncGenerator) -> AsyncAnthropic: + client = MagicMock(spec=AsyncAnthropic) + messages_mock = MagicMock() + messages_mock.stream.return_value = mock_streaming_message_response + client.messages = messages_mock + return client diff --git a/python/tests/unit/connectors/ai/anthropic/services/test_anthropic_chat_completion.py b/python/tests/unit/connectors/ai/anthropic/services/test_anthropic_chat_completion.py index d368dd901c4d..bff83bfe89d6 100644 --- a/python/tests/unit/connectors/ai/anthropic/services/test_anthropic_chat_completion.py +++ b/python/tests/unit/connectors/ai/anthropic/services/test_anthropic_chat_completion.py @@ -1,27 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. -from collections.abc import AsyncGenerator from unittest.mock import AsyncMock, MagicMock, patch import pytest from anthropic import AsyncAnthropic -from anthropic.lib.streaming import TextEvent -from anthropic.lib.streaming._types import InputJsonEvent -from anthropic.types import ( - ContentBlockStopEvent, - InputJSONDelta, - Message, - MessageDeltaUsage, - MessageStopEvent, - RawContentBlockDeltaEvent, - RawContentBlockStartEvent, - RawMessageDeltaEvent, - RawMessageStartEvent, - TextBlock, - TextDelta, - ToolUseBlock, - Usage, -) -from anthropic.types.raw_message_delta_event import Delta +from anthropic.types import Message from semantic_kernel.connectors.ai.anthropic.prompt_execution_settings.anthropic_prompt_execution_settings import ( AnthropicChatPromptExecutionSettings, @@ -33,406 +15,20 @@ OpenAIChatPromptExecutionSettings, ) from semantic_kernel.contents.chat_history import ChatHistory -from semantic_kernel.contents.chat_message_content import ( - ChatMessageContent, - FunctionCallContent, - FunctionResultContent, - TextContent, -) -from semantic_kernel.contents.const import ContentTypes -from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent, StreamingTextContent +from semantic_kernel.contents.chat_message_content import ChatMessageContent, FunctionCallContent, TextContent +from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent from semantic_kernel.contents.utils.author_role import AuthorRole -from semantic_kernel.contents.utils.finish_reason import FinishReason from semantic_kernel.exceptions import ( ServiceInitializationError, ServiceInvalidExecutionSettingsError, ServiceResponseException, ) -from semantic_kernel.functions.function_result import FunctionResult +from semantic_kernel.exceptions.service_exceptions import ServiceInvalidRequestError from semantic_kernel.functions.kernel_arguments import KernelArguments from semantic_kernel.functions.kernel_function_decorator import kernel_function -from semantic_kernel.functions.kernel_function_from_method import KernelFunctionMetadata -from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata from semantic_kernel.kernel import Kernel -@pytest.fixture -def mock_tool_calls_message() -> ChatMessageContent: - return ChatMessageContent( - inner_content=Message( - id="test_message_id", - content=[ - TextBlock(text="", type="text"), - ToolUseBlock( - id="test_tool_use_blocks", - input={"input": 3, "amount": 3}, - name="math-Add", - type="tool_use", - ), - ], - model="claude-3-opus-20240229", - role="assistant", - stop_reason="tool_use", - stop_sequence=None, - type="message", - usage=Usage(input_tokens=1720, output_tokens=194), - ), - ai_model_id="claude-3-opus-20240229", - metadata={}, - content_type="message", - role=AuthorRole.ASSISTANT, - name=None, - items=[ - FunctionCallContent( - inner_content=None, - ai_model_id=None, - metadata={}, - content_type=ContentTypes.FUNCTION_CALL_CONTENT, - id="test_function_call_content", - index=1, - name="math-Add", - function_name="Add", - plugin_name="math", - arguments={"input": 3, "amount": 3}, - ), - TextContent( - inner_content=None, - ai_model_id=None, - metadata={}, - content_type="text", - text="", - encoding=None, - ), - ], - encoding=None, - finish_reason=FinishReason.TOOL_CALLS, - ) - - -@pytest.fixture -def mock_streaming_tool_calls_message() -> list: - stream_events = [ - RawMessageStartEvent( - message=Message( - id="test_message_id", - content=[], - model="claude-3-opus-20240229", - role="assistant", - stop_reason=None, - stop_sequence=None, - type="message", - usage=Usage(input_tokens=1720, output_tokens=2), - ), - type="message_start", - ), - RawContentBlockStartEvent(content_block=TextBlock(text="", type="text"), index=0, type="content_block_start"), - RawContentBlockDeltaEvent( - delta=TextDelta(text="", type="text_delta"), index=0, type="content_block_delta" - ), - TextEvent(type="text", text="", snapshot=""), - RawContentBlockDeltaEvent( - delta=TextDelta(text="", type="text_delta"), index=0, type="content_block_delta" - ), - TextEvent(type="text", text="", snapshot=""), - ContentBlockStopEvent( - index=0, type="content_block_stop", content_block=TextBlock(text="", type="text") - ), - RawContentBlockStartEvent( - content_block=ToolUseBlock(id="test_tool_use_message_id", input={}, name="math-Add", type="tool_use"), - index=1, - type="content_block_start", - ), - RawContentBlockDeltaEvent( - delta=InputJSONDelta(partial_json='{"input": 3, "amount": 3}', type="input_json_delta"), - index=1, - type="content_block_delta", - ), - InputJsonEvent(type="input_json", partial_json='{"input": 3, "amount": 3}', snapshot={"input": 3, "amount": 3}), - ContentBlockStopEvent( - index=1, - type="content_block_stop", - content_block=ToolUseBlock( - id="test_tool_use_block_id", input={"input": 3, "amount": 3}, name="math-Add", type="tool_use" - ), - ), - RawMessageDeltaEvent( - delta=Delta(stop_reason="tool_use", stop_sequence=None), - type="message_delta", - usage=MessageDeltaUsage(output_tokens=159), - ), - MessageStopEvent( - type="message_stop", - message=Message( - id="test_message_id", - content=[ - TextBlock(text="", type="text"), - ToolUseBlock( - id="test_tool_use_block_id", input={"input": 3, "amount": 3}, name="math-Add", type="tool_use" - ), - ], - model="claude-3-opus-20240229", - role="assistant", - stop_reason="tool_use", - stop_sequence=None, - type="message", - usage=Usage(input_tokens=100, output_tokens=100), - ), - ), - ] - - async def async_generator(): - for event in stream_events: - yield event - - stream_mock = AsyncMock() - stream_mock.__aenter__.return_value = async_generator() - - return stream_mock - - -@pytest.fixture -def mock_tool_call_result_message() -> ChatMessageContent: - return ChatMessageContent( - inner_content=None, - ai_model_id=None, - metadata={}, - content_type="message", - role=AuthorRole.TOOL, - name=None, - items=[ - FunctionResultContent( - id="tool_01", - inner_content=FunctionResult( - function=KernelFunctionMetadata( - name="Add", - plugin_name="math", - description="Returns the Addition result of the values provided.", - parameters=[ - KernelParameterMetadata( - name="input", - description="the first number to add", - default_value=None, - type_="int", - is_required=True, - type_object=int, - schema_data={"type": "integer", "description": "the first number to add"}, - function_schema_include=True, - ), - KernelParameterMetadata( - name="amount", - description="the second number to add", - default_value=None, - type_="int", - is_required=True, - type_object=int, - schema_data={"type": "integer", "description": "the second number to add"}, - function_schema_include=True, - ), - ], - is_prompt=False, - is_asynchronous=False, - return_parameter=KernelParameterMetadata( - name="return", - description="the output is a number", - default_value=None, - type_="int", - is_required=True, - type_object=int, - schema_data={"type": "integer", "description": "the output is a number"}, - function_schema_include=True, - ), - additional_properties={}, - ), - value=6, - metadata={}, - ), - value=6, - ) - ], - encoding=None, - finish_reason=FinishReason.TOOL_CALLS, - ) - - -# mock StreamingChatMessageContent -@pytest.fixture -def mock_streaming_chat_message_content() -> StreamingChatMessageContent: - return StreamingChatMessageContent( - choice_index=0, - inner_content=[ - RawContentBlockDeltaEvent( - delta=TextDelta(text="", type="text_delta"), index=0, type="content_block_delta" - ), - RawContentBlockDeltaEvent( - delta=TextDelta(text="", type="text_delta"), index=0, type="content_block_delta" - ), - ContentBlockStopEvent( - index=1, - type="content_block_stop", - content_block=ToolUseBlock( - id="tool_id", - input={"input": 3, "amount": 3}, - name="math-Add", - type="tool_use", - ), - ), - RawMessageDeltaEvent( - delta=Delta(stop_reason="tool_use", stop_sequence=None), - type="message_delta", - usage=MessageDeltaUsage(output_tokens=175), - ), - ], - ai_model_id="claude-3-opus-20240229", - metadata={}, - role=AuthorRole.ASSISTANT, - name=None, - items=[ - StreamingTextContent( - inner_content=None, - ai_model_id=None, - metadata={}, - content_type="text", - text="", - encoding=None, - choice_index=0, - ), - FunctionCallContent( - inner_content=None, - ai_model_id=None, - metadata={}, - content_type=ContentTypes.FUNCTION_CALL_CONTENT, - id="tool_id", - index=0, - name="math-Add", - function_name="Add", - plugin_name="math", - arguments='{"input": 3, "amount": 3}', - ), - ], - encoding=None, - finish_reason=FinishReason.TOOL_CALLS, - ) - - -@pytest.fixture -def mock_settings() -> AnthropicChatPromptExecutionSettings: - return AnthropicChatPromptExecutionSettings() - - -@pytest.fixture -def mock_chat_message_response() -> Message: - return Message( - id="test_message_id", - content=[TextBlock(text="Hello, how are you?", type="text")], - model="claude-3-opus-20240229", - role="assistant", - stop_reason="end_turn", - stop_sequence=None, - type="message", - usage=Usage(input_tokens=10, output_tokens=10), - ) - - -@pytest.fixture -def mock_streaming_message_response() -> AsyncGenerator: - raw_message_start_event = RawMessageStartEvent( - message=Message( - id="test_message_id", - content=[], - model="claude-3-opus-20240229", - role="assistant", - stop_reason=None, - stop_sequence=None, - type="message", - usage=Usage(input_tokens=41, output_tokens=3), - ), - type="message_start", - ) - - raw_content_block_start_event = RawContentBlockStartEvent( - content_block=TextBlock(text="", type="text"), - index=0, - type="content_block_start", - ) - - raw_content_block_delta_event = RawContentBlockDeltaEvent( - delta=TextDelta(text="Hello! It", type="text_delta"), - index=0, - type="content_block_delta", - ) - - text_event = TextEvent( - type="text", - text="Hello! It", - snapshot="Hello! It", - ) - - content_block_stop_event = ContentBlockStopEvent( - index=0, - type="content_block_stop", - content_block=TextBlock(text="Hello! It's nice to meet you.", type="text"), - ) - - raw_message_delta_event = RawMessageDeltaEvent( - delta=Delta(stop_reason="end_turn", stop_sequence=None), - type="message_delta", - usage=MessageDeltaUsage(output_tokens=84), - ) - - message_stop_event = MessageStopEvent( - type="message_stop", - message=Message( - id="test_message_stop_id", - content=[TextBlock(text="Hello! It's nice to meet you.", type="text")], - model="claude-3-opus-20240229", - role="assistant", - stop_reason="end_turn", - stop_sequence=None, - type="message", - usage=Usage(input_tokens=41, output_tokens=84), - ), - ) - - # Combine all mock events into a list - stream_events = [ - raw_message_start_event, - raw_content_block_start_event, - raw_content_block_delta_event, - text_event, - content_block_stop_event, - raw_message_delta_event, - message_stop_event, - ] - - async def async_generator(): - for event in stream_events: - yield event - - # Create an AsyncMock for the stream - stream_mock = AsyncMock() - stream_mock.__aenter__.return_value = async_generator() - - return stream_mock - - -@pytest.fixture -def mock_anthropic_client_completion(mock_chat_message_response: Message) -> AsyncAnthropic: - client = MagicMock(spec=AsyncAnthropic) - messages_mock = MagicMock() - messages_mock.create = AsyncMock(return_value=mock_chat_message_response) - client.messages = messages_mock - return client - - -@pytest.fixture -def mock_anthropic_client_completion_stream(mock_streaming_message_response: AsyncGenerator) -> AsyncAnthropic: - client = MagicMock(spec=AsyncAnthropic) - messages_mock = MagicMock() - messages_mock.stream.return_value = mock_streaming_message_response - client.messages = messages_mock - return client - - async def test_complete_chat_contents( kernel: Kernel, mock_settings: AnthropicChatPromptExecutionSettings, @@ -753,7 +349,7 @@ async def test_prepare_chat_history_for_request_with_system_message(mock_anthrop assert system_message_content == "System message" assert remaining_messages == [ {"role": AuthorRole.USER, "content": "User message"}, - {"role": AuthorRole.ASSISTANT, "content": "Assistant message"}, + {"role": AuthorRole.ASSISTANT, "content": [{"type": "text", "text": "Assistant message"}]}, ] assert not any(msg["role"] == AuthorRole.SYSTEM for msg in remaining_messages) @@ -780,35 +376,121 @@ async def test_prepare_chat_history_for_request_with_tool_message( ) assert system_message_content is None - assert len(remaining_messages) == 3 + assert remaining_messages == [ + {"role": AuthorRole.USER, "content": "What is 3+3?"}, + { + "role": AuthorRole.ASSISTANT, + "content": [ + {"type": "text", "text": mock_tool_calls_message.items[0].text}, + { + "type": "tool_use", + "id": mock_tool_calls_message.items[1].id, + "name": mock_tool_calls_message.items[1].name, + "input": mock_tool_calls_message.items[1].arguments, + }, + ], + }, + { + "role": AuthorRole.USER, + "content": [ + { + "type": "tool_result", + "tool_use_id": mock_tool_call_result_message.items[0].id, + "content": str(mock_tool_call_result_message.items[0].result), + } + ], + }, + ] -async def test_prepare_chat_history_for_request_with_tool_message_streaming( +async def test_prepare_chat_history_for_request_with_parallel_tool_message( + mock_anthropic_client_completion_stream: MagicMock, + mock_parallel_tool_calls_message: ChatMessageContent, + mock_parallel_tool_call_result_message: ChatMessageContent, +): + chat_history = ChatHistory() + chat_history.add_user_message("What is 3+3?") + chat_history.add_message(mock_parallel_tool_calls_message) + chat_history.add_message(mock_parallel_tool_call_result_message) + + chat_completion_client = AnthropicChatCompletion( + ai_model_id="test_model_id", + service_id="test", + api_key="", + async_client=mock_anthropic_client_completion_stream, + ) + + remaining_messages, system_message_content = chat_completion_client._prepare_chat_history_for_request( + chat_history, role_key="role", content_key="content" + ) + + assert system_message_content is None + assert remaining_messages == [ + {"role": AuthorRole.USER, "content": "What is 3+3?"}, + { + "role": AuthorRole.ASSISTANT, + "content": [ + {"type": "text", "text": mock_parallel_tool_calls_message.items[0].text}, + *[ + { + "type": "tool_use", + "id": function_call_content.id, + "name": function_call_content.name, + "input": function_call_content.arguments, + } + for function_call_content in mock_parallel_tool_calls_message.items[1:] + ], + ], + }, + { + "role": AuthorRole.USER, + "content": [ + { + "type": "tool_result", + "tool_use_id": function_result_content.id, + "content": str(function_result_content.result), + } + for function_result_content in mock_parallel_tool_call_result_message.items + ], + }, + ] + + +async def test_prepare_chat_history_for_request_with_tool_message_right_after_user_message( mock_anthropic_client_completion_stream: MagicMock, - mock_streaming_chat_message_content: StreamingChatMessageContent, mock_tool_call_result_message: ChatMessageContent, ): chat_history = ChatHistory() chat_history.add_user_message("What is 3+3?") - chat_history.add_message(mock_streaming_chat_message_content) chat_history.add_message(mock_tool_call_result_message) - chat_completion = AnthropicChatCompletion( + chat_completion_client = AnthropicChatCompletion( ai_model_id="test_model_id", service_id="test", api_key="", async_client=mock_anthropic_client_completion_stream, ) - remaining_messages, system_message_content = chat_completion._prepare_chat_history_for_request( - chat_history, - role_key="role", - content_key="content", - stream=True, + with pytest.raises(ServiceInvalidRequestError, match="Tool message found after a user or system message."): + chat_completion_client._prepare_chat_history_for_request(chat_history, role_key="role", content_key="content") + + +async def test_prepare_chat_history_for_request_with_tool_message_as_the_first_message( + mock_anthropic_client_completion_stream: MagicMock, + mock_tool_call_result_message: ChatMessageContent, +): + chat_history = ChatHistory() + chat_history.add_message(mock_tool_call_result_message) + + chat_completion_client = AnthropicChatCompletion( + ai_model_id="test_model_id", + service_id="test", + api_key="", + async_client=mock_anthropic_client_completion_stream, ) - assert system_message_content is None - assert len(remaining_messages) == 3 + with pytest.raises(ServiceInvalidRequestError, match="Tool message found without a preceding message."): + chat_completion_client._prepare_chat_history_for_request(chat_history, role_key="role", content_key="content") async def test_send_chat_stream_request_tool_calls(