diff --git a/python/samples/concepts/filtering/retry_with_filters.py b/python/samples/concepts/filtering/retry_with_filters.py new file mode 100644 index 000000000000..92131ad1d292 --- /dev/null +++ b/python/samples/concepts/filtering/retry_with_filters.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import logging +from collections.abc import Callable, Coroutine +from typing import Any + +from samples.concepts.setup.chat_completion_services import Services, get_chat_completion_service_and_request_settings +from semantic_kernel import Kernel +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior +from semantic_kernel.contents import ChatHistory +from semantic_kernel.filters import FunctionInvocationContext +from semantic_kernel.filters.filter_types import FilterTypes +from semantic_kernel.functions import kernel_function + +# This sample shows how to use a filter for retrying a function invocation. +# This sample requires the following components: +# - a ChatCompletionService: This component is responsible for generating responses to user messages. +# - a ChatHistory: This component is responsible for keeping track of the chat history. +# - a Kernel: This component is responsible for managing plugins and filters. +# - a mock plugin: This plugin contains a function that simulates a call to an external service. +# - a filter: This filter retries the function invocation if it fails. + +logger = logging.getLogger(__name__) + +# The maximum number of retries for the filter +MAX_RETRIES = 3 + + +class WeatherPlugin: + MAX_FAILURES = 2 + + def __init__(self): + self._invocation_count = 0 + + @kernel_function(name="GetWeather", description="Get the weather of the day at the current location.") + def get_wather(self) -> str: + """Get the weather of the day at the current location. + + Simulates a call to an external service to get the weather. + This function is designed to fail a certain number of times before succeeding. + """ + if self._invocation_count < self.MAX_FAILURES: + self._invocation_count += 1 + print(f"Number of attempts: {self._invocation_count}") + raise Exception("Failed to get the weather") + + return "Sunny" + + +async def retry_filter( + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Coroutine[Any, Any, None]], +) -> None: + """A filter that retries the function invocation if it fails. + + The filter uses a binary exponential backoff strategy to retry the function invocation. + """ + for i in range(MAX_RETRIES): + try: + await next(context) + return + except Exception as e: + logger.warning(f"Failed to execute the function: {e}") + backoff = 2**i + logger.info(f"Sleeping for {backoff} seconds before retrying") + + +async def main() -> None: + kernel = Kernel() + # Register the plugin to the kernel + kernel.add_plugin(WeatherPlugin(), plugin_name="WeatherPlugin") + # Add the filter to the kernel as a function invocation filter + # A function invocation filter is called during when the kernel executes a function + kernel.add_filter(FilterTypes.FUNCTION_INVOCATION, retry_filter) + + chat_history = ChatHistory() + chat_history.add_user_message("What is the weather today?") + + chat_completion_service, request_settings = get_chat_completion_service_and_request_settings(Services.OPENAI) + # Need to set the function choice behavior to auto such that the + # service will automatically invoke the function in the response. + request_settings.function_choice_behavior = FunctionChoiceBehavior.Auto() + + response = await chat_completion_service.get_chat_message_content( + chat_history=chat_history, + settings=request_settings, + # Need to pass the kernel to the chat completion service so that it has access to the plugins and filters + kernel=kernel, + ) + + print(response) + + # Sample output: + # Number of attempts: 1 + # Failed to execute the function: Failed to get the weather + # Number of attempts: 2 + # Failed to execute the function: Failed to get the weather + # The weather today is Sunny + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/semantic_kernel/kernel.py b/python/semantic_kernel/kernel.py index 3fd5d33dcc1d..5ddef255f355 100644 --- a/python/semantic_kernel/kernel.py +++ b/python/semantic_kernel/kernel.py @@ -65,8 +65,6 @@ class Kernel(KernelFilterExtension, KernelFunctionExtension, KernelServicesExten plugins: A dict with the plugins registered with the Kernel, from KernelFunctionExtension. services: A dict with the services registered with the Kernel, from KernelServicesExtension. ai_service_selector: The AI service selector to be used by the kernel, from KernelServicesExtension. - retry_mechanism: The retry mechanism to be used by the kernel, from KernelReliabilityExtension. - """ def __init__( @@ -84,12 +82,8 @@ def __init__( plugins: The plugins to be used by the kernel, will be rewritten to a dict with plugin name as key services: The services to be used by the kernel, will be rewritten to a dict with service_id as key ai_service_selector: The AI service selector to be used by the kernel, - default is based on order of execution settings. - **kwargs: Additional fields to be passed to the Kernel model, - these are limited to retry_mechanism and function_invoking_handlers - and function_invoked_handlers, the best way to add function_invoking_handlers - and function_invoked_handlers is to use the add_function_invoking_handler - and add_function_invoked_handler methods. + default is based on order of execution settings. + **kwargs: Additional fields to be passed to the Kernel model, these are limited to filters. """ args = { "services": services, diff --git a/python/semantic_kernel/reliability/kernel_reliability_extension.py b/python/semantic_kernel/reliability/kernel_reliability_extension.py index 82a020cfdeff..9c89766c47db 100644 --- a/python/semantic_kernel/reliability/kernel_reliability_extension.py +++ b/python/semantic_kernel/reliability/kernel_reliability_extension.py @@ -4,6 +4,7 @@ from abc import ABC from pydantic import Field +from typing_extensions import deprecated from semantic_kernel.kernel_pydantic import KernelBaseModel from semantic_kernel.reliability.pass_through_without_retry import PassThroughWithoutRetry @@ -15,4 +16,7 @@ class KernelReliabilityExtension(KernelBaseModel, ABC): """Kernel reliability extension.""" - retry_mechanism: RetryMechanismBase = Field(default_factory=PassThroughWithoutRetry) + retry_mechanism: RetryMechanismBase = Field( + default_factory=PassThroughWithoutRetry, + deprecated=deprecated("retry_mechanism is deprecated; This property doesn't have any effect on the kernel."), + ) diff --git a/python/tests/unit/kernel/test_kernel.py b/python/tests/unit/kernel/test_kernel.py index 4180994792dd..38b1608d150f 100644 --- a/python/tests/unit/kernel/test_kernel.py +++ b/python/tests/unit/kernel/test_kernel.py @@ -18,10 +18,7 @@ from semantic_kernel.contents import ChatMessageContent from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.function_call_content import FunctionCallContent -from semantic_kernel.exceptions import ( - KernelFunctionAlreadyExistsError, - KernelServiceNotFoundError, -) +from semantic_kernel.exceptions import KernelFunctionAlreadyExistsError, KernelServiceNotFoundError from semantic_kernel.exceptions.content_exceptions import FunctionCallInvalidArgumentsException from semantic_kernel.exceptions.kernel_exceptions import ( KernelFunctionNotFoundError,