From 3bab8486f5ce6ff48aeec456f46cf80c4e1f64a0 Mon Sep 17 00:00:00 2001 From: Eduard van Valkenburg Date: Thu, 19 Dec 2024 09:47:12 +0100 Subject: [PATCH] Python: improved text_search folder testing (#9984) ### Motivation and Context added unit tests to up the test coverage across the data folder Also redoes exceptions for everything vector. This does introduce some breaking changes on the exceptions returned by the different methods for vector stores, they are still marked experimental and this will go a long way in moving towards release. Closes #9485 ### Description ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [ ] I didn't break anyone :smile: --- python/.vscode/settings.json | 7 +- .../azure_ai_search_collection.py | 21 +- .../azure_ai_search/azure_ai_search_store.py | 4 +- .../azure_cosmos_db_no_sql_base.py | 18 +- .../azure_cosmos_db_no_sql_collection.py | 60 +-- .../azure_cosmos_db_no_sql_store.py | 4 +- .../memory/azure_cosmos_db/utils.py | 46 +- .../memory/postgres/postgres_collection.py | 224 +++++----- .../memory/qdrant/qdrant_collection.py | 14 +- .../connectors/memory/qdrant/qdrant_store.py | 6 +- .../memory/redis/redis_collection.py | 13 +- .../connectors/memory/redis/redis_store.py | 4 +- .../connectors/memory/redis/utils.py | 2 +- .../connectors/memory/weaviate/utils.py | 2 +- .../memory/weaviate/weaviate_collection.py | 84 ++-- .../memory/weaviate/weaviate_store.py | 12 +- .../vector_store_model_decorator.py | 116 ++--- .../vector_store_model_definition.py | 20 +- .../vector_store_model_protocols.py | 49 +-- .../vector_store_record_fields.py | 2 +- .../vector_store_record_utils.py | 2 +- python/semantic_kernel/data/search_filter.py | 4 +- .../data/text_search/text_search.py | 2 +- .../semantic_kernel/data/text_search/utils.py | 28 +- .../text_search/vector_store_text_search.py | 6 +- .../data/vector_search/vector_search.py | 22 +- .../vector_search/vector_search_filter.py | 4 +- .../data/vector_search/vector_text_search.py | 20 +- .../vector_search/vectorizable_text_search.py | 20 +- .../data/vector_search/vectorized_search.py | 20 +- .../data/vector_storage/vector_store.py | 2 +- .../vector_store_record_collection.py | 361 ++++++++------- python/semantic_kernel/exceptions/__init__.py | 1 + .../exceptions/memory_connector_exceptions.py | 35 -- .../exceptions/search_exceptions.py | 28 -- .../exceptions/vector_store_exceptions.py | 78 ++++ python/tests/conftest.py | 38 -- .../completions/chat_completion_test_base.py | 6 +- ...t_chat_completion_with_function_calling.py | 7 +- ...completion_with_image_input_text_output.py | 11 +- .../completions/test_chat_completions.py | 21 +- .../test_conversation_summary_plugin.py | 2 +- .../completions/test_text_completion.py | 29 +- .../azure_ai_search/test_azure_ai_search.py | 20 +- .../test_azure_cosmos_db_no_sql_collection.py | 19 +- .../test_azure_cosmos_db_no_sql_store.py | 8 +- .../connectors/memory/qdrant/test_qdrant.py | 18 +- .../memory/redis/test_redis_store.py | 14 +- .../weaviate/test_weaviate_collection.py | 16 +- .../memory/weaviate/test_weaviate_store.py | 5 +- python/tests/unit/data/conftest.py | 73 ++- python/tests/unit/data/test_text_search.py | 97 +++- .../unit/data/test_vector_search_base.py | 41 ++ .../unit/data/test_vector_search_mixins.py | 45 ++ .../data/test_vector_store_model_decorator.py | 8 +- .../test_vector_store_record_collection.py | 414 +++++++++--------- .../test_vector_store_record_definition.py | 62 +++ .../data/test_vector_store_record_utils.py | 11 + .../data/test_vector_store_text_search.py | 88 +++- python/tests/utils.py | 5 +- 60 files changed, 1439 insertions(+), 960 deletions(-) create mode 100644 python/semantic_kernel/exceptions/vector_store_exceptions.py create mode 100644 python/tests/unit/data/test_vector_search_base.py create mode 100644 python/tests/unit/data/test_vector_search_mixins.py diff --git a/python/.vscode/settings.json b/python/.vscode/settings.json index 93b973d6fc73..9d465bfc9a7a 100644 --- a/python/.vscode/settings.json +++ b/python/.vscode/settings.json @@ -13,11 +13,6 @@ "locale": "en-US" } ], - "python.testing.pytestArgs": [ - "tests" - ], - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true, "[python]": { "editor.codeActionsOnSave": { "source.organizeImports": "explicit", @@ -29,7 +24,7 @@ "notebook.formatOnSave.enabled": true, "notebook.codeActionsOnSave": { "source.fixAll": true, - "source.organizeImports": true + "source.organizeImports": false }, "python.analysis.extraPaths": [ "${workspaceFolder}/samples/learn_resources" diff --git a/python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_collection.py b/python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_collection.py index 1eec8028a1b8..d103997db200 100644 --- a/python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_collection.py +++ b/python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_collection.py @@ -34,7 +34,11 @@ from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult from semantic_kernel.data.vector_search.vector_text_search import VectorTextSearchMixin from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin -from semantic_kernel.exceptions import MemoryConnectorException, MemoryConnectorInitializationError +from semantic_kernel.exceptions import ( + VectorSearchExecutionException, + VectorStoreInitializationException, + VectorStoreOperationException, +) from semantic_kernel.utils.experimental_decorator import experimental_class logger: logging.Logger = logging.getLogger(__name__) @@ -91,7 +95,7 @@ def __init__( if not collection_name: collection_name = search_client._index_name elif search_client._index_name != collection_name: - raise MemoryConnectorInitializationError( + raise VectorStoreInitializationException( "Search client and search index client have different index names." ) super().__init__( @@ -107,7 +111,7 @@ def __init__( if search_index_client: if not collection_name: - raise MemoryConnectorInitializationError("Collection name is required.") + raise VectorStoreInitializationException("Collection name is required.") super().__init__( data_model_type=data_model_type, data_model_definition=data_model_definition, @@ -133,14 +137,14 @@ def __init__( index_name=collection_name, ) except ValidationError as exc: - raise MemoryConnectorInitializationError("Failed to create Azure Cognitive Search settings.") from exc + raise VectorStoreInitializationException("Failed to create Azure Cognitive Search settings.") from exc search_index_client = get_search_index_client( azure_ai_search_settings=azure_ai_search_settings, azure_credential=kwargs.get("azure_credentials"), token_credential=kwargs.get("token_credentials"), ) if not azure_ai_search_settings.index_name: - raise MemoryConnectorInitializationError("Collection name is required.") + raise VectorStoreInitializationException("Collection name is required.") super().__init__( data_model_type=data_model_type, @@ -211,7 +215,7 @@ async def create_collection(self, **kwargs) -> None: if isinstance(index, SearchIndex): await self.search_index_client.create_index(index=index, **kwargs) return - raise MemoryConnectorException("Invalid index type supplied.") + raise VectorStoreOperationException("Invalid index type supplied, should be a SearchIndex object.") await self.search_index_client.create_index( index=data_model_definition_to_azure_ai_search_index( collection_name=self.collection_name, @@ -279,7 +283,10 @@ async def _inner_search( for name, field in self.data_model_definition.fields.items() if not isinstance(field, VectorStoreRecordVectorField) ] - raw_results = await self.search_client.search(**search_args) + try: + raw_results = await self.search_client.search(**search_args) + except Exception as exc: + raise VectorSearchExecutionException("Failed to search the collection.") from exc return KernelSearchResults( results=self._get_vector_search_results_from_results(raw_results, options), total_count=await raw_results.get_count() if options.include_total_count else None, diff --git a/python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_store.py b/python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_store.py index fcaecf62f97f..4c4693abb6d7 100644 --- a/python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_store.py +++ b/python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_store.py @@ -19,7 +19,7 @@ from semantic_kernel.connectors.memory.azure_ai_search.utils import get_search_client, get_search_index_client from semantic_kernel.data.record_definition import VectorStoreRecordDefinition from semantic_kernel.data.vector_storage import VectorStore -from semantic_kernel.exceptions import MemoryConnectorInitializationError +from semantic_kernel.exceptions import VectorStoreInitializationException from semantic_kernel.utils.experimental_decorator import experimental_class if TYPE_CHECKING: @@ -78,7 +78,7 @@ def __init__( env_file_encoding=env_file_encoding, ) except ValidationError as exc: - raise MemoryConnectorInitializationError("Failed to create Azure AI Search settings.") from exc + raise VectorStoreInitializationException("Failed to create Azure AI Search settings.") from exc search_index_client = get_search_index_client( azure_ai_search_settings=azure_ai_search_settings, azure_credential=azure_credentials, diff --git a/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_base.py b/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_base.py index afe8fa75f140..c9a49ba4e546 100644 --- a/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_base.py +++ b/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_base.py @@ -6,9 +6,9 @@ from semantic_kernel.connectors.memory.azure_cosmos_db.azure_cosmos_db_no_sql_settings import AzureCosmosDBNoSQLSettings from semantic_kernel.connectors.memory.azure_cosmos_db.utils import CosmosClientWrapper -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorInitializationError, - MemoryConnectorResourceNotFound, +from semantic_kernel.exceptions import ( + VectorStoreInitializationException, + VectorStoreOperationException, ) from semantic_kernel.kernel_pydantic import KernelBaseModel from semantic_kernel.utils.authentication.async_default_azure_credential_wrapper import ( @@ -63,10 +63,10 @@ def __init__( env_file_encoding=env_file_encoding, ) except ValidationError as e: - raise MemoryConnectorInitializationError("Failed to validate Azure Cosmos DB NoSQL settings.") from e + raise VectorStoreInitializationException("Failed to validate Azure Cosmos DB NoSQL settings.") from e if cosmos_db_nosql_settings.database_name is None: - raise MemoryConnectorInitializationError("The name of the Azure Cosmos DB NoSQL database is missing.") + raise VectorStoreInitializationException("The name of the Azure Cosmos DB NoSQL database is missing.") if cosmos_client is None: if cosmos_db_nosql_settings.key is not None: @@ -94,7 +94,7 @@ async def _does_database_exist(self) -> bool: except CosmosResourceNotFoundError: return False except Exception as e: - raise MemoryConnectorResourceNotFound( + raise VectorStoreOperationException( f"Failed to check if database '{self.database_name}' exists, with message {e}" ) from e @@ -106,9 +106,9 @@ async def _get_database_proxy(self, **kwargs) -> DatabaseProxy: if self.create_database: return await self.cosmos_client.create_database(self.database_name, **kwargs) - raise MemoryConnectorResourceNotFound(f"Database '{self.database_name}' does not exist.") + raise VectorStoreOperationException(f"Database '{self.database_name}' does not exist.") except Exception as e: - raise MemoryConnectorResourceNotFound(f"Failed to get database proxy for '{id}'.") from e + raise VectorStoreOperationException(f"Failed to get database proxy for '{id}'.") from e async def _get_container_proxy(self, container_name: str, **kwargs) -> ContainerProxy: """Gets the container proxy.""" @@ -116,4 +116,4 @@ async def _get_container_proxy(self, container_name: str, **kwargs) -> Container database_proxy = await self._get_database_proxy(**kwargs) return database_proxy.get_container_client(container_name) except Exception as e: - raise MemoryConnectorResourceNotFound(f"Failed to get container proxy for '{container_name}'.") from e + raise VectorStoreOperationException(f"Failed to get container proxy for '{container_name}'.") from e diff --git a/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_collection.py b/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_collection.py index 00a6813c1802..aa8633ecb54e 100644 --- a/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_collection.py +++ b/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_collection.py @@ -11,7 +11,7 @@ from typing_extensions import override # pragma: no cover from azure.cosmos.aio import CosmosClient -from azure.cosmos.exceptions import CosmosBatchOperationError, CosmosHttpResponseError, CosmosResourceNotFoundError +from azure.cosmos.exceptions import CosmosHttpResponseError from azure.cosmos.partition_key import PartitionKey from semantic_kernel.connectors.memory.azure_cosmos_db.azure_cosmos_db_no_sql_base import AzureCosmosDBNoSQLBase @@ -36,10 +36,10 @@ from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult from semantic_kernel.data.vector_search.vector_text_search import VectorTextSearchMixin from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, - MemoryConnectorResourceNotFound, +from semantic_kernel.exceptions import ( + VectorSearchExecutionException, VectorStoreModelDeserializationException, + VectorStoreOperationException, ) from semantic_kernel.kernel_types import OneOrMany from semantic_kernel.utils.experimental_decorator import experimental_class @@ -120,15 +120,8 @@ async def _inner_upsert( **kwargs: Any, ) -> Sequence[TKey]: container_proxy = await self._get_container_proxy(self.collection_name, **kwargs) - try: - results = await asyncio.gather(*(container_proxy.upsert_item(record) for record in records)) - return [result[COSMOS_ITEM_ID_PROPERTY_NAME] for result in results] - except CosmosResourceNotFoundError as e: - raise MemoryConnectorResourceNotFound( - "The collection does not exist yet. Create the collection first." - ) from e - except (CosmosBatchOperationError, CosmosHttpResponseError) as e: - raise MemoryConnectorException("Failed to upsert items.") from e + results = await asyncio.gather(*(container_proxy.upsert_item(record) for record in records)) + return [result[COSMOS_ITEM_ID_PROPERTY_NAME] for result in results] @override async def _inner_get(self, keys: Sequence[TKey], **kwargs: Any) -> OneOrMany[Any] | None: @@ -140,14 +133,7 @@ async def _inner_get(self, keys: Sequence[TKey], **kwargs: Any) -> OneOrMany[Any parameters: list[dict[str, Any]] = [{"name": f"@id{i}", "value": get_key(key)} for i, key in enumerate(keys)] container_proxy = await self._get_container_proxy(self.collection_name, **kwargs) - try: - return [item async for item in container_proxy.query_items(query=query, parameters=parameters)] - except CosmosResourceNotFoundError as e: - raise MemoryConnectorResourceNotFound( - "The collection does not exist yet. Create the collection first." - ) from e - except Exception as e: - raise MemoryConnectorException("Failed to read items.") from e + return [item async for item in container_proxy.query_items(query=query, parameters=parameters)] @override async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None: @@ -158,7 +144,7 @@ async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None: ) exceptions = [result for result in results if isinstance(result, Exception)] if exceptions: - raise MemoryConnectorException("Failed to delete item(s).", exceptions) + raise VectorStoreOperationException("Failed to delete item(s).", exceptions) @override async def _inner_search( @@ -177,12 +163,12 @@ async def _inner_search( query = self._build_vector_query(options) params.append({"name": "@vector", "value": vector}) else: - raise ValueError("Either search_text or vector must be provided.") + raise VectorSearchExecutionException("Either search_text or vector must be provided.") container_proxy = await self._get_container_proxy(self.collection_name, **kwargs) try: results = container_proxy.query_items(query, parameters=params) - except Exception as e: - raise MemoryConnectorException("Failed to search items.") from e + except Exception as exc: + raise VectorSearchExecutionException("Failed to search items.") from exc return KernelSearchResults( results=self._get_vector_search_results_from_results(results, options), total_count=None, @@ -286,26 +272,26 @@ def _deserialize_store_models_to_dicts(self, records: Sequence[Any], **kwargs: A @override async def create_collection(self, **kwargs) -> None: + indexing_policy = kwargs.pop("indexing_policy", create_default_indexing_policy(self.data_model_definition)) + vector_embedding_policy = kwargs.pop( + "vector_embedding_policy", create_default_vector_embedding_policy(self.data_model_definition) + ) + database_proxy = await self._get_database_proxy(**kwargs) try: - database_proxy = await self._get_database_proxy(**kwargs) await database_proxy.create_container_if_not_exists( id=self.collection_name, partition_key=self.partition_key, - indexing_policy=kwargs.pop( - "indexing_policy", create_default_indexing_policy(self.data_model_definition) - ), - vector_embedding_policy=kwargs.pop( - "vector_embedding_policy", create_default_vector_embedding_policy(self.data_model_definition) - ), + indexing_policy=indexing_policy, + vector_embedding_policy=vector_embedding_policy, **kwargs, ) except CosmosHttpResponseError as e: - raise MemoryConnectorException("Failed to create container.") from e + raise VectorStoreOperationException("Failed to create container.") from e @override async def does_collection_exist(self, **kwargs) -> bool: + container_proxy = await self._get_container_proxy(self.collection_name, **kwargs) try: - container_proxy = await self._get_container_proxy(self.collection_name, **kwargs) await container_proxy.read(**kwargs) return True except CosmosHttpResponseError: @@ -313,11 +299,11 @@ async def does_collection_exist(self, **kwargs) -> bool: @override async def delete_collection(self, **kwargs) -> None: + database_proxy = await self._get_database_proxy(**kwargs) try: - database_proxy = await self._get_database_proxy(**kwargs) await database_proxy.delete_container(self.collection_name) - except CosmosHttpResponseError as e: - raise MemoryConnectorException("Container could not be deleted.") from e + except Exception as e: + raise VectorStoreOperationException("Container could not be deleted.") from e @override async def __aexit__(self, exc_type, exc_value, traceback) -> None: diff --git a/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_store.py b/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_store.py index 9cf7b0629d30..45ca18b58c55 100644 --- a/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_store.py +++ b/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_store.py @@ -18,7 +18,7 @@ from semantic_kernel.data.record_definition.vector_store_model_definition import VectorStoreRecordDefinition from semantic_kernel.data.vector_storage.vector_store import VectorStore from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection -from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorException +from semantic_kernel.exceptions import VectorStoreOperationException from semantic_kernel.utils.experimental_decorator import experimental_class TModel = TypeVar("TModel") @@ -93,7 +93,7 @@ async def list_collection_names(self, **kwargs) -> Sequence[str]: containers = database.list_containers() return [container["id"] async for container in containers] except Exception as e: - raise MemoryConnectorException("Failed to list collection names.") from e + raise VectorStoreOperationException("Failed to list collection names.") from e async def __aexit__(self, exc_type, exc_value, traceback) -> None: """Exit the context manager.""" diff --git a/python/semantic_kernel/connectors/memory/azure_cosmos_db/utils.py b/python/semantic_kernel/connectors/memory/azure_cosmos_db/utils.py index 7955153a8135..40115cb647ef 100644 --- a/python/semantic_kernel/connectors/memory/azure_cosmos_db/utils.py +++ b/python/semantic_kernel/connectors/memory/azure_cosmos_db/utils.py @@ -20,7 +20,7 @@ VectorStoreRecordDataField, VectorStoreRecordVectorField, ) -from semantic_kernel.exceptions.memory_connector_exceptions import VectorStoreModelException +from semantic_kernel.exceptions import VectorStoreModelException def to_vector_index_policy_type(index_kind: IndexKind | None) -> str: @@ -34,6 +34,9 @@ def to_vector_index_policy_type(index_kind: IndexKind | None) -> str: Returns: str: The vector index policy type. + + Raises: + VectorStoreModelException: If the index kind is not supported by Azure Cosmos DB NoSQL container. """ if index_kind is None: # Use IndexKind.FLAT as the default index kind. @@ -46,7 +49,18 @@ def to_vector_index_policy_type(index_kind: IndexKind | None) -> str: def to_distance_function(distance_function: DistanceFunction | None) -> str: - """Converts the distance function to the distance function for Azure Cosmos DB NoSQL container.""" + """Converts the distance function to the distance function for Azure Cosmos DB NoSQL container. + + Args: + distance_function: The distance function. + + Returns: + str: The distance function as defined by Azure Cosmos DB NoSQL container. + + Raises: + VectorStoreModelException: If the distance function is not supported by Azure Cosmos DB NoSQL container. + + """ if distance_function is None: # Use DistanceFunction.COSINE_SIMILARITY as the default distance function. return DISTANCE_FUNCTION_MAPPING[DistanceFunction.COSINE_SIMILARITY] @@ -60,7 +74,18 @@ def to_distance_function(distance_function: DistanceFunction | None) -> str: def to_datatype(property_type: str | None) -> str: - """Converts the property type to the data type for Azure Cosmos DB NoSQL container.""" + """Converts the property type to the data type for Azure Cosmos DB NoSQL container. + + Args: + property_type: The property type. + + Returns: + str: The data type as defined by Azure Cosmos DB NoSQL container. + + Raises: + VectorStoreModelException: If the property type is not supported by Azure Cosmos DB NoSQL container + + """ if property_type is None: # Use the default data type. return DATATYPES_MAPPING["default"] @@ -83,8 +108,11 @@ def create_default_indexing_policy(data_model_definition: VectorStoreRecordDefin Returns: dict[str, Any]: The indexing policy. + + Raises: + VectorStoreModelException: If the field is not full text searchable and not filterable. """ - indexing_policy = { + indexing_policy: dict[str, Any] = { "automatic": True, "includedPaths": [ { @@ -103,15 +131,15 @@ def create_default_indexing_policy(data_model_definition: VectorStoreRecordDefin if isinstance(field, VectorStoreRecordDataField) and ( not field.is_full_text_searchable and not field.is_filterable ): - indexing_policy["excludedPaths"].append({"path": f'/"{field.name}"/*'}) # type: ignore + indexing_policy["excludedPaths"].append({"path": f'/"{field.name}"/*'}) if isinstance(field, VectorStoreRecordVectorField): - indexing_policy["vectorIndexes"].append({ # type: ignore + indexing_policy["vectorIndexes"].append({ "path": f'/"{field.name}"', "type": to_vector_index_policy_type(field.index_kind), }) # Exclude the vector field from the index for performance optimization. - indexing_policy["excludedPaths"].append({"path": f'/"{field.name}"/*'}) # type: ignore + indexing_policy["excludedPaths"].append({"path": f'/"{field.name}"/*'}) return indexing_policy @@ -126,6 +154,10 @@ def create_default_vector_embedding_policy(data_model_definition: VectorStoreRec Returns: dict[str, Any]: The vector embedding policy. + + Raises: + VectorStoreModelException: If the datatype or distance function is not supported by Azure Cosmos DB NoSQL. + """ vector_embedding_policy: dict[str, Any] = {"vectorEmbeddings": []} diff --git a/python/semantic_kernel/connectors/memory/postgres/postgres_collection.py b/python/semantic_kernel/connectors/memory/postgres/postgres_collection.py index a833efcc855a..6de863646dc3 100644 --- a/python/semantic_kernel/connectors/memory/postgres/postgres_collection.py +++ b/python/semantic_kernel/connectors/memory/postgres/postgres_collection.py @@ -11,7 +11,6 @@ from typing_extensions import override # pragma: no cover from psycopg import sql -from psycopg.errors import DatabaseError from psycopg_pool import AsyncConnectionPool from pydantic import PrivateAttr @@ -30,9 +29,9 @@ VectorStoreRecordVectorField, ) from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, +from semantic_kernel.exceptions import ( VectorStoreModelValidationError, + VectorStoreOperationException, ) from semantic_kernel.kernel_types import OneOrMany from semantic_kernel.utils.experimental_decorator import experimental_class @@ -138,46 +137,42 @@ async def _inner_upsert( The keys of the upserted records. """ if self.connection_pool is None: - raise MemoryConnectorException("Connection pool is not available, use the collection as a context manager.") + raise VectorStoreOperationException( + "Connection pool is not available, use the collection as a context manager." + ) keys = [] - try: - async with ( - self.connection_pool.connection() as conn, - conn.transaction(), - conn.cursor() as cur, - ): - # Split the records into batches - max_rows_per_transaction = self._settings.max_rows_per_transaction - for i in range(0, len(records), max_rows_per_transaction): - record_batch = records[i : i + max_rows_per_transaction] - - fields = list(self.data_model_definition.fields.items()) - - row_values = [convert_dict_to_row(record, fields) for record in record_batch] - - # Execute the INSERT statement for each batch - await cur.executemany( - sql.SQL("INSERT INTO {}.{} ({}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {}").format( - sql.Identifier(self.db_schema), - sql.Identifier(self.collection_name), - sql.SQL(", ").join(sql.Identifier(field.name) for _, field in fields), - sql.SQL(", ").join(sql.Placeholder() * len(fields)), - sql.Identifier(self.data_model_definition.key_field.name), - sql.SQL(", ").join( - sql.SQL("{field} = EXCLUDED.{field}").format(field=sql.Identifier(field.name)) - for _, field in fields - if field.name != self.data_model_definition.key_field.name - ), + async with ( + self.connection_pool.connection() as conn, + conn.transaction(), + conn.cursor() as cur, + ): + # Split the records into batches + max_rows_per_transaction = self._settings.max_rows_per_transaction + for i in range(0, len(records), max_rows_per_transaction): + record_batch = records[i : i + max_rows_per_transaction] + + fields = list(self.data_model_definition.fields.items()) + + row_values = [convert_dict_to_row(record, fields) for record in record_batch] + + # Execute the INSERT statement for each batch + await cur.executemany( + sql.SQL("INSERT INTO {}.{} ({}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {}").format( + sql.Identifier(self.db_schema), + sql.Identifier(self.collection_name), + sql.SQL(", ").join(sql.Identifier(field.name) for _, field in fields), + sql.SQL(", ").join(sql.Placeholder() * len(fields)), + sql.Identifier(self.data_model_definition.key_field.name), + sql.SQL(", ").join( + sql.SQL("{field} = EXCLUDED.{field}").format(field=sql.Identifier(field.name)) + for _, field in fields + if field.name != self.data_model_definition.key_field.name ), - row_values, - ) - keys.extend(record.get(self.data_model_definition.key_field.name) for record in record_batch) - - except DatabaseError as error: - # Rollback happens automatically if an exception occurs within the transaction block - raise MemoryConnectorException(f"Error upserting records: {error}") from error - + ), + row_values, + ) + keys.extend(record.get(self.data_model_definition.key_field.name) for record in record_batch) return keys @override @@ -192,27 +187,25 @@ async def _inner_get(self, keys: Sequence[TKey], **kwargs: Any) -> OneOrMany[dic The records from the store, not deserialized. """ if self.connection_pool is None: - raise MemoryConnectorException("Connection pool is not available, use the collection as a context manager.") + raise VectorStoreOperationException( + "Connection pool is not available, use the collection as a context manager." + ) fields = [(field.name, field) for field in self.data_model_definition.fields.values()] - try: - async with self.connection_pool.connection() as conn, conn.cursor() as cur: - await cur.execute( - sql.SQL("SELECT {} FROM {}.{} WHERE {} IN ({})").format( - sql.SQL(", ").join(sql.Identifier(name) for (name, _) in fields), - sql.Identifier(self.db_schema), - sql.Identifier(self.collection_name), - sql.Identifier(self.data_model_definition.key_field.name), - sql.SQL(", ").join(sql.Literal(key) for key in keys), - ) + async with self.connection_pool.connection() as conn, conn.cursor() as cur: + await cur.execute( + sql.SQL("SELECT {} FROM {}.{} WHERE {} IN ({})").format( + sql.SQL(", ").join(sql.Identifier(name) for (name, _) in fields), + sql.Identifier(self.db_schema), + sql.Identifier(self.collection_name), + sql.Identifier(self.data_model_definition.key_field.name), + sql.SQL(", ").join(sql.Literal(key) for key in keys), ) - rows = await cur.fetchall() - if not rows: - return None - return [convert_row_to_dict(row, fields) for row in rows] - - except DatabaseError as error: - raise MemoryConnectorException(f"Error getting records: {error}") from error + ) + rows = await cur.fetchall() + if not rows: + return None + return [convert_row_to_dict(row, fields) for row in rows] @override async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None: @@ -223,32 +216,29 @@ async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None: **kwargs: Additional arguments. """ if self.connection_pool is None: - raise MemoryConnectorException("Connection pool is not available, use the collection as a context manager.") - - try: - async with ( - self.connection_pool.connection() as conn, - conn.transaction(), - conn.cursor() as cur, - ): - # Split the keys into batches - max_rows_per_transaction = self._settings.max_rows_per_transaction - for i in range(0, len(keys), max_rows_per_transaction): - key_batch = keys[i : i + max_rows_per_transaction] - - # Execute the DELETE statement for each batch - await cur.execute( - sql.SQL("DELETE FROM {}.{} WHERE {} IN ({})").format( - sql.Identifier(self.db_schema), - sql.Identifier(self.collection_name), - sql.Identifier(self.data_model_definition.key_field.name), - sql.SQL(", ").join(sql.Literal(key) for key in key_batch), - ) - ) + raise VectorStoreOperationException( + "Connection pool is not available, use the collection as a context manager." + ) - except DatabaseError as error: - # Rollback happens automatically if an exception occurs within the transaction block - raise MemoryConnectorException(f"Error deleting records: {error}") from error + async with ( + self.connection_pool.connection() as conn, + conn.transaction(), + conn.cursor() as cur, + ): + # Split the keys into batches + max_rows_per_transaction = self._settings.max_rows_per_transaction + for i in range(0, len(keys), max_rows_per_transaction): + key_batch = keys[i : i + max_rows_per_transaction] + + # Execute the DELETE statement for each batch + await cur.execute( + sql.SQL("DELETE FROM {}.{} WHERE {} IN ({})").format( + sql.Identifier(self.db_schema), + sql.Identifier(self.collection_name), + sql.Identifier(self.data_model_definition.key_field.name), + sql.SQL(", ").join(sql.Literal(key) for key in key_batch), + ) + ) @override def _serialize_dicts_to_store_models(self, records: Sequence[dict[str, Any]], **kwargs: Any) -> Sequence[Any]: @@ -276,7 +266,9 @@ async def create_collection(self, **kwargs: Any) -> None: **kwargs: Additional arguments """ if self.connection_pool is None: - raise MemoryConnectorException("Connection pool is not available, use the collection as a context manager.") + raise VectorStoreOperationException( + "Connection pool is not available, use the collection as a context manager." + ) column_definitions = [] table_name = self.collection_name @@ -310,20 +302,16 @@ async def create_collection(self, **kwargs: Any) -> None: sql.Identifier(self.db_schema), sql.Identifier(table_name), columns_str ) - try: - async with self.connection_pool.connection() as conn, conn.cursor() as cur: - await cur.execute(create_table_query) - await conn.commit() - - logger.info(f"Postgres table '{table_name}' created successfully.") + async with self.connection_pool.connection() as conn, conn.cursor() as cur: + await cur.execute(create_table_query) + await conn.commit() - # If the vector field defines an index, apply it - for vector_field in self.data_model_definition.vector_fields: - if vector_field.index_kind: - await self._create_index(table_name, vector_field) + logger.info(f"Postgres table '{table_name}' created successfully.") - except DatabaseError as error: - raise MemoryConnectorException(f"Error creating table: {error}") from error + # If the vector field defines an index, apply it + for vector_field in self.data_model_definition.vector_fields: + if vector_field.index_kind: + await self._create_index(table_name, vector_field) async def _create_index(self, table_name: str, vector_field: VectorStoreRecordVectorField) -> None: """Create an index on a column in the table. @@ -333,14 +321,16 @@ async def _create_index(self, table_name: str, vector_field: VectorStoreRecordVe vector_field: The vector field definition that the index is based on. """ if self.connection_pool is None: - raise MemoryConnectorException("Connection pool is not available, use the collection as a context manager.") + raise VectorStoreOperationException( + "Connection pool is not available, use the collection as a context manager." + ) column_name = vector_field.name index_name = f"{table_name}_{column_name}_idx" # Only support creating HNSW indexes through the vector store if vector_field.index_kind != IndexKind.HNSW: - raise MemoryConnectorException( + raise VectorStoreOperationException( f"Unsupported index kind: {vector_field.index_kind}. " "If you need to create an index of this type, please do so manually. " "Only HNSW indexes are supported through the vector store." @@ -348,37 +338,35 @@ async def _create_index(self, table_name: str, vector_field: VectorStoreRecordVe # Require the distance function to be set for HNSW indexes if not vector_field.distance_function: - raise MemoryConnectorException( + raise VectorStoreOperationException( "Distance function must be set for HNSW indexes. " "Please set the distance function in the vector field definition." ) ops_str = get_vector_index_ops_str(vector_field.distance_function) - try: - async with self.connection_pool.connection() as conn, conn.cursor() as cur: - await cur.execute( - sql.SQL("CREATE INDEX {} ON {}.{} USING {} ({} {})").format( - sql.Identifier(index_name), - sql.Identifier(self.db_schema), - sql.Identifier(table_name), - sql.SQL(vector_field.index_kind), - sql.Identifier(column_name), - sql.SQL(ops_str), - ) + async with self.connection_pool.connection() as conn, conn.cursor() as cur: + await cur.execute( + sql.SQL("CREATE INDEX {} ON {}.{} USING {} ({} {})").format( + sql.Identifier(index_name), + sql.Identifier(self.db_schema), + sql.Identifier(table_name), + sql.SQL(vector_field.index_kind), + sql.Identifier(column_name), + sql.SQL(ops_str), ) - await conn.commit() - - logger.info(f"Index '{index_name}' created successfully on column '{column_name}'.") + ) + await conn.commit() - except DatabaseError as error: - raise MemoryConnectorException(f"Error creating index: {error}") from error + logger.info(f"Index '{index_name}' created successfully on column '{column_name}'.") @override async def does_collection_exist(self, **kwargs: Any) -> bool: """Check if the collection exists.""" if self.connection_pool is None: - raise MemoryConnectorException("Connection pool is not available, use the collection as a context manager.") + raise VectorStoreOperationException( + "Connection pool is not available, use the collection as a context manager." + ) async with self.connection_pool.connection() as conn, conn.cursor() as cur: await cur.execute( @@ -396,7 +384,9 @@ async def does_collection_exist(self, **kwargs: Any) -> bool: async def delete_collection(self, **kwargs: Any) -> None: """Delete the collection.""" if self.connection_pool is None: - raise MemoryConnectorException("Connection pool is not available, use the collection as a context manager.") + raise VectorStoreOperationException( + "Connection pool is not available, use the collection as a context manager." + ) async with self.connection_pool.connection() as conn, conn.cursor() as cur: await cur.execute( diff --git a/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py b/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py index cb30fa0cdc76..b0cfd8244299 100644 --- a/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py +++ b/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py @@ -23,11 +23,11 @@ from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin from semantic_kernel.exceptions import ( - MemoryConnectorInitializationError, + VectorSearchExecutionException, + VectorStoreInitializationException, VectorStoreModelValidationError, + VectorStoreOperationException, ) -from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorException -from semantic_kernel.exceptions.search_exceptions import VectorSearchExecutionException from semantic_kernel.kernel_types import OneOrMany from semantic_kernel.utils.experimental_decorator import experimental_class from semantic_kernel.utils.telemetry.user_agent import APP_INFO, prepend_semantic_kernel_to_user_agent @@ -125,14 +125,14 @@ def __init__( env_file_encoding=env_file_encoding, ) except ValidationError as ex: - raise MemoryConnectorInitializationError("Failed to create Qdrant settings.", ex) from ex + raise VectorStoreInitializationException("Failed to create Qdrant settings.", ex) from ex if APP_INFO: kwargs.setdefault("metadata", {}) kwargs["metadata"] = prepend_semantic_kernel_to_user_agent(kwargs["metadata"]) try: client = AsyncQdrantClientWrapper(**settings.model_dump(exclude_none=True), **kwargs) except ValueError as ex: - raise MemoryConnectorInitializationError("Failed to create Qdrant client.", ex) from ex + raise VectorStoreInitializationException("Failed to create Qdrant client.", ex) from ex super().__init__( data_model_type=data_model_type, data_model_definition=data_model_definition, @@ -274,7 +274,7 @@ async def create_collection(self, **kwargs) -> None: vector = self.data_model_definition.fields[field] assert isinstance(vector, VectorStoreRecordVectorField) # nosec if not vector.dimensions: - raise MemoryConnectorException("Vector field must have dimensions.") + raise VectorStoreOperationException("Vector field must have dimensions.") vectors_config[field] = VectorParams( size=vector.dimensions, distance=DISTANCE_FUNCTION_MAP[vector.distance_function or "default"], @@ -284,7 +284,7 @@ async def create_collection(self, **kwargs) -> None: vector = self.data_model_definition.fields[self.data_model_definition.vector_field_names[0]] assert isinstance(vector, VectorStoreRecordVectorField) # nosec if not vector.dimensions: - raise MemoryConnectorException("Vector field must have dimensions.") + raise VectorStoreOperationException("Vector field must have dimensions.") vectors_config = VectorParams( size=vector.dimensions, distance=DISTANCE_FUNCTION_MAP[vector.distance_function or "default"], diff --git a/python/semantic_kernel/connectors/memory/qdrant/qdrant_store.py b/python/semantic_kernel/connectors/memory/qdrant/qdrant_store.py index f0fe5000cd8b..0fd00bc59532 100644 --- a/python/semantic_kernel/connectors/memory/qdrant/qdrant_store.py +++ b/python/semantic_kernel/connectors/memory/qdrant/qdrant_store.py @@ -16,7 +16,7 @@ from semantic_kernel.connectors.memory.qdrant.qdrant_collection import QdrantCollection from semantic_kernel.data.record_definition import VectorStoreRecordDefinition from semantic_kernel.data.vector_storage import VectorStore -from semantic_kernel.exceptions import MemoryConnectorInitializationError +from semantic_kernel.exceptions import VectorStoreInitializationException from semantic_kernel.utils.experimental_decorator import experimental_class from semantic_kernel.utils.telemetry.user_agent import APP_INFO, prepend_semantic_kernel_to_user_agent @@ -94,14 +94,14 @@ def __init__( env_file_encoding=env_file_encoding, ) except ValidationError as ex: - raise MemoryConnectorInitializationError("Failed to create Qdrant settings.", ex) from ex + raise VectorStoreInitializationException("Failed to create Qdrant settings.", ex) from ex if APP_INFO: kwargs.setdefault("metadata", {}) kwargs["metadata"] = prepend_semantic_kernel_to_user_agent(kwargs["metadata"]) try: client = AsyncQdrantClient(**settings.model_dump(exclude_none=True), **kwargs) except ValueError as ex: - raise MemoryConnectorInitializationError("Failed to create Qdrant client.", ex) from ex + raise VectorStoreInitializationException("Failed to create Qdrant client.", ex) from ex super().__init__(qdrant_client=client) def get_collection( diff --git a/python/semantic_kernel/connectors/memory/redis/redis_collection.py b/python/semantic_kernel/connectors/memory/redis/redis_collection.py index f4f59a1ecded..73f3a2bd4dea 100644 --- a/python/semantic_kernel/connectors/memory/redis/redis_collection.py +++ b/python/semantic_kernel/connectors/memory/redis/redis_collection.py @@ -47,11 +47,12 @@ from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult from semantic_kernel.data.vector_search.vector_text_search import VectorTextSearchMixin from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, - MemoryConnectorInitializationError, +from semantic_kernel.exceptions import ( + VectorSearchExecutionException, + VectorSearchOptionsException, + VectorStoreInitializationException, + VectorStoreOperationException, ) -from semantic_kernel.exceptions.search_exceptions import VectorSearchExecutionException, VectorSearchOptionsException from semantic_kernel.utils.experimental_decorator import experimental_class from semantic_kernel.utils.list_handler import desync_list @@ -111,7 +112,7 @@ def __init__( env_file_encoding=env_file_encoding, ) except ValidationError as ex: - raise MemoryConnectorInitializationError("Failed to create Redis settings.", ex) from ex + raise VectorStoreInitializationException("Failed to create Redis settings.", ex) from ex super().__init__( data_model_type=data_model_type, data_model_definition=data_model_definition, @@ -149,7 +150,7 @@ async def create_collection(self, **kwargs) -> None: fields, definition=index_definition, **kwargs ) return - raise MemoryConnectorException("Invalid index type supplied.") + raise VectorStoreOperationException("Invalid index type supplied.") fields = data_model_definition_to_redis_fields(self.data_model_definition, self.collection_type) index_definition = IndexDefinition( prefix=f"{self.collection_name}:", index_type=INDEX_TYPE_MAP[self.collection_type] diff --git a/python/semantic_kernel/connectors/memory/redis/redis_store.py b/python/semantic_kernel/connectors/memory/redis/redis_store.py index 4e0629ba4271..8764027e0cd8 100644 --- a/python/semantic_kernel/connectors/memory/redis/redis_store.py +++ b/python/semantic_kernel/connectors/memory/redis/redis_store.py @@ -18,7 +18,7 @@ from semantic_kernel.connectors.memory.redis.utils import RedisWrapper from semantic_kernel.data.record_definition import VectorStoreRecordDefinition from semantic_kernel.data.vector_storage import VectorStore, VectorStoreRecordCollection -from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorInitializationError +from semantic_kernel.exceptions import VectorStoreInitializationException from semantic_kernel.utils.experimental_decorator import experimental_class logger: logging.Logger = logging.getLogger(__name__) @@ -58,7 +58,7 @@ def __init__( env_file_encoding=env_file_encoding, ) except ValidationError as ex: - raise MemoryConnectorInitializationError("Failed to create Redis settings.", ex) from ex + raise VectorStoreInitializationException("Failed to create Redis settings.", ex) from ex super().__init__(redis_database=RedisWrapper.from_url(redis_settings.connection_string.get_secret_value())) @override diff --git a/python/semantic_kernel/connectors/memory/redis/utils.py b/python/semantic_kernel/connectors/memory/redis/utils.py index 68054281d156..35305c4118d8 100644 --- a/python/semantic_kernel/connectors/memory/redis/utils.py +++ b/python/semantic_kernel/connectors/memory/redis/utils.py @@ -27,7 +27,7 @@ VectorStoreRecordVectorField, ) from semantic_kernel.data.vector_search.vector_search_filter import VectorSearchFilter -from semantic_kernel.exceptions.search_exceptions import VectorSearchOptionsException +from semantic_kernel.exceptions import VectorSearchOptionsException from semantic_kernel.memory.memory_record import MemoryRecord diff --git a/python/semantic_kernel/connectors/memory/weaviate/utils.py b/python/semantic_kernel/connectors/memory/weaviate/utils.py index 363d4dae49fe..3f295d1076b7 100644 --- a/python/semantic_kernel/connectors/memory/weaviate/utils.py +++ b/python/semantic_kernel/connectors/memory/weaviate/utils.py @@ -18,7 +18,7 @@ VectorStoreRecordVectorField, ) from semantic_kernel.data.vector_search.vector_search_filter import VectorSearchFilter -from semantic_kernel.exceptions.memory_connector_exceptions import ( +from semantic_kernel.exceptions import ( VectorStoreModelDeserializationException, ) diff --git a/python/semantic_kernel/connectors/memory/weaviate/weaviate_collection.py b/python/semantic_kernel/connectors/memory/weaviate/weaviate_collection.py index 869d80e4df22..947188a3c819 100644 --- a/python/semantic_kernel/connectors/memory/weaviate/weaviate_collection.py +++ b/python/semantic_kernel/connectors/memory/weaviate/weaviate_collection.py @@ -41,10 +41,11 @@ from semantic_kernel.data.vector_search.vectorizable_text_search import VectorizableTextSearchMixin from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin from semantic_kernel.exceptions import ( - MemoryConnectorException, - MemoryConnectorInitializationError, + VectorSearchExecutionException, + VectorStoreInitializationException, + VectorStoreModelValidationError, + VectorStoreOperationException, ) -from semantic_kernel.exceptions.memory_connector_exceptions import VectorStoreModelValidationError from semantic_kernel.kernel_types import OneOrMany from semantic_kernel.utils.experimental_decorator import experimental_class @@ -140,7 +141,7 @@ def __init__( " a local Weaviate instance, or the client embedding options.", ) except Exception as e: - raise MemoryConnectorInitializationError(f"Failed to initialize Weaviate client: {e}") + raise VectorStoreInitializationException(f"Failed to initialize Weaviate client: {e}") super().__init__( data_model_type=data_model_type, @@ -170,47 +171,24 @@ async def _inner_upsert( **kwargs: Any, ) -> Sequence[TKey]: assert all([isinstance(record, DataObject) for record in records]) # nosec - - try: - collection: CollectionAsync = self.async_client.collections.get(self.collection_name) - response = await collection.data.insert_many(records) - except WeaviateClosedClientError as ex: - raise MemoryConnectorException( - "Client is closed, please use the context manager or self.async_client.connect." - ) from ex - except Exception as ex: - raise MemoryConnectorException(f"Failed to upsert records: {ex}") - + collection: CollectionAsync = self.async_client.collections.get(self.collection_name) + response = await collection.data.insert_many(records) return [str(v) for _, v in response.uuids.items()] @override async def _inner_get(self, keys: Sequence[TKey], **kwargs: Any) -> OneOrMany[Any] | None: - try: - collection: CollectionAsync = self.async_client.collections.get(self.collection_name) - result = await collection.query.fetch_objects( - filters=Filter.any_of([Filter.by_id().equal(key) for key in keys]), - include_vector=kwargs.get("include_vectors", False), - ) + collection: CollectionAsync = self.async_client.collections.get(self.collection_name) + result = await collection.query.fetch_objects( + filters=Filter.any_of([Filter.by_id().equal(key) for key in keys]), + include_vector=kwargs.get("include_vectors", False), + ) - return result.objects - except WeaviateClosedClientError as ex: - raise MemoryConnectorException( - "Client is closed, please use the context manager or self.async_client.connect." - ) from ex - except Exception as ex: - raise MemoryConnectorException(f"Failed to get records: {ex}") + return result.objects @override async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None: - try: - collection: CollectionAsync = self.async_client.collections.get(self.collection_name) - await collection.data.delete_many(where=Filter.any_of([Filter.by_id().equal(key) for key in keys])) - except WeaviateClosedClientError as ex: - raise MemoryConnectorException( - "Client is closed, please use the context manager or self.async_client.connect." - ) from ex - except Exception as ex: - raise MemoryConnectorException(f"Failed to delete records: {ex}") + collection: CollectionAsync = self.async_client.collections.get(self.collection_name) + await collection.data.delete_many(where=Filter.any_of([Filter.by_id().equal(key) for key in keys])) @override async def _inner_search( @@ -236,7 +214,7 @@ async def _inner_search( elif vector: results = await self._inner_vectorized_search(collection, vector, vector_field, args) else: - raise MemoryConnectorException("No search criteria provided.") + raise VectorSearchExecutionException("No search criteria provided.") return KernelSearchResults( results=self._get_vector_search_results_from_results(results.objects), total_count=len(results.objects) @@ -252,7 +230,7 @@ async def _inner_vector_text_search( **args, ) except Exception as ex: - raise MemoryConnectorException(f"Failed searching using a text: {ex}") from ex + raise VectorSearchExecutionException(f"Failed searching using a text: {ex}") from ex async def _inner_vectorizable_text_search( self, @@ -262,7 +240,7 @@ async def _inner_vectorizable_text_search( args: dict[str, Any], ) -> Any: if self.named_vectors and not vector_field: - raise MemoryConnectorException( + raise VectorSearchExecutionException( "Vectorizable text search requires a vector field to be specified in the options." ) try: @@ -281,7 +259,7 @@ async def _inner_vectorizable_text_search( "Alternatively you could use a existing collection that has a vectorizer setup." "See also: https://weaviate.io/developers/weaviate/manage-data/collections#create-a-collection" ) - raise MemoryConnectorException(f"Failed searching using a vectorizable text: {ex}") from ex + raise VectorSearchExecutionException(f"Failed searching using a vectorizable text: {ex}") from ex async def _inner_vectorized_search( self, @@ -291,7 +269,7 @@ async def _inner_vectorized_search( args: dict[str, Any], ) -> Any: if self.named_vectors and not vector_field: - raise MemoryConnectorException( + raise VectorSearchExecutionException( "Vectorizable text search requires a vector field to be specified in the options." ) try: @@ -302,11 +280,11 @@ async def _inner_vectorized_search( **args, ) except WeaviateClosedClientError as ex: - raise MemoryConnectorException( + raise VectorSearchExecutionException( "Client is closed, please use the context manager or self.async_client.connect." ) from ex except Exception as ex: - raise MemoryConnectorException(f"Failed searching using a vector: {ex}") from ex + raise VectorSearchExecutionException(f"Failed searching using a vector: {ex}") from ex def _get_record_from_result(self, result: Any) -> Any: """Get the record from the returned search result.""" @@ -361,18 +339,18 @@ async def create_collection(self, **kwargs) -> None: Make sure to check the arguments of that method for the specifications. """ if not self.named_vectors and len(self.data_model_definition.vector_field_names) != 1: - raise MemoryConnectorException( + raise VectorStoreOperationException( "Named vectors must be enabled if there is not exactly one vector field in the data model definition." ) if kwargs: try: await self.async_client.collections.create(**kwargs) except WeaviateClosedClientError as ex: - raise MemoryConnectorException( + raise VectorStoreOperationException( "Client is closed, please use the context manager or self.async_client.connect." ) from ex except Exception as ex: - raise MemoryConnectorException(f"Failed to create collection: {ex}") from ex + raise VectorStoreOperationException(f"Failed to create collection: {ex}") from ex try: await self.async_client.collections.create( name=self.collection_name, @@ -387,11 +365,11 @@ async def create_collection(self, **kwargs) -> None: else None, ) except WeaviateClosedClientError as ex: - raise MemoryConnectorException( + raise VectorStoreOperationException( "Client is closed, please use the context manager or self.async_client.connect." ) from ex except Exception as ex: - raise MemoryConnectorException(f"Failed to create collection: {ex}") from ex + raise VectorStoreOperationException(f"Failed to create collection: {ex}") from ex @override async def does_collection_exist(self, **kwargs) -> bool: @@ -406,11 +384,11 @@ async def does_collection_exist(self, **kwargs) -> bool: try: return await self.async_client.collections.exists(self.collection_name) except WeaviateClosedClientError as ex: - raise MemoryConnectorException( + raise VectorStoreOperationException( "Client is closed, please use the context manager or self.async_client.connect." ) from ex except Exception as ex: - raise MemoryConnectorException(f"Failed to check if collection exists: {ex}") from ex + raise VectorStoreOperationException(f"Failed to check if collection exists: {ex}") from ex @override async def delete_collection(self, **kwargs) -> None: @@ -422,11 +400,11 @@ async def delete_collection(self, **kwargs) -> None: try: await self.async_client.collections.delete(self.collection_name) except WeaviateClosedClientError as ex: - raise MemoryConnectorException( + raise VectorStoreOperationException( "Client is closed, please use the context manager or self.async_client.connect." ) from ex except Exception as ex: - raise MemoryConnectorException(f"Failed to delete collection: {ex}") from ex + raise VectorStoreOperationException(f"Failed to delete collection: {ex}") from ex @override async def __aenter__(self) -> "WeaviateCollection": diff --git a/python/semantic_kernel/connectors/memory/weaviate/weaviate_store.py b/python/semantic_kernel/connectors/memory/weaviate/weaviate_store.py index 82c05c769361..9d57a6e588c2 100644 --- a/python/semantic_kernel/connectors/memory/weaviate/weaviate_store.py +++ b/python/semantic_kernel/connectors/memory/weaviate/weaviate_store.py @@ -19,9 +19,9 @@ from semantic_kernel.data.vector_storage.vector_store import VectorStore from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection from semantic_kernel.exceptions import ( - MemoryConnectorConnectionException, - MemoryConnectorException, - MemoryConnectorInitializationError, + VectorStoreException, + VectorStoreInitializationException, + VectorStoreOperationException, ) from semantic_kernel.utils.experimental_decorator import experimental_class @@ -95,7 +95,7 @@ def __init__( " a local Weaviate instance, or the client embedding options.", ) except Exception as e: - raise MemoryConnectorInitializationError(f"Failed to initialize Weaviate client: {e}") + raise VectorStoreInitializationException(f"Failed to initialize Weaviate client: {e}") super().__init__(async_client=async_client, managed_client=managed_client) @@ -124,7 +124,7 @@ async def list_collection_names(self, **kwargs) -> Sequence[str]: collections = await self.async_client.collections.list_all() return [collection.name for collection in collections] except Exception as e: - raise MemoryConnectorException(f"Failed to list Weaviate collections: {e}") + raise VectorStoreOperationException(f"Failed to list Weaviate collections: {e}") @override async def __aenter__(self) -> "VectorStore": @@ -133,7 +133,7 @@ async def __aenter__(self) -> "VectorStore": try: await self.async_client.connect() except WeaviateConnectionError as exc: - raise MemoryConnectorConnectionException("Weaviate client cannot connect.") from exc + raise VectorStoreException("Weaviate client cannot connect.") from exc return self @override diff --git a/python/semantic_kernel/data/record_definition/vector_store_model_decorator.py b/python/semantic_kernel/data/record_definition/vector_store_model_decorator.py index 6d5a89719292..d9244076bcbb 100644 --- a/python/semantic_kernel/data/record_definition/vector_store_model_decorator.py +++ b/python/semantic_kernel/data/record_definition/vector_store_model_decorator.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. import logging -from inspect import _empty, signature -from types import NoneType +from inspect import Parameter, _empty, signature +from types import MappingProxyType, NoneType from typing import Any from semantic_kernel.data.record_definition.vector_store_model_definition import VectorStoreRecordDefinition @@ -10,7 +10,7 @@ VectorStoreRecordField, VectorStoreRecordVectorField, ) -from semantic_kernel.exceptions.memory_connector_exceptions import VectorStoreModelException +from semantic_kernel.exceptions import VectorStoreModelException from semantic_kernel.utils.experimental_decorator import experimental_function logger = logging.getLogger(__name__) @@ -27,6 +27,7 @@ def vectorstoremodel( - The class must have at least one field with a annotation, of type VectorStoreRecordKeyField, VectorStoreRecordDataField or VectorStoreRecordVectorField. - The class must have exactly one field with the VectorStoreRecordKeyField annotation. + - A field with multiple VectorStoreRecordKeyField annotations will be set to the first one found. Optionally, when there are VectorStoreRecordDataFields that specify a embedding property name, there must be a corresponding VectorStoreRecordVectorField with the same name. @@ -59,63 +60,66 @@ def wrap(cls: Any): return wrap(cls) -def _parse_signature_to_definition(parameters) -> VectorStoreRecordDefinition: +def _parse_signature_to_definition(parameters: MappingProxyType[str, Parameter]) -> VectorStoreRecordDefinition: if len(parameters) == 0: raise VectorStoreModelException( "There must be at least one field in the datamodel. If you are using this with a @dataclass, " "you might have inverted the order of the decorators, the vectorstoremodel decorator should be the top one." ) - fields: dict[str, VectorStoreRecordField] = {} - for field in parameters.values(): - annotation = field.annotation - # check first if there are any annotations - if not hasattr(annotation, "__metadata__"): - if field._default is _empty: - raise VectorStoreModelException( - "Fields that do not have a VectorStoreRecord* annotation must have a default value." - ) - logger.info( - f'Field "{field.name}" does not have a VectorStoreRecord* annotation, will not be part of the record.' - ) - continue - property_type = annotation.__origin__ + return VectorStoreRecordDefinition( + fields={ + field.name: field for field in [_parse_parameter_to_field(field) for field in parameters.values()] if field + } + ) + + +def _parse_parameter_to_field(field: Parameter) -> VectorStoreRecordField | None: + for field_annotation in getattr(field.annotation, "__metadata__", []): + if isinstance(field_annotation, VectorStoreRecordField): + return _parse_vector_store_record_field_instance(field_annotation, field) + if isinstance(field_annotation, type(VectorStoreRecordField)): + return _parse_vector_store_record_field_class(field_annotation, field) + + # This means there are no annotations or non VectorStoreRecordField annotations + if field.default is _empty: + raise VectorStoreModelException( + "Fields that do not have a VectorStoreRecord* annotation must have a default value." + ) + logger.debug( + f'Field "{field.name}" does not have a VectorStoreRecordField annotation, will not be part of the record.' + ) + return None + + +def _parse_vector_store_record_field_instance( + record_field: VectorStoreRecordField, field: Parameter +) -> VectorStoreRecordField: + if not record_field.name or record_field.name != field.name: + record_field.name = field.name + if not record_field.property_type: + property_type = field.annotation.__origin__ if (args := getattr(property_type, "__args__", None)) and NoneType in args and len(args) == 2: property_type = args[0] - metadata = annotation.__metadata__ - field_type = None - for item in metadata: - if isinstance(item, VectorStoreRecordField): - field_type = item - if not field_type.name or field_type.name != field.name: - field_type.name = field.name - if not field_type.property_type: - if hasattr(property_type, "__args__"): - if isinstance(item, VectorStoreRecordVectorField): - field_type.property_type = property_type.__args__[0].__name__ - elif property_type.__name__ == "list": - field_type.property_type = f"{property_type.__name__}[{property_type.__args__[0].__name__}]" - else: - field_type.property_type = property_type.__name__ - - else: - field_type.property_type = property_type.__name__ - elif isinstance(item, type(VectorStoreRecordField)): - if hasattr(property_type, "__args__") and property_type.__name__ == "list": - property_type_name = f"{property_type.__name__}[{property_type.__args__[0].__name__}]" - else: - property_type_name = property_type.__name__ - field_type = item(name=field.name, property_type=property_type_name) - if not field_type: - if field._default is _empty: - raise VectorStoreModelException( - "Fields that do not have a VectorStoreRecord* annotation must have a default value." - ) - logger.debug( - f'Field "{field.name}" does not have a VectorStoreRecordField ' - "annotation, will not be part of the record." - ) - continue - # field name is set either when not None or by instantiating a new field - assert field_type.name is not None # nosec - fields[field_type.name] = field_type - return VectorStoreRecordDefinition(fields=fields) + if hasattr(property_type, "__args__"): + if isinstance(record_field, VectorStoreRecordVectorField): + record_field.property_type = property_type.__args__[0].__name__ + elif property_type.__name__ == "list": + record_field.property_type = f"{property_type.__name__}[{property_type.__args__[0].__name__}]" + else: + record_field.property_type = property_type.__name__ + else: + record_field.property_type = property_type.__name__ + return record_field + + +def _parse_vector_store_record_field_class( + field_type: type[VectorStoreRecordField], field: Parameter +) -> VectorStoreRecordField: + property_type = field.annotation.__origin__ + if (args := getattr(property_type, "__args__", None)) and NoneType in args and len(args) == 2: + property_type = args[0] + if hasattr(property_type, "__args__") and property_type.__name__ == "list": + property_type_name = f"{property_type.__name__}[{property_type.__args__[0].__name__}]" + else: + property_type_name = property_type.__name__ + return field_type(name=field.name, property_type=property_type_name) diff --git a/python/semantic_kernel/data/record_definition/vector_store_model_definition.py b/python/semantic_kernel/data/record_definition/vector_store_model_definition.py index 5d7e01faa42b..adc993ff22e3 100644 --- a/python/semantic_kernel/data/record_definition/vector_store_model_definition.py +++ b/python/semantic_kernel/data/record_definition/vector_store_model_definition.py @@ -4,10 +4,10 @@ from typing import TypeVar from semantic_kernel.data.record_definition.vector_store_model_protocols import ( - DeserializeProtocol, - FromDictProtocol, - SerializeProtocol, - ToDictProtocol, + DeserializeFunctionProtocol, + FromDictFunctionProtocol, + SerializeFunctionProtocol, + ToDictFunctionProtocol, ) from semantic_kernel.data.record_definition.vector_store_record_fields import ( VectorStoreRecordDataField, @@ -15,7 +15,7 @@ VectorStoreRecordKeyField, VectorStoreRecordVectorField, ) -from semantic_kernel.exceptions.memory_connector_exceptions import VectorStoreModelException +from semantic_kernel.exceptions import VectorStoreModelException from semantic_kernel.utils.experimental_decorator import experimental_class VectorStoreRecordFields = TypeVar("VectorStoreRecordFields", bound=VectorStoreRecordField) @@ -40,10 +40,10 @@ class VectorStoreRecordDefinition: key_field_name: str = field(init=False) fields: FieldsType container_mode: bool = False - to_dict: ToDictProtocol | None = None - from_dict: FromDictProtocol | None = None - serialize: SerializeProtocol | None = None - deserialize: DeserializeProtocol | None = None + to_dict: ToDictFunctionProtocol | None = None + from_dict: FromDictFunctionProtocol | None = None + serialize: SerializeFunctionProtocol | None = None + deserialize: DeserializeFunctionProtocol | None = None @property def field_names(self) -> list[str]: @@ -126,7 +126,7 @@ def __post_init__(self): for name, value in self.fields.items(): if not name: raise VectorStoreModelException("Fields must have a name.") - if value.name is None: + if not value.name: value.name = name if ( isinstance(value, VectorStoreRecordDataField) diff --git a/python/semantic_kernel/data/record_definition/vector_store_model_protocols.py b/python/semantic_kernel/data/record_definition/vector_store_model_protocols.py index 452827857b48..0c83a131c965 100644 --- a/python/semantic_kernel/data/record_definition/vector_store_model_protocols.py +++ b/python/semantic_kernel/data/record_definition/vector_store_model_protocols.py @@ -10,8 +10,8 @@ @experimental_class @runtime_checkable -class VectorStoreModelFunctionSerdeProtocol(Protocol): - """Data model serialization and deserialization protocol. +class SerializeMethodProtocol(Protocol): + """Data model serialization protocol. This can optionally be implemented to allow single step serialization and deserialization for using your data model with a specific datastore. @@ -21,46 +21,21 @@ def serialize(self, **kwargs: Any) -> Any: """Serialize the object to the format required by the data store.""" ... # pragma: no cover - @classmethod - def deserialize(cls: type[TModel], obj: Any, **kwargs: Any) -> TModel: - """Deserialize the output of the data store to an object.""" - ... # pragma: no cover - @experimental_class @runtime_checkable -class VectorStoreModelPydanticProtocol(Protocol): - """Class used internally to make sure a datamodel has model_dump and model_validate.""" - - def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]: - """Serialize the object to the format required by the data store.""" - ... # pragma: no cover - - @classmethod - def model_validate(cls: type[TModel], *args: Any, **kwargs: Any) -> TModel: - """Deserialize the output of the data store to an object.""" - ... # pragma: no cover - - -@experimental_class -@runtime_checkable -class VectorStoreModelToDictFromDictProtocol(Protocol): - """Class used internally to check if a model has to_dict and from_dict methods.""" +class ToDictMethodProtocol(Protocol): + """Class used internally to check if a model has a to_dict method.""" def to_dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: """Serialize the object to the format required by the data store.""" ... # pragma: no cover - @classmethod - def from_dict(cls: type[TModel], *args: Any, **kwargs: Any) -> TModel: - """Deserialize the output of the data store to an object.""" - ... # pragma: no cover - @experimental_class @runtime_checkable -class ToDictProtocol(Protocol): - """Protocol for to_dict method. +class ToDictFunctionProtocol(Protocol): + """Protocol for to_dict function. Args: record: The record to be serialized. @@ -75,8 +50,8 @@ def __call__(self, record: Any, **kwargs: Any) -> Sequence[dict[str, Any]]: ... @experimental_class @runtime_checkable -class FromDictProtocol(Protocol): - """Protocol for from_dict method. +class FromDictFunctionProtocol(Protocol): + """Protocol for from_dict function. Args: records: A list of dictionaries. @@ -91,8 +66,8 @@ def __call__(self, records: Sequence[dict[str, Any]], **kwargs: Any) -> Any: ... @experimental_class @runtime_checkable -class SerializeProtocol(Protocol): - """Protocol for serialize method. +class SerializeFunctionProtocol(Protocol): + """Protocol for serialize function. Args: record: The record to be serialized. @@ -108,8 +83,8 @@ def __call__(self, record: Any, **kwargs: Any) -> Any: ... # noqa: D102 @experimental_class @runtime_checkable -class DeserializeProtocol(Protocol): - """Protocol for deserialize method. +class DeserializeFunctionProtocol(Protocol): + """Protocol for deserialize function. Args: records: The serialized record directly from the store. diff --git a/python/semantic_kernel/data/record_definition/vector_store_record_fields.py b/python/semantic_kernel/data/record_definition/vector_store_record_fields.py index d3b782c49912..536482b1069d 100644 --- a/python/semantic_kernel/data/record_definition/vector_store_record_fields.py +++ b/python/semantic_kernel/data/record_definition/vector_store_record_fields.py @@ -17,7 +17,7 @@ class VectorStoreRecordField(ABC): """Base class for all Vector Store Record Fields.""" - name: str | None = None + name: str = "" property_type: str | None = None diff --git a/python/semantic_kernel/data/record_definition/vector_store_record_utils.py b/python/semantic_kernel/data/record_definition/vector_store_record_utils.py index 96494ff6239b..00436fe8e199 100644 --- a/python/semantic_kernel/data/record_definition/vector_store_record_utils.py +++ b/python/semantic_kernel/data/record_definition/vector_store_record_utils.py @@ -8,7 +8,7 @@ VectorStoreRecordDataField, VectorStoreRecordVectorField, ) -from semantic_kernel.exceptions.memory_connector_exceptions import VectorStoreModelException +from semantic_kernel.exceptions import VectorStoreModelException from semantic_kernel.kernel_types import OneOrMany from semantic_kernel.utils.experimental_decorator import experimental_class diff --git a/python/semantic_kernel/data/search_filter.py b/python/semantic_kernel/data/search_filter.py index a6b844e141f8..4d0d84b5a7b9 100644 --- a/python/semantic_kernel/data/search_filter.py +++ b/python/semantic_kernel/data/search_filter.py @@ -4,9 +4,9 @@ from typing import TypeVar if sys.version_info >= (3, 11): - from typing import Self + from typing import Self # pragma: no cover else: - from typing_extensions import Self + from typing_extensions import Self # pragma: no cover from semantic_kernel.data.filter_clauses.equal_to_filter_clause import EqualTo from semantic_kernel.data.filter_clauses.filter_clause_base import FilterClauseBase diff --git a/python/semantic_kernel/data/text_search/text_search.py b/python/semantic_kernel/data/text_search/text_search.py index ff7b6c416435..d40f7169786c 100644 --- a/python/semantic_kernel/data/text_search/text_search.py +++ b/python/semantic_kernel/data/text_search/text_search.py @@ -294,7 +294,7 @@ async def _map_results( return [self._default_map_to_string(result) async for result in results.results] @staticmethod - def _default_map_to_string(result: Any) -> str: + def _default_map_to_string(result: BaseModel | object) -> str: """Default mapping function for text search results.""" if isinstance(result, BaseModel): return result.model_dump_json() diff --git a/python/semantic_kernel/data/text_search/utils.py b/python/semantic_kernel/data/text_search/utils.py index eb60f87b3d82..44e432f1ec36 100644 --- a/python/semantic_kernel/data/text_search/utils.py +++ b/python/semantic_kernel/data/text_search/utils.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. +import logging +from contextlib import suppress from typing import TYPE_CHECKING, Any, Protocol from pydantic import ValidationError @@ -8,6 +10,8 @@ from semantic_kernel.data.search_options import SearchOptions from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata +logger = logging.getLogger(__name__) + class OptionsUpdateFunctionType(Protocol): """Type definition for the options update function in Text Search.""" @@ -20,7 +24,7 @@ def __call__( **kwargs: Any, ) -> tuple[str, "SearchOptions"]: """Signature of the function.""" - ... + ... # pragma: no cover def create_options( @@ -44,30 +48,24 @@ def create_options( SearchOptions: The options. """ + new_options = options_class() if options: if not isinstance(options, options_class): + inputs = None try: - # Validate the options in one go - new_options = options_class.model_validate( - options.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True), - ) - except ValidationError: - # if that fails, go one by one - new_options = options_class() - for key, value in options.model_dump( - exclude_none=True, exclude_defaults=True, exclude_unset=True - ).items(): - setattr(new_options, key, value) + inputs = options.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True) + except Exception: + logger.warning("Options are not valid. Creating new options.") + if inputs: + new_options = options_class.model_validate(inputs) else: new_options = options for key, value in kwargs.items(): if key in new_options.model_fields: setattr(new_options, key, value) else: - try: + with suppress(ValidationError): new_options = options_class(**kwargs) - except ValidationError: - new_options = options_class() return new_options diff --git a/python/semantic_kernel/data/text_search/vector_store_text_search.py b/python/semantic_kernel/data/text_search/vector_store_text_search.py index dd2242650132..688e6f763e71 100644 --- a/python/semantic_kernel/data/text_search/vector_store_text_search.py +++ b/python/semantic_kernel/data/text_search/vector_store_text_search.py @@ -14,7 +14,7 @@ from semantic_kernel.data.vector_search.vector_text_search import VectorTextSearchMixin from semantic_kernel.data.vector_search.vectorizable_text_search import VectorizableTextSearchMixin from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin -from semantic_kernel.exceptions import VectorSearchExecutionException, VectorStoreTextSearchValidationError +from semantic_kernel.exceptions import VectorSearchExecutionException, VectorStoreInitializationException from semantic_kernel.kernel_pydantic import KernelBaseModel if TYPE_CHECKING: @@ -59,11 +59,11 @@ def _validate_stores(cls, data: dict[str, Any]) -> dict[str, Any]: and not data.get("vectorized_search") and not data.get("vector_text_search") ): - raise VectorStoreTextSearchValidationError( + raise VectorStoreInitializationException( "At least one of vectorizable_text_search, vectorized_search or vector_text_search is required." ) if data.get("vectorized_search") and not data.get("embedding_service"): - raise VectorStoreTextSearchValidationError("embedding_service is required when using vectorized_search.") + raise VectorStoreInitializationException("embedding_service is required when using vectorized_search.") return data @classmethod diff --git a/python/semantic_kernel/data/vector_search/vector_search.py b/python/semantic_kernel/data/vector_search/vector_search.py index 935a3b02464f..166676136ef9 100644 --- a/python/semantic_kernel/data/vector_search/vector_search.py +++ b/python/semantic_kernel/data/vector_search/vector_search.py @@ -10,6 +10,7 @@ from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection +from semantic_kernel.exceptions import VectorStoreModelDeserializationException from semantic_kernel.utils.experimental_decorator import experimental_class from semantic_kernel.utils.list_handler import desync_list @@ -57,6 +58,9 @@ async def _inner_search( The implementation of this method must deal with the possibility that multiple search contents are provided, and should handle them in a way that makes sense for that particular store. + The public methods will catch and reraise the three exceptions mentioned below, others are caught and turned + into a VectorSearchExecutionException. + Args: options: The search options, can be None. search_text: The text to search for, optional. @@ -67,6 +71,11 @@ async def _inner_search( Returns: The search results, wrapped in a KernelSearchResults object. + Raises: + VectorSearchExecutionException: If an error occurs during the search. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + VectorSearchOptionsException: If the search options are invalid. + """ ... @@ -106,9 +115,16 @@ async def _get_vector_search_results_from_results( if isinstance(results, Sequence): results = desync_list(results) async for result in results: - record = self.deserialize( - self._get_record_from_result(result), include_vectors=options.include_vectors if options else True - ) + try: + record = self.deserialize( + self._get_record_from_result(result), include_vectors=options.include_vectors if options else True + ) + except VectorStoreModelDeserializationException: + raise + except Exception as exc: + raise VectorStoreModelDeserializationException( + f"An error occurred while deserializing the record: {exc}" + ) from exc score = self._get_score_from_result(result) if record: # single records are always returned as single records by the deserializer diff --git a/python/semantic_kernel/data/vector_search/vector_search_filter.py b/python/semantic_kernel/data/vector_search/vector_search_filter.py index c7bd99d6797e..6944fe69ba4d 100644 --- a/python/semantic_kernel/data/vector_search/vector_search_filter.py +++ b/python/semantic_kernel/data/vector_search/vector_search_filter.py @@ -3,9 +3,9 @@ import sys if sys.version_info >= (3, 11): - from typing import Self + from typing import Self # pragma: no cover else: - from typing_extensions import Self + from typing_extensions import Self # pragma: no cover from semantic_kernel.data.filter_clauses.any_tags_equal_to_filter_clause import AnyTagsEqualTo from semantic_kernel.data.search_filter import SearchFilter diff --git a/python/semantic_kernel/data/vector_search/vector_text_search.py b/python/semantic_kernel/data/vector_search/vector_text_search.py index a5445c62e16c..f2a29b2908b8 100644 --- a/python/semantic_kernel/data/vector_search/vector_text_search.py +++ b/python/semantic_kernel/data/vector_search/vector_text_search.py @@ -5,7 +5,12 @@ from semantic_kernel.data.search_options import SearchOptions from semantic_kernel.data.text_search.utils import create_options -from semantic_kernel.exceptions import VectorStoreMixinException +from semantic_kernel.exceptions import ( + VectorSearchExecutionException, + VectorSearchOptionsException, + VectorStoreMixinException, + VectorStoreModelDeserializationException, +) from semantic_kernel.utils.experimental_decorator import experimental_class if TYPE_CHECKING: @@ -35,8 +40,10 @@ async def text_search( **kwargs: if options are not set, this is used to create them. Raises: - VectorSearchOptionsException: raised when the options given are not correct. - SearchResultEmptyError: raised when there are no results returned. + VectorSearchExecutionException: If an error occurs during the search. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + VectorSearchOptionsException: If the search options are invalid. + VectorStoreMixinException: raised when the method is not used in combination with the VectorSearchBase. """ from semantic_kernel.data.vector_search.vector_search import VectorSearchBase @@ -44,4 +51,9 @@ async def text_search( if not isinstance(self, VectorSearchBase): raise VectorStoreMixinException("This method can only be used in combination with the VectorSearchBase.") options = create_options(self.options_class, options, **kwargs) - return await self._inner_search(search_text=search_text, options=options) # type: ignore + try: + return await self._inner_search(search_text=search_text, options=options) # type: ignore + except (VectorStoreModelDeserializationException, VectorSearchOptionsException, VectorSearchExecutionException): + raise # pragma: no cover + except Exception as exc: + raise VectorSearchExecutionException(f"An error occurred during the search: {exc}") from exc diff --git a/python/semantic_kernel/data/vector_search/vectorizable_text_search.py b/python/semantic_kernel/data/vector_search/vectorizable_text_search.py index 76e653a5b053..9c5b882cf6f4 100644 --- a/python/semantic_kernel/data/vector_search/vectorizable_text_search.py +++ b/python/semantic_kernel/data/vector_search/vectorizable_text_search.py @@ -4,7 +4,12 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar from semantic_kernel.data.text_search.utils import create_options -from semantic_kernel.exceptions import VectorStoreMixinException +from semantic_kernel.exceptions import ( + VectorSearchExecutionException, + VectorSearchOptionsException, + VectorStoreMixinException, + VectorStoreModelDeserializationException, +) from semantic_kernel.utils.experimental_decorator import experimental_class if TYPE_CHECKING: @@ -40,8 +45,10 @@ async def vectorizable_text_search( **kwargs: if options are not set, this is used to create them. Raises: - VectorSearchOptionsException: raised when the options given are not correct. - SearchResultEmptyError: raised when there are no results returned. + VectorSearchExecutionException: If an error occurs during the search. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + VectorSearchOptionsException: If the search options are invalid. + VectorStoreMixinException: raised when the method is not used in combination with the VectorSearchBase. """ from semantic_kernel.data.vector_search.vector_search import VectorSearchBase @@ -49,4 +56,9 @@ async def vectorizable_text_search( if not isinstance(self, VectorSearchBase): raise VectorStoreMixinException("This method can only be used in combination with the VectorSearchBase.") options = create_options(self.options_class, options, **kwargs) - return await self._inner_search(vectorizable_text=vectorizable_text, options=options) # type: ignore + try: + return await self._inner_search(vectorizable_text=vectorizable_text, options=options) # type: ignore + except (VectorStoreModelDeserializationException, VectorSearchOptionsException, VectorSearchExecutionException): + raise # pragma: no cover + except Exception as exc: + raise VectorSearchExecutionException(f"An error occurred during the search: {exc}") from exc diff --git a/python/semantic_kernel/data/vector_search/vectorized_search.py b/python/semantic_kernel/data/vector_search/vectorized_search.py index ae170840eb8a..1b3e5aa25f9e 100644 --- a/python/semantic_kernel/data/vector_search/vectorized_search.py +++ b/python/semantic_kernel/data/vector_search/vectorized_search.py @@ -4,7 +4,12 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar from semantic_kernel.data.text_search.utils import create_options -from semantic_kernel.exceptions import VectorStoreMixinException +from semantic_kernel.exceptions import ( + VectorSearchExecutionException, + VectorSearchOptionsException, + VectorStoreMixinException, + VectorStoreModelDeserializationException, +) from semantic_kernel.utils.experimental_decorator import experimental_class if TYPE_CHECKING: @@ -35,8 +40,10 @@ async def vectorized_search( **kwargs: if options are not set, this is used to create them. Raises: - VectorSearchOptionsException: raised when the options given are not correct. - SearchResultEmptyError: raised when there are no results returned. + VectorSearchExecutionException: If an error occurs during the search. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + VectorSearchOptionsException: If the search options are invalid. + VectorStoreMixinException: raised when the method is not used in combination with the VectorSearchBase. """ from semantic_kernel.data.vector_search.vector_search import VectorSearchBase @@ -44,4 +51,9 @@ async def vectorized_search( if not isinstance(self, VectorSearchBase): raise VectorStoreMixinException("This method can only be used in combination with the VectorSearchBase.") options = create_options(self.options_class, options, **kwargs) - return await self._inner_search(vector=vector, options=options) # type: ignore + try: + return await self._inner_search(vector=vector, options=options) # type: ignore + except (VectorStoreModelDeserializationException, VectorSearchOptionsException, VectorSearchExecutionException): + raise # pragma: no cover + except Exception as exc: + raise VectorSearchExecutionException(f"An error occurred during the search: {exc}") from exc diff --git a/python/semantic_kernel/data/vector_storage/vector_store.py b/python/semantic_kernel/data/vector_storage/vector_store.py index 5608e6ec174d..796973a63854 100644 --- a/python/semantic_kernel/data/vector_storage/vector_store.py +++ b/python/semantic_kernel/data/vector_storage/vector_store.py @@ -48,4 +48,4 @@ async def __aexit__(self, exc_type, exc_value, traceback) -> None: If the client is passed in the constructor, it should not be closed, in that case the managed_client should be set to False. """ - pass + pass # pragma: no cover diff --git a/python/semantic_kernel/data/vector_storage/vector_store_record_collection.py b/python/semantic_kernel/data/vector_storage/vector_store_record_collection.py index 5c239db849b6..2774ebfb2fd8 100644 --- a/python/semantic_kernel/data/vector_storage/vector_store_record_collection.py +++ b/python/semantic_kernel/data/vector_storage/vector_store_record_collection.py @@ -7,14 +7,18 @@ from collections.abc import Awaitable, Callable, Mapping, Sequence from typing import Any, ClassVar, Generic, TypeVar -from pydantic import model_validator +from pydantic import BaseModel, model_validator from semantic_kernel.data.record_definition.vector_store_model_definition import VectorStoreRecordDefinition -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, +from semantic_kernel.data.record_definition.vector_store_model_protocols import ( + SerializeMethodProtocol, + ToDictMethodProtocol, +) +from semantic_kernel.exceptions import ( VectorStoreModelDeserializationException, VectorStoreModelSerializationException, VectorStoreModelValidationError, + VectorStoreOperationException, ) from semantic_kernel.kernel_pydantic import KernelBaseModel from semantic_kernel.kernel_types import OneOrMany @@ -96,6 +100,15 @@ async def _inner_upsert( Returns: The keys of the upserted records. + + Raises: + Exception: If an error occurs during the upsert. + There is no need to catch and parse exceptions in the inner functions, + they are handled by the public methods. + The only exception is raises exceptions yourself, such as a ValueError. + This is then caught and turned into the relevant exception by the public method. + This setup promotes a limited depth of the stack trace. + """ ... # pragma: no cover @@ -109,6 +122,14 @@ async def _inner_get(self, keys: Sequence[TKey], **kwargs: Any) -> OneOrMany[Any Returns: The records from the store, not deserialized. + + Raises: + Exception: If an error occurs during the upsert. + There is no need to catch and parse exceptions in the inner functions, + they are handled by the public methods. + The only exception is raises exceptions yourself, such as a ValueError. + This is then caught and turned into the relevant exception by the public method. + This setup promotes a limited depth of the stack trace. """ ... # pragma: no cover @@ -119,6 +140,14 @@ async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None: Args: keys: The keys. **kwargs: Additional arguments. + + Raises: + Exception: If an error occurs during the upsert. + There is no need to catch and parse exceptions in the inner functions, + they are handled by the public methods. + The only exception is raises exceptions yourself, such as a ValueError. + This is then caught and turned into the relevant exception by the public method. + This setup promotes a limited depth of the stack trace. """ ... # pragma: no cover @@ -184,17 +213,39 @@ async def create_collection_if_not_exists(self, **kwargs: Any) -> bool: @abstractmethod async def create_collection(self, **kwargs: Any) -> None: - """Create the collection in the service.""" + """Create the collection in the service. + + This should be overridden by the child class. + + Raises: + Make sure the implementation of this function raises relevant exceptions with good descriptions. + This is different then the `_inner_x` methods, as this is a public method. + + """ ... # pragma: no cover @abstractmethod async def does_collection_exist(self, **kwargs: Any) -> bool: - """Check if the collection exists.""" + """Check if the collection exists. + + This should be overridden by the child class. + + Raises: + Make sure the implementation of this function raises relevant exceptions with good descriptions. + This is different then the `_inner_x` methods, as this is a public method. + """ ... # pragma: no cover @abstractmethod async def delete_collection(self, **kwargs: Any) -> None: - """Delete the collection.""" + """Delete the collection. + + This should be overridden by the child class. + + Raises: + Make sure the implementation of this function raises relevant exceptions with good descriptions. + This is different then the `_inner_x` methods, as this is a public method. + """ ... # pragma: no cover # region Public Methods @@ -223,19 +274,24 @@ async def upsert( Returns: The key of the upserted record or a list of keys, when a container type is used. + + Raises: + VectorStoreModelSerializationException: If an error occurs during serialization. + VectorStoreOperationException: If an error occurs during upserting. """ if embedding_generation_function: record = await embedding_generation_function(record, self.data_model_type, self.data_model_definition) try: data = self.serialize(record) - except Exception as exc: - raise MemoryConnectorException(f"Error serializing record: {exc}") from exc + # the serialize method will parse any exception into a VectorStoreModelSerializationException + except VectorStoreModelSerializationException: + raise try: results = await self._inner_upsert(data if isinstance(data, Sequence) else [data], **kwargs) except Exception as exc: - raise MemoryConnectorException(f"Error upserting record: {exc}") from exc + raise VectorStoreOperationException(f"Error upserting record: {exc}") from exc if self._container_mode: return results @@ -266,19 +322,24 @@ async def upsert_batch( Returns: Sequence[TKey]: The keys of the upserted records, this is always a list, corresponds to the input or the items in the container. + + Raises: + VectorStoreModelSerializationException: If an error occurs during serialization. + VectorStoreOperationException: If an error occurs during upserting. """ if embedding_generation_function: records = await embedding_generation_function(records, self.data_model_type, self.data_model_definition) try: data = self.serialize(records) - except Exception as exc: - raise MemoryConnectorException(f"Error serializing records: {exc}") from exc + # the serialize method will parse any exception into a VectorStoreModelSerializationException + except VectorStoreModelSerializationException: + raise try: return await self._inner_upsert(data, **kwargs) # type: ignore except Exception as exc: - raise MemoryConnectorException(f"Error upserting records: {exc}") from exc + raise VectorStoreOperationException(f"Error upserting records: {exc}") from exc async def get(self, key: TKey, include_vectors: bool = True, **kwargs: Any) -> TModel | None: """Get a record if the key exists. @@ -293,19 +354,24 @@ async def get(self, key: TKey, include_vectors: bool = True, **kwargs: Any) -> T Returns: TModel: The record. None if the key does not exist. + + Raises: + VectorStoreOperationException: If an error occurs during the get. + VectorStoreModelDeserializationException: If an error occurs during deserialization. """ try: records = await self._inner_get([key], include_vectors=include_vectors, **kwargs) except Exception as exc: - raise MemoryConnectorException(f"Error getting record: {exc}") from exc + raise VectorStoreOperationException(f"Error getting record: {exc}") from exc if not records: return None try: model_records = self.deserialize(records[0], **kwargs) - except Exception as exc: - raise MemoryConnectorException(f"Error deserializing record: {exc}") from exc + # the deserialize method will parse any exception into a VectorStoreModelDeserializationException + except VectorStoreModelDeserializationException: + raise # there are many code paths within the deserialize method, some supplied by the developer, # and so depending on what is used, @@ -316,7 +382,9 @@ async def get(self, key: TKey, include_vectors: bool = True, **kwargs: Any) -> T return model_records if len(model_records) == 1: return model_records[0] - raise MemoryConnectorException(f"Error deserializing record, multiple records returned: {model_records}") + raise VectorStoreModelDeserializationException( + f"Error deserializing record, multiple records returned: {model_records}" + ) async def get_batch( self, keys: Sequence[TKey], include_vectors: bool = True, **kwargs: Any @@ -333,19 +401,24 @@ async def get_batch( Returns: The records, either a list of TModel or the container type. + + Raises: + VectorStoreOperationException: If an error occurs during the get. + VectorStoreModelDeserializationException: If an error occurs during deserialization. """ try: records = await self._inner_get(keys, include_vectors=include_vectors, **kwargs) except Exception as exc: - raise MemoryConnectorException(f"Error getting records: {exc}") from exc + raise VectorStoreOperationException(f"Error getting records: {exc}") from exc if not records: return None try: return self.deserialize(records, keys=keys, **kwargs) - except Exception as exc: - raise MemoryConnectorException(f"Error deserializing record: {exc}") from exc + # the deserialize method will parse any exception into a VectorStoreModelDeserializationException + except VectorStoreModelDeserializationException: + raise async def delete(self, key: TKey, **kwargs: Any) -> None: """Delete a record. @@ -354,12 +427,12 @@ async def delete(self, key: TKey, **kwargs: Any) -> None: key: The key. **kwargs: Additional arguments. Exceptions: - MemoryConnectorException: If an error occurs during deletion or the record does not exist. + VectorStoreOperationException: If an error occurs during deletion or the record does not exist. """ try: await self._inner_delete([key], **kwargs) except Exception as exc: - raise MemoryConnectorException(f"Error deleting record: {exc}") from exc + raise VectorStoreOperationException(f"Error deleting record: {exc}") from exc async def delete_batch(self, keys: Sequence[TKey], **kwargs: Any) -> None: """Delete a batch of records. @@ -370,14 +443,14 @@ async def delete_batch(self, keys: Sequence[TKey], **kwargs: Any) -> None: keys: The keys. **kwargs: Additional arguments. Exceptions: - MemoryConnectorException: If an error occurs during deletion or a record does not exist. + VectorStoreOperationException: If an error occurs during deletion or a record does not exist. """ try: await self._inner_delete(keys, **kwargs) except Exception as exc: - raise MemoryConnectorException(f"Error deleting records: {exc}") from exc + raise VectorStoreOperationException(f"Error deleting records: {exc}") from exc - # region Internal Serialization methods + # region Serialization methods def serialize(self, records: OneOrMany[TModel], **kwargs: Any) -> OneOrMany[Any]: """Serialize the data model to the store model. @@ -391,45 +464,31 @@ def serialize(self, records: OneOrMany[TModel], **kwargs: Any) -> OneOrMany[Any] If overriding this method, make sure to first try to serialize the data model to the store model, before doing the store specific version, the user supplied version should have precedence. - """ - if serialized := self._serialize_data_model_to_store_model(records): - return serialized - - if isinstance(records, Sequence): - dict_records = [self._serialize_data_model_to_dict(rec) for rec in records] - return self._serialize_dicts_to_store_models(dict_records, **kwargs) # type: ignore - - dict_records = self._serialize_data_model_to_dict(records) # type: ignore - if isinstance(dict_records, Sequence): - # most likely this is a container, so we return all records as a list - # can also be a single record, but the to_dict returns a list - # hence we will treat it as a container. - return self._serialize_dicts_to_store_models(dict_records, **kwargs) # type: ignore - # this case is single record in, single record out - return self._serialize_dicts_to_store_models([dict_records], **kwargs)[0] - def deserialize(self, records: OneOrMany[Any | dict[str, Any]], **kwargs: Any) -> OneOrMany[TModel] | None: - """Deserialize the store model to the data model. + Raises: + VectorStoreModelSerializationException: If an error occurs during serialization. - This method follows the following steps: - 1. Check if the data model has a deserialize method. - Use that method to deserialize and return the result. - 2. Deserialize the store model to a dict, using the store specific method. - 3. Convert the dict to the data model, using the data model specific method. """ - if deserialized := self._deserialize_store_model_to_data_model(records, **kwargs): - return deserialized - - if isinstance(records, Sequence): - dict_records = self._deserialize_store_models_to_dicts(records, **kwargs) - if self._container_mode: - return self._deserialize_dict_to_data_model(dict_records, **kwargs) - return [self._deserialize_dict_to_data_model(rec, **kwargs) for rec in dict_records] - - dict_record = self._deserialize_store_models_to_dicts([records], **kwargs)[0] - if not dict_record: - return None - return self._deserialize_dict_to_data_model(dict_record, **kwargs) + try: + if serialized := self._serialize_data_model_to_store_model(records): + return serialized + + if isinstance(records, Sequence): + dict_records = [self._serialize_data_model_to_dict(rec) for rec in records] + return self._serialize_dicts_to_store_models(dict_records, **kwargs) # type: ignore + + dict_records = self._serialize_data_model_to_dict(records) # type: ignore + if isinstance(dict_records, Sequence): + # most likely this is a container, so we return all records as a list + # can also be a single record, but the to_dict returns a list + # hence we will treat it as a container. + return self._serialize_dicts_to_store_models(dict_records, **kwargs) # type: ignore + # this case is single record in, single record out + return self._serialize_dicts_to_store_models([dict_records], **kwargs)[0] + except VectorStoreModelSerializationException: + raise # pragma: no cover + except Exception as exc: + raise VectorStoreModelSerializationException(f"Error serializing records: {exc}") from exc def _serialize_data_model_to_store_model(self, record: OneOrMany[TModel], **kwargs: Any) -> OneOrMany[Any] | None: """Serialize the data model to the store model. @@ -445,33 +504,9 @@ def _serialize_data_model_to_store_model(self, record: OneOrMany[TModel], **kwar return None return result if self.data_model_definition.serialize: - return self.data_model_definition.serialize(record, **kwargs) # type: ignore - if hasattr(record, "serialize"): - try: - return record.serialize(**kwargs) - except Exception as exc: - raise VectorStoreModelSerializationException(f"Error serializing record: {exc}") from exc - return None - - def _deserialize_store_model_to_data_model(self, record: OneOrMany[Any], **kwargs: Any) -> OneOrMany[TModel] | None: - """Deserialize the store model to the data model. - - This works when the data model has supplied a deserialize method, specific to a data source. - This uses a method called 'deserialize()' on the data model or part of the vector store record definition. - - The developer is responsible for correctly deserializing for the specific data source. - """ - if self.data_model_definition.deserialize: - if isinstance(record, Sequence): - return self.data_model_definition.deserialize(record, **kwargs) - return self.data_model_definition.deserialize([record], **kwargs) - try: - if hasattr(self.data_model_type, "deserialize"): - if isinstance(record, Sequence): - return [self.data_model_type.deserialize(rec, **kwargs) for rec in record] # type: ignore - return self.data_model_type.deserialize(record, **kwargs) # type: ignore - except Exception as exc: - raise VectorStoreModelSerializationException(f"Error deserializing record: {exc}") from exc + return self.data_model_definition.serialize(record, **kwargs) + if isinstance(record, SerializeMethodProtocol): + return record.serialize(**kwargs) return None def _serialize_data_model_to_dict(self, record: TModel, **kwargs: Any) -> OneOrMany[dict[str, Any]]: @@ -483,44 +518,79 @@ def _serialize_data_model_to_dict(self, record: TModel, **kwargs: Any) -> OneOrM """ if self.data_model_definition.to_dict: return self.data_model_definition.to_dict(record, **kwargs) - if hasattr(record, "model_dump"): - try: - ret = record.model_dump() # type: ignore - if not any(field.serialize_function is not None for field in self.data_model_definition.vector_fields): - return ret - for field in self.data_model_definition.vector_fields: - if field.serialize_function: - assert field.name is not None # nosec - ret[field.name] = field.serialize_function(ret[field.name]) - return ret - except Exception as exc: - raise VectorStoreModelSerializationException(f"Error serializing record: {exc}") from exc - if hasattr(record, "to_dict"): - try: - ret = record.to_dict() # type: ignore - if not any(field.serialize_function is not None for field in self.data_model_definition.vector_fields): - return ret - for field in self.data_model_definition.vector_fields: - if field.serialize_function: - assert field.name is not None # nosec - ret[field.name] = field.serialize_function(ret[field.name]) - return ret - except Exception as exc: - raise VectorStoreModelSerializationException(f"Error serializing record: {exc}") from exc + if isinstance(record, BaseModel): + return self._serialize_vectors(record.model_dump()) + if isinstance(record, ToDictMethodProtocol): + return self._serialize_vectors(record.to_dict()) store_model = {} for field_name in self.data_model_definition.field_names: - try: - value = record[field_name] if isinstance(record, Mapping) else getattr(record, field_name) - if func := getattr(self.data_model_definition.fields[field_name], "serialize_function", None): - value = func(value) - store_model[field_name] = value - except (AttributeError, KeyError) as exc: - raise VectorStoreModelSerializationException( - f"Error serializing record, not able to get: {field_name}" - ) from exc + value = record[field_name] if isinstance(record, Mapping) else getattr(record, field_name) + if func := getattr(self.data_model_definition.fields[field_name], "serialize_function", None): + value = func(value) + store_model[field_name] = value return store_model + def _serialize_vectors(self, record: dict[str, Any]) -> dict[str, Any]: + for field in self.data_model_definition.vector_fields: + if field.serialize_function: + record[field.name or ""] = field.serialize_function(record[field.name or ""]) + return record + + # region Deserialization methods + + def deserialize(self, records: OneOrMany[Any | dict[str, Any]], **kwargs: Any) -> OneOrMany[TModel] | None: + """Deserialize the store model to the data model. + + This method follows the following steps: + 1. Check if the data model has a deserialize method. + Use that method to deserialize and return the result. + 2. Deserialize the store model to a dict, using the store specific method. + 3. Convert the dict to the data model, using the data model specific method. + + Raises: + VectorStoreModelDeserializationException: If an error occurs during deserialization. + """ + try: + if not records: + return None + if deserialized := self._deserialize_store_model_to_data_model(records, **kwargs): + return deserialized + + if isinstance(records, Sequence): + dict_records = self._deserialize_store_models_to_dicts(records, **kwargs) + return ( + self._deserialize_dict_to_data_model(dict_records, **kwargs) + if self._container_mode + else [self._deserialize_dict_to_data_model(rec, **kwargs) for rec in dict_records] + ) + + dict_record = self._deserialize_store_models_to_dicts([records], **kwargs)[0] + # regardless of mode, only 1 object is returned. + return self._deserialize_dict_to_data_model(dict_record, **kwargs) + except VectorStoreModelDeserializationException: + raise # pragma: no cover + except Exception as exc: + raise VectorStoreModelDeserializationException(f"Error deserializing records: {exc}") from exc + + def _deserialize_store_model_to_data_model(self, record: OneOrMany[Any], **kwargs: Any) -> OneOrMany[TModel] | None: + """Deserialize the store model to the data model. + + This works when the data model has supplied a deserialize method, specific to a data source. + This uses a method called 'deserialize()' on the data model or part of the vector store record definition. + + The developer is responsible for correctly deserializing for the specific data source. + """ + if self.data_model_definition.deserialize: + if isinstance(record, Sequence): + return self.data_model_definition.deserialize(record, **kwargs) + return self.data_model_definition.deserialize([record], **kwargs) + if func := getattr(self.data_model_type, "deserialize", None): + if isinstance(record, Sequence): + return [func(rec, **kwargs) for rec in record] + return func(record, **kwargs) + return None + def _deserialize_dict_to_data_model(self, record: OneOrMany[dict[str, Any]], **kwargs: Any) -> TModel: """This function is used if no deserialize method is found on the data model. @@ -541,45 +611,32 @@ def _deserialize_dict_to_data_model(self, record: OneOrMany[dict[str, Any]], **k "Cannot deserialize multiple records to a single record unless you are using a container." ) record = record[0] - if hasattr(self.data_model_type, "model_validate"): - try: - if not any(field.serialize_function is not None for field in self.data_model_definition.vector_fields): - return self.data_model_type.model_validate(record) # type: ignore - if include_vectors: - for field in self.data_model_definition.vector_fields: - if field.serialize_function: - record[field.name] = field.serialize_function(record[field.name]) # type: ignore - return self.data_model_type.model_validate(record) # type: ignore - except Exception as exc: - raise VectorStoreModelDeserializationException(f"Error deserializing record: {exc}") from exc - if hasattr(self.data_model_type, "from_dict"): - try: - if not any(field.serialize_function is not None for field in self.data_model_definition.vector_fields): - return self.data_model_type.from_dict(record) # type: ignore - if include_vectors: - for field in self.data_model_definition.vector_fields: - if field.serialize_function: - record[field.name] = field.serialize_function(record[field.name]) # type: ignore - return self.data_model_type.from_dict(record) # type: ignore - except Exception as exc: - raise VectorStoreModelDeserializationException(f"Error deserializing record: {exc}") from exc + if issubclass(self.data_model_type, BaseModel): + if include_vectors: + record = self._deserialize_vector(record) + return self.data_model_type.model_validate(record) # type: ignore + if func := getattr(self.data_model_type, "from_dict", None): + if include_vectors: + record = self._deserialize_vector(record) + return func(record) data_model_dict: dict[str, Any] = {} - for field_name in self.data_model_definition.fields: # type: ignore + for field_name in self.data_model_definition.fields: if not include_vectors and field_name in self.data_model_definition.vector_field_names: continue - try: - value = record[field_name] - if func := getattr(self.data_model_definition.fields[field_name], "deserialize_function", None): - value = func(value) - data_model_dict[field_name] = value - except KeyError as exc: - raise VectorStoreModelDeserializationException( - f"Error deserializing record, not able to get: {field_name}" - ) from exc + value = record[field_name] + if func := getattr(self.data_model_definition.fields[field_name], "deserialize_function", None): + value = func(value) + data_model_dict[field_name] = value if self.data_model_type is dict: return data_model_dict # type: ignore return self.data_model_type(**data_model_dict) + def _deserialize_vector(self, record: dict[str, Any]) -> dict[str, Any]: + for field in self.data_model_definition.vector_fields: + if field.deserialize_function: + record[field.name or ""] = field.deserialize_function(record[field.name or ""]) + return record + # region Internal Functions def __del__(self): diff --git a/python/semantic_kernel/exceptions/__init__.py b/python/semantic_kernel/exceptions/__init__.py index 9ed131971525..6667c4570b32 100644 --- a/python/semantic_kernel/exceptions/__init__.py +++ b/python/semantic_kernel/exceptions/__init__.py @@ -11,3 +11,4 @@ from semantic_kernel.exceptions.search_exceptions import * # noqa: F403 from semantic_kernel.exceptions.service_exceptions import * # noqa: F403 from semantic_kernel.exceptions.template_engine_exceptions import * # noqa: F403 +from semantic_kernel.exceptions.vector_store_exceptions import * # noqa: F403 diff --git a/python/semantic_kernel/exceptions/memory_connector_exceptions.py b/python/semantic_kernel/exceptions/memory_connector_exceptions.py index 16848f572025..d8b9bb736023 100644 --- a/python/semantic_kernel/exceptions/memory_connector_exceptions.py +++ b/python/semantic_kernel/exceptions/memory_connector_exceptions.py @@ -16,24 +16,6 @@ class MemoryConnectorConnectionException(MemoryConnectorException): pass -class VectorStoreModelException(MemoryConnectorException): - """Base class for all vector store model exceptions.""" - - pass - - -class VectorStoreModelSerializationException(VectorStoreModelException): - """An error occurred while serializing the vector store model.""" - - pass - - -class VectorStoreModelDeserializationException(VectorStoreModelException): - """An error occurred while deserializing the vector store model.""" - - pass - - class MemoryConnectorInitializationError(MemoryConnectorException): """An error occurred while initializing the memory connector.""" @@ -46,26 +28,9 @@ class MemoryConnectorResourceNotFound(MemoryConnectorException): pass -class VectorStoreModelValidationError(VectorStoreModelException): - """An error occurred while validating the vector store model.""" - - pass - - -class VectorStoreSearchError(MemoryConnectorException): - """An error occurred while searching the vector store model.""" - - pass - - __all__ = [ "MemoryConnectorConnectionException", "MemoryConnectorException", "MemoryConnectorInitializationError", "MemoryConnectorResourceNotFound", - "VectorStoreModelDeserializationException", - "VectorStoreModelException", - "VectorStoreModelSerializationException", - "VectorStoreModelValidationError", - "VectorStoreSearchError", ] diff --git a/python/semantic_kernel/exceptions/search_exceptions.py b/python/semantic_kernel/exceptions/search_exceptions.py index fdb0bae13284..456235e6b008 100644 --- a/python/semantic_kernel/exceptions/search_exceptions.py +++ b/python/semantic_kernel/exceptions/search_exceptions.py @@ -9,36 +9,12 @@ class SearchException(KernelException): pass -class VectorStoreMixinException(SearchException): - """Raised when a mixin is used without the VectorSearchBase Class.""" - - pass - - -class VectorStoreTextSearchValidationError(SearchException): - """An error occurred while validating the vector store text search model.""" - - pass - - class SearchResultEmptyError(SearchException): """Raised when there are no hits in the search results.""" pass -class VectorSearchExecutionException(SearchException): - """Raised when there is an error executing a VectorSearch function.""" - - pass - - -class VectorSearchOptionsException(SearchException): - """Raised when invalid options are given to a VectorSearch function.""" - - pass - - class TextSearchException(SearchException): """An error occurred while executing a text search function.""" @@ -56,8 +32,4 @@ class TextSearchOptionsException(SearchException): "SearchResultEmptyError", "TextSearchException", "TextSearchOptionsException", - "VectorSearchExecutionException", - "VectorSearchOptionsException", - "VectorStoreMixinException", - "VectorStoreTextSearchValidationError", ] diff --git a/python/semantic_kernel/exceptions/vector_store_exceptions.py b/python/semantic_kernel/exceptions/vector_store_exceptions.py new file mode 100644 index 000000000000..0f42939c6155 --- /dev/null +++ b/python/semantic_kernel/exceptions/vector_store_exceptions.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft. All rights reserved. + +from semantic_kernel.exceptions.kernel_exceptions import KernelException + + +class VectorStoreException(KernelException): + """Base class for all vector store exceptions.""" + + pass + + +class VectorStoreInitializationException(VectorStoreException): + """Class for all vector store initialization exceptions.""" + + pass + + +class VectorStoreModelException(VectorStoreException): + """Base class for all vector store model exceptions.""" + + pass + + +class VectorStoreModelSerializationException(VectorStoreModelException): + """An error occurred while serializing the vector store model.""" + + pass + + +class VectorStoreModelDeserializationException(VectorStoreModelException): + """An error occurred while deserializing the vector store model.""" + + pass + + +class VectorStoreModelValidationError(VectorStoreModelException): + """An error occurred while validating the vector store model.""" + + pass + + +class VectorStoreMixinException(VectorStoreException): + """Raised when a mixin is used without the VectorSearchBase Class.""" + + pass + + +class VectorStoreOperationException(VectorStoreException): + """An error occurred while performing an operation on the vector store record collection.""" + + pass + + +class VectorSearchExecutionException(VectorStoreOperationException): + """Raised when there is an error executing a VectorSearch function.""" + + pass + + +class VectorSearchOptionsException(VectorStoreOperationException): + """Raised when invalid options are given to a VectorSearch function.""" + + pass + + +__all__ = [ + "VectorSearchExecutionException", + "VectorSearchOptionsException", + "VectorStoreException", + "VectorStoreInitializationException", + "VectorStoreMixinException", + "VectorStoreModelDeserializationException", + "VectorStoreModelException", + "VectorStoreModelException", + "VectorStoreModelSerializationException", + "VectorStoreModelValidationError", + "VectorStoreOperationException", +] diff --git a/python/tests/conftest.py b/python/tests/conftest.py index e6a01549f020..aba80113c333 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -195,44 +195,6 @@ def prompt() -> str: return "test prompt" -# @fixture(autouse=True) -# def enable_debug_mode(): -# """Set `autouse=True` to enable easy debugging for tests. - -# How to debug: -# 1. Ensure [snoop](https://github.com/alexmojaki/snoop) is installed -# (`pip install snoop`). -# 2. If you're doing print based debugging, use `pr` instead of `print`. -# That is, convert `print(some_var)` to `pr(some_var)`. -# 3. If you want a trace of a particular functions calls, just add `ss()` as the first -# line of the function. - -# Note: -# ---- -# It's completely fine to leave `autouse=True` in the fixture. It doesn't affect -# the tests unless you use `pr` or `ss` in any test. - -# Note: -# ---- -# When you use `ss` or `pr` in a test, pylance or mypy will complain. This is -# because they don't know that we're adding these functions to the builtins. The -# tests will run fine though. -# """ -# import builtins - -# try: -# import snoop -# except ImportError: -# warnings.warn( -# "Install snoop to enable trace debugging. `pip install snoop`", -# ImportWarning, -# ) -# return - -# builtins.ss = snoop.snoop(depth=4).__enter__ -# builtins.pr = snoop.pp - - @fixture def exclude_list(request): """Fixture that returns a list of environment variables to exclude.""" diff --git a/python/tests/integration/completions/chat_completion_test_base.py b/python/tests/integration/completions/chat_completion_test_base.py index a31882951c9b..d05157e607c5 100644 --- a/python/tests/integration/completions/chat_completion_test_base.py +++ b/python/tests/integration/completions/chat_completion_test_base.py @@ -118,11 +118,12 @@ def services(self) -> dict[str, tuple[ServiceType | None, type[PromptExecutionSe default_headers={"Test-User-X-ID": "test"}, ), ) + assert deployment_name azure_ai_inference_client = AzureAIInferenceChatCompletion( ai_model_id=deployment_name, client=ChatCompletionsClient( endpoint=f"{endpoint.strip('/')}/openai/deployments/{deployment_name}", - credential=DefaultAzureCredential(), + credential=DefaultAzureCredential(), # type: ignore credential_scopes=["https://cognitiveservices.azure.com/.default"], ), ) @@ -190,7 +191,7 @@ def setup(self, kernel: Kernel): async def get_chat_completion_response( self, kernel: Kernel, - service: ChatCompletionClientBase, + service: ServiceType, execution_settings: PromptExecutionSettings, chat_history: ChatHistory, stream: bool, @@ -204,6 +205,7 @@ async def get_chat_completion_response( input (str): Input string. stream (bool): Stream flag. """ + assert isinstance(service, ChatCompletionClientBase) if not stream: return await service.get_chat_message_content( chat_history, diff --git a/python/tests/integration/completions/test_chat_completion_with_function_calling.py b/python/tests/integration/completions/test_chat_completion_with_function_calling.py index b6d83c6d0735..6760eb5a1b02 100644 --- a/python/tests/integration/completions/test_chat_completion_with_function_calling.py +++ b/python/tests/integration/completions/test_chat_completion_with_function_calling.py @@ -971,6 +971,7 @@ async def test_streaming_completion( def evaluate(self, test_target: Any, **kwargs): inputs = kwargs.get("inputs") test_type = kwargs.get("test_type") + assert isinstance(inputs, list) if test_type == FunctionChoiceTestTypes.AUTO: self._evaluate_auto_function_choice(test_target, inputs) @@ -1060,7 +1061,7 @@ async def _test_helper( self.setup(kernel) - cmc = await retry( + cmc: ChatMessageContent | None = await retry( partial( self.get_chat_completion_response, kernel=kernel, @@ -1070,11 +1071,13 @@ async def _test_helper( stream=stream, ), retries=5, + name="function_calling", ) # We need to add the latest message to the history because the connector is # not responsible for updating the history, unless it is related to auto function # calling, when the history is updated after the function calls are invoked. - history.add_message(cmc) + if cmc: + history.add_message(cmc) self.evaluate(history, inputs=inputs, test_type=test_type) diff --git a/python/tests/integration/completions/test_chat_completion_with_image_input_text_output.py b/python/tests/integration/completions/test_chat_completion_with_image_input_text_output.py index 3f8eaa10ea82..f45f3367268c 100644 --- a/python/tests/integration/completions/test_chat_completion_with_image_input_text_output.py +++ b/python/tests/integration/completions/test_chat_completion_with_image_input_text_output.py @@ -263,7 +263,7 @@ async def test_completion( service_id: str, services: dict[str, tuple[ServiceType, type[PromptExecutionSettings]]], execution_settings_kwargs: dict[str, Any], - inputs: list[str | ChatMessageContent | list[ChatMessageContent]], + inputs: list[ChatMessageContent], kwargs: dict[str, Any], ): await self._test_helper( @@ -282,7 +282,7 @@ async def test_streaming_completion( service_id: str, services: dict[str, tuple[ServiceType, type[PromptExecutionSettings]]], execution_settings_kwargs: dict[str, Any], - inputs: list[str | ChatMessageContent | list[ChatMessageContent]], + inputs: list[ChatMessageContent], kwargs: dict[str, Any], ): await self._test_helper( @@ -297,6 +297,7 @@ async def test_streaming_completion( @override def evaluate(self, test_target: Any, **kwargs): inputs = kwargs.get("inputs") + assert isinstance(inputs, list) assert len(test_target) == len(inputs) * 2 for i in range(len(inputs)): message = test_target[i * 2 + 1] @@ -323,7 +324,7 @@ async def _test_helper( for message in inputs: history.add_message(message) - cmc = await retry( + cmc: ChatMessageContent | None = await retry( partial( self.get_chat_completion_response, kernel=kernel, @@ -333,7 +334,9 @@ async def _test_helper( stream=stream, ), retries=5, + name="image_input", ) - history.add_message(cmc) + if cmc: + history.add_message(cmc) self.evaluate(history.messages, inputs=inputs) diff --git a/python/tests/integration/completions/test_chat_completions.py b/python/tests/integration/completions/test_chat_completions.py index 810be08fd5e2..e3a77542f0f6 100644 --- a/python/tests/integration/completions/test_chat_completions.py +++ b/python/tests/integration/completions/test_chat_completions.py @@ -6,6 +6,11 @@ import pytest +if sys.version_info >= (3, 12): + from typing import override # pragma: no cover +else: + from typing_extensions import override # pragma: no cover + from semantic_kernel import Kernel from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings from semantic_kernel.contents import ChatMessageContent, TextContent @@ -24,11 +29,6 @@ from tests.integration.completions.completion_test_base import ServiceType from tests.utils import retry -if sys.version_info >= (3, 12): - from typing import override # pragma: no cover -else: - from typing_extensions import override # pragma: no cover - class Step(KernelBaseModel): explanation: str @@ -265,7 +265,7 @@ async def test_completion( service_id: str, services: dict[str, tuple[ServiceType, type[PromptExecutionSettings]]], execution_settings_kwargs: dict[str, Any], - inputs: list[str | ChatMessageContent | list[ChatMessageContent]], + inputs: list[ChatMessageContent], kwargs: dict[str, Any], ): await self._test_helper( @@ -284,7 +284,7 @@ async def test_streaming_completion( service_id: str, services: dict[str, tuple[ServiceType, type[PromptExecutionSettings]]], execution_settings_kwargs: dict[str, Any], - inputs: list[str | ChatMessageContent | list[ChatMessageContent]], + inputs: list[ChatMessageContent], kwargs: dict[str, Any], ): await self._test_helper( @@ -299,6 +299,7 @@ async def test_streaming_completion( @override def evaluate(self, test_target: Any, **kwargs): inputs = kwargs.get("inputs") + assert isinstance(inputs, list) assert len(test_target) == len(inputs) * 2 for i in range(len(inputs)): message = test_target[i * 2 + 1] @@ -325,7 +326,7 @@ async def _test_helper( for message in inputs: history.add_message(message) - cmc = await retry( + cmc: ChatMessageContent | None = await retry( partial( self.get_chat_completion_response, kernel=kernel, @@ -335,7 +336,9 @@ async def _test_helper( stream=stream, ), retries=5, + name="get_chat_completion_response", ) - history.add_message(cmc) + if cmc: + history.add_message(cmc) self.evaluate(history.messages, inputs=inputs) diff --git a/python/tests/integration/completions/test_conversation_summary_plugin.py b/python/tests/integration/completions/test_conversation_summary_plugin.py index a89749b2993f..8e466d6d4731 100644 --- a/python/tests/integration/completions/test_conversation_summary_plugin.py +++ b/python/tests/integration/completions/test_conversation_summary_plugin.py @@ -55,7 +55,7 @@ async def test_azure_summarize_conversation_using_plugin(kernel): ) prompt_template_config = PromptTemplateConfig( description="Given a section of a conversation transcript, summarize the part of the conversation.", - execution_settings=execution_settings, + execution_settings={service_id: execution_settings}, ) kernel.add_service(sk_oai.OpenAIChatCompletion(service_id=service_id)) diff --git a/python/tests/integration/completions/test_text_completion.py b/python/tests/integration/completions/test_text_completion.py index 9e9e93084ed0..3e9b34ef76aa 100644 --- a/python/tests/integration/completions/test_text_completion.py +++ b/python/tests/integration/completions/test_text_completion.py @@ -2,7 +2,7 @@ import platform import sys -from functools import partial, reduce +from functools import partial from typing import Any if sys.version_info >= (3, 12): @@ -27,8 +27,7 @@ ) from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings from semantic_kernel.connectors.ai.text_completion_client_base import TextCompletionClientBase -from semantic_kernel.contents.chat_message_content import ChatMessageContent -from semantic_kernel.contents.text_content import TextContent +from semantic_kernel.contents import StreamingTextContent, TextContent from semantic_kernel.utils.authentication.entra_id_authentication import get_entra_auth_token from tests.integration.completions.completion_test_base import CompletionTestBase, ServiceType from tests.utils import is_service_setup_for_testing, is_test_running_on_supported_platforms, retry @@ -275,7 +274,7 @@ def services(self) -> dict[str, tuple[ServiceType | None, type[PromptExecutionSe async def get_text_completion_response( self, - service: TextCompletionClientBase, + service: ServiceType, execution_settings: PromptExecutionSettings, prompt: str, stream: bool, @@ -289,21 +288,20 @@ async def get_text_completion_response( prompt (str): Input string. stream (bool): Stream flag. """ + assert isinstance(service, TextCompletionClientBase) if stream: response = service.get_streaming_text_content( prompt=prompt, settings=execution_settings, ) - parts = [part async for part in response] + parts: list[StreamingTextContent] = [part async for part in response if part is not None] if parts: - response = reduce(lambda p, r: p + r, parts) - else: - raise AssertionError("No response") - else: - response = await service.get_text_content( - prompt=prompt, - settings=execution_settings, - ) + return sum(parts[1:], parts[0]) + raise AssertionError("No response") + return await service.get_text_content( + prompt=prompt, + settings=execution_settings, + ) return response @@ -314,7 +312,7 @@ async def test_completion( service_id: str, services: dict[str, tuple[ServiceType, type[PromptExecutionSettings]]], execution_settings_kwargs: dict[str, Any], - inputs: list[str | ChatMessageContent | list[ChatMessageContent]], + inputs: list[str], kwargs: dict[str, Any], ) -> None: await self._test_helper(service_id, services, execution_settings_kwargs, inputs, False) @@ -326,7 +324,7 @@ async def test_streaming_completion( service_id: str, services: dict[str, tuple[ServiceType, type[PromptExecutionSettings]]], execution_settings_kwargs: dict[str, Any], - inputs: list[str | ChatMessageContent | list[ChatMessageContent]], + inputs: list[str], kwargs: dict[str, Any], ): if "streaming" in kwargs and not kwargs["streaming"]: @@ -364,5 +362,6 @@ async def _test_helper( stream=stream, ), retries=5, + name="text completions", ) self.evaluate(response) diff --git a/python/tests/unit/connectors/memory/azure_ai_search/test_azure_ai_search.py b/python/tests/unit/connectors/memory/azure_ai_search/test_azure_ai_search.py index 6bc6102215b5..3d87f26d0ccd 100644 --- a/python/tests/unit/connectors/memory/azure_ai_search/test_azure_ai_search.py +++ b/python/tests/unit/connectors/memory/azure_ai_search/test_azure_ai_search.py @@ -14,11 +14,11 @@ data_model_definition_to_azure_ai_search_index, get_search_index_client, ) -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, - MemoryConnectorInitializationError, +from semantic_kernel.exceptions import ( + ServiceInitializationError, + VectorStoreInitializationException, + VectorStoreOperationException, ) -from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError from semantic_kernel.utils.list_handler import desync_list BASE_PATH_SEARCH_CLIENT = "azure.search.documents.aio.SearchClient" @@ -104,7 +104,7 @@ def test_init_with_type(azure_ai_search_unit_test_env, data_model_type): @mark.parametrize("exclude_list", [["AZURE_AI_SEARCH_ENDPOINT"]], indirect=True) def test_init_endpoint_fail(azure_ai_search_unit_test_env, data_model_definition): - with raises(MemoryConnectorInitializationError): + with raises(VectorStoreInitializationException): AzureAISearchCollection( data_model_type=dict, data_model_definition=data_model_definition, env_file_path="test.env" ) @@ -112,7 +112,7 @@ def test_init_endpoint_fail(azure_ai_search_unit_test_env, data_model_definition @mark.parametrize("exclude_list", [["AZURE_AI_SEARCH_INDEX_NAME"]], indirect=True) def test_init_index_fail(azure_ai_search_unit_test_env, data_model_definition): - with raises(MemoryConnectorInitializationError): + with raises(VectorStoreInitializationException): AzureAISearchCollection( data_model_type=dict, data_model_definition=data_model_definition, env_file_path="test.env" ) @@ -161,7 +161,7 @@ def test_init_with_search_index_client(azure_ai_search_unit_test_env, data_model def test_init_with_search_index_client_fail(azure_ai_search_unit_test_env, data_model_definition): search_index_client = MagicMock(spec=SearchIndexClientWrapper) - with raises(MemoryConnectorInitializationError, match="Collection name is required."): + with raises(VectorStoreInitializationException, match="Collection name is required."): AzureAISearchCollection( data_model_type=dict, data_model_definition=data_model_definition, @@ -175,7 +175,7 @@ def test_init_with_clients_fail(azure_ai_search_unit_test_env, data_model_defini search_client._index_name = "test-index-name" with raises( - MemoryConnectorInitializationError, match="Search client and search index client have different index names." + VectorStoreInitializationException, match="Search client and search index client have different index names." ): AzureAISearchCollection( data_model_type=dict, @@ -233,7 +233,7 @@ async def test_create_index_from_definition(collection, mock_create_collection): async def test_create_index_from_index_fail(collection, mock_create_collection): index = Mock() - with raises(MemoryConnectorException): + with raises(VectorStoreOperationException): await collection.create_collection(index=index) @@ -247,7 +247,7 @@ def test_data_model_definition_to_azure_ai_search_index(data_model_definition): @mark.parametrize("exclude_list", [["AZURE_AI_SEARCH_ENDPOINT"]], indirect=True) async def test_vector_store_fail(azure_ai_search_unit_test_env): - with raises(MemoryConnectorInitializationError): + with raises(VectorStoreInitializationException): AzureAISearchStore(env_file_path="test.env") diff --git a/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_collection.py b/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_collection.py index 6c5dc4d1c2b4..e307e9bab8dc 100644 --- a/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_collection.py +++ b/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_collection.py @@ -15,11 +15,10 @@ create_default_indexing_policy, create_default_vector_embedding_policy, ) -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, - MemoryConnectorInitializationError, - MemoryConnectorResourceNotFound, +from semantic_kernel.exceptions import ( + VectorStoreInitializationException, ) +from semantic_kernel.exceptions.vector_store_exceptions import VectorStoreModelException, VectorStoreOperationException def test_azure_cosmos_db_no_sql_collection_init( @@ -74,7 +73,7 @@ def test_azure_cosmos_db_no_sql_collection_init_no_url( collection_name: str, ) -> None: """Test the initialization of an AzureCosmosDBNoSQLCollection object with missing URL.""" - with pytest.raises(MemoryConnectorInitializationError): + with pytest.raises(VectorStoreInitializationException): AzureCosmosDBNoSQLCollection( data_model_type=data_model_type, collection_name=collection_name, @@ -90,7 +89,7 @@ def test_azure_cosmos_db_no_sql_collection_init_no_database_name( ) -> None: """Test the initialization of an AzureCosmosDBNoSQLCollection object with missing database name.""" with pytest.raises( - MemoryConnectorInitializationError, match="The name of the Azure Cosmos DB NoSQL database is missing." + VectorStoreInitializationException, match="The name of the Azure Cosmos DB NoSQL database is missing." ): AzureCosmosDBNoSQLCollection( data_model_type=data_model_type, @@ -105,7 +104,7 @@ def test_azure_cosmos_db_no_sql_collection_invalid_settings( collection_name: str, ) -> None: """Test the initialization of an AzureCosmosDBNoSQLCollection object with invalid settings.""" - with pytest.raises(MemoryConnectorInitializationError): + with pytest.raises(VectorStoreInitializationException): AzureCosmosDBNoSQLCollection( data_model_type=data_model_type, collection_name=collection_name, @@ -204,7 +203,7 @@ async def test_azure_cosmos_db_no_sql_collection_create_database_raise_if_databa assert vector_collection.create_database is False - with pytest.raises(MemoryConnectorResourceNotFound): + with pytest.raises(VectorStoreOperationException): await vector_collection._get_database_proxy() @@ -325,7 +324,7 @@ async def test_azure_cosmos_db_no_sql_collection_create_collection_unsupported_v mock_database_proxy.create_container_if_not_exists = AsyncMock(return_value=None) - with pytest.raises(MemoryConnectorException): + with pytest.raises(VectorStoreModelException): await vector_collection.create_collection() @@ -367,7 +366,7 @@ async def test_azure_cosmos_db_no_sql_collection_delete_collection_fail( vector_collection._get_database_proxy = AsyncMock(return_value=mock_database_proxy) mock_database_proxy.delete_container = AsyncMock(side_effect=CosmosHttpResponseError) - with pytest.raises(MemoryConnectorException, match="Container could not be deleted."): + with pytest.raises(VectorStoreOperationException, match="Container could not be deleted."): await vector_collection.delete_collection() diff --git a/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_store.py b/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_store.py index 7f7f53d155b8..0514f9d8dc91 100644 --- a/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_store.py +++ b/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_store.py @@ -10,7 +10,7 @@ ) from semantic_kernel.connectors.memory.azure_cosmos_db.azure_cosmos_db_no_sql_store import AzureCosmosDBNoSQLStore from semantic_kernel.connectors.memory.azure_cosmos_db.utils import CosmosClientWrapper -from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorInitializationError +from semantic_kernel.exceptions import VectorStoreInitializationException def test_azure_cosmos_db_no_sql_store_init( @@ -43,7 +43,7 @@ def test_azure_cosmos_db_no_sql_store_init_no_url( azure_cosmos_db_no_sql_unit_test_env, ) -> None: """Test the initialization of an AzureCosmosDBNoSQLStore object with missing URL.""" - with pytest.raises(MemoryConnectorInitializationError): + with pytest.raises(VectorStoreInitializationException): AzureCosmosDBNoSQLStore(env_file_path="fake_path") @@ -53,7 +53,7 @@ def test_azure_cosmos_db_no_sql_store_init_no_database_name( ) -> None: """Test the initialization of an AzureCosmosDBNoSQLStore object with missing database name.""" with pytest.raises( - MemoryConnectorInitializationError, match="The name of the Azure Cosmos DB NoSQL database is missing." + VectorStoreInitializationException, match="The name of the Azure Cosmos DB NoSQL database is missing." ): AzureCosmosDBNoSQLStore(env_file_path="fake_path") @@ -62,7 +62,7 @@ def test_azure_cosmos_db_no_sql_store_invalid_settings( clear_azure_cosmos_db_no_sql_env, ) -> None: """Test the initialization of an AzureCosmosDBNoSQLStore object with invalid settings.""" - with pytest.raises(MemoryConnectorInitializationError, match="Failed to validate Azure Cosmos DB NoSQL settings."): + with pytest.raises(VectorStoreInitializationException, match="Failed to validate Azure Cosmos DB NoSQL settings."): AzureCosmosDBNoSQLStore(url="invalid_url") diff --git a/python/tests/unit/connectors/memory/qdrant/test_qdrant.py b/python/tests/unit/connectors/memory/qdrant/test_qdrant.py index c92571daf238..a9f4c69c97ae 100644 --- a/python/tests/unit/connectors/memory/qdrant/test_qdrant.py +++ b/python/tests/unit/connectors/memory/qdrant/test_qdrant.py @@ -11,12 +11,12 @@ from semantic_kernel.data.record_definition.vector_store_record_fields import VectorStoreRecordVectorField from semantic_kernel.data.vector_search.vector_search_filter import VectorSearchFilter from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, - MemoryConnectorInitializationError, +from semantic_kernel.exceptions import ( + VectorSearchExecutionException, + VectorStoreInitializationException, VectorStoreModelValidationError, + VectorStoreOperationException, ) -from semantic_kernel.exceptions.search_exceptions import VectorSearchExecutionException BASE_PATH = "qdrant_client.async_qdrant_client.AsyncQdrantClient" @@ -144,10 +144,10 @@ def test_vector_store_in_memory(qdrant_unit_test_env): def test_vector_store_fail(): - with raises(MemoryConnectorInitializationError, match="Failed to create Qdrant settings."): + with raises(VectorStoreInitializationException, match="Failed to create Qdrant settings."): QdrantStore(location="localhost", url="localhost", env_file_path="test.env") - with raises(MemoryConnectorInitializationError, match="Failed to create Qdrant client."): + with raises(VectorStoreInitializationException, match="Failed to create Qdrant client."): QdrantStore(location="localhost", url="http://localhost", env_file_path="test.env") @@ -180,7 +180,7 @@ async def test_collection_init(data_model_definition, qdrant_unit_test_env): def test_collection_init_fail(data_model_definition): - with raises(MemoryConnectorInitializationError, match="Failed to create Qdrant settings."): + with raises(VectorStoreInitializationException, match="Failed to create Qdrant settings."): QdrantCollection( data_model_type=dict, collection_name="test", @@ -188,7 +188,7 @@ def test_collection_init_fail(data_model_definition): url="localhost", env_file_path="test.env", ) - with raises(MemoryConnectorInitializationError, match="Failed to create Qdrant client."): + with raises(VectorStoreInitializationException, match="Failed to create Qdrant client."): QdrantCollection( data_model_type=dict, collection_name="test", @@ -274,7 +274,7 @@ async def test_create_index_with_named_vectors(collection_to_use, results, mock_ async def test_create_index_fail(collection_to_use, request): collection = request.getfixturevalue(collection_to_use) collection.data_model_definition.fields["vector"].dimensions = None - with raises(MemoryConnectorException, match="Vector field must have dimensions."): + with raises(VectorStoreOperationException, match="Vector field must have dimensions."): await collection.create_collection() diff --git a/python/tests/unit/connectors/memory/redis/test_redis_store.py b/python/tests/unit/connectors/memory/redis/test_redis_store.py index f62a8a34e468..0fb3b06efae4 100644 --- a/python/tests/unit/connectors/memory/redis/test_redis_store.py +++ b/python/tests/unit/connectors/memory/redis/test_redis_store.py @@ -9,9 +9,9 @@ from semantic_kernel.connectors.memory.redis.const import RedisCollectionTypes from semantic_kernel.connectors.memory.redis.redis_collection import RedisHashsetCollection, RedisJsonCollection from semantic_kernel.connectors.memory.redis.redis_store import RedisStore -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, - MemoryConnectorInitializationError, +from semantic_kernel.exceptions import ( + VectorStoreInitializationException, + VectorStoreOperationException, ) BASE_PATH = "redis.asyncio.client.Redis" @@ -153,7 +153,7 @@ def test_vector_store_with_client(redis_unit_test_env): @mark.parametrize("exclude_list", [["REDIS_CONNECTION_STRING"]], indirect=True) def test_vector_store_fail(redis_unit_test_env): - with raises(MemoryConnectorInitializationError, match="Failed to create Redis settings."): + with raises(VectorStoreInitializationException, match="Failed to create Redis settings."): RedisStore(env_file_path="test.env") @@ -223,14 +223,14 @@ def test_init_with_type(redis_unit_test_env, data_model_type, type_): @mark.parametrize("exclude_list", [["REDIS_CONNECTION_STRING"]], indirect=True) def test_collection_fail(redis_unit_test_env, data_model_definition): - with raises(MemoryConnectorInitializationError, match="Failed to create Redis settings."): + with raises(VectorStoreInitializationException, match="Failed to create Redis settings."): RedisHashsetCollection( data_model_type=dict, collection_name="test", data_model_definition=data_model_definition, env_file_path="test.env", ) - with raises(MemoryConnectorInitializationError, match="Failed to create Redis settings."): + with raises(VectorStoreInitializationException, match="Failed to create Redis settings."): RedisJsonCollection( data_model_type=dict, collection_name="test", @@ -326,5 +326,5 @@ async def test_create_index_manual(collection_hash, mock_create_collection): async def test_create_index_fail(collection_hash, mock_create_collection): - with raises(MemoryConnectorException, match="Invalid index type supplied."): + with raises(VectorStoreOperationException, match="Invalid index type supplied."): await collection_hash.create_collection(index_definition="index_definition", fields="fields") diff --git a/python/tests/unit/connectors/memory/weaviate/test_weaviate_collection.py b/python/tests/unit/connectors/memory/weaviate/test_weaviate_collection.py index 5ba167d5c6a8..4db7fc3a257e 100644 --- a/python/tests/unit/connectors/memory/weaviate/test_weaviate_collection.py +++ b/python/tests/unit/connectors/memory/weaviate/test_weaviate_collection.py @@ -9,11 +9,11 @@ from weaviate.collections.classes.data import DataObject from semantic_kernel.connectors.memory.weaviate.weaviate_collection import WeaviateCollection -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, - MemoryConnectorInitializationError, +from semantic_kernel.exceptions import ( + ServiceInvalidExecutionSettingsError, + VectorStoreInitializationException, + VectorStoreOperationException, ) -from semantic_kernel.exceptions.service_exceptions import ServiceInvalidExecutionSettingsError @patch( @@ -164,7 +164,7 @@ def test_weaviate_collection_init_fail_to_create_client( """Test the initialization of a WeaviateCollection object raises an error when failing to create a client.""" collection_name = "TestCollection" - with pytest.raises(MemoryConnectorInitializationError): + with pytest.raises(VectorStoreInitializationException): WeaviateCollection( data_model_type=data_model_type, data_model_definition=data_model_definition, @@ -262,7 +262,7 @@ async def test_weaviate_collection_create_collection_fail( env_file_path="fake_env_file_path.env", ) - with pytest.raises(MemoryConnectorException): + with pytest.raises(VectorStoreOperationException): await collection.create_collection() @@ -312,7 +312,7 @@ async def test_weaviate_collection_delete_collection_fail( env_file_path="fake_env_file_path.env", ) - with pytest.raises(MemoryConnectorException): + with pytest.raises(VectorStoreOperationException): await collection.delete_collection() @@ -362,7 +362,7 @@ async def test_weaviate_collection_collection_exist_fail( env_file_path="fake_env_file_path.env", ) - with pytest.raises(MemoryConnectorException): + with pytest.raises(VectorStoreOperationException): await collection.does_collection_exist() diff --git a/python/tests/unit/connectors/memory/weaviate/test_weaviate_store.py b/python/tests/unit/connectors/memory/weaviate/test_weaviate_store.py index d47f95fa1462..d50f024bf2ce 100644 --- a/python/tests/unit/connectors/memory/weaviate/test_weaviate_store.py +++ b/python/tests/unit/connectors/memory/weaviate/test_weaviate_store.py @@ -7,8 +7,7 @@ from weaviate import WeaviateAsyncClient from semantic_kernel.connectors.memory.weaviate.weaviate_store import WeaviateStore -from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorInitializationError -from semantic_kernel.exceptions.service_exceptions import ServiceInvalidExecutionSettingsError +from semantic_kernel.exceptions import ServiceInvalidExecutionSettingsError, VectorStoreInitializationException @patch.object( @@ -119,7 +118,7 @@ def test_weaviate_store_init_fail_to_create_client( data_model_definition, ) -> None: """Test the initialization of a WeaviateStore object raises an error when failing to create a client.""" - with pytest.raises(MemoryConnectorInitializationError): + with pytest.raises(VectorStoreInitializationException): WeaviateStore( local_host="localhost", env_file_path="fake_env_file_path.env", diff --git a/python/tests/unit/data/conftest.py b/python/tests/unit/data/conftest.py index 61110ebfcab4..8d926ad1d676 100644 --- a/python/tests/unit/data/conftest.py +++ b/python/tests/unit/data/conftest.py @@ -6,7 +6,8 @@ from typing import Annotated, Any import numpy as np -from pydantic import BaseModel, Field +from pandas import DataFrame +from pydantic import BaseModel, ConfigDict, Field from pytest import fixture from semantic_kernel.data import ( @@ -25,7 +26,7 @@ @fixture -def DictVectorStoreRecordCollection(): +def DictVectorStoreRecordCollection() -> type[VectorSearchBase]: class DictVectorStoreRecordCollection( VectorSearchBase[str, Any], VectorizedSearchMixin[Any], @@ -334,6 +335,24 @@ class DataModelClass(BaseModel): return DataModelClass +@fixture +def data_model_type_pydantic_array(): + @vectorstoremodel + class DataModelClass(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + content: Annotated[str, VectorStoreRecordDataField()] + vector: Annotated[ + np.ndarray, + VectorStoreRecordVectorField( + serialize_function=np.ndarray.tolist, + deserialize_function=np.array, + ), + ] + id: Annotated[str, VectorStoreRecordKeyField()] + + return DataModelClass + + @fixture def data_model_type_dataclass(): @vectorstoremodel @@ -344,3 +363,53 @@ class DataModelClass: id: Annotated[str, VectorStoreRecordKeyField()] return DataModelClass + + +@fixture(scope="function") +def vector_store_record_collection( + DictVectorStoreRecordCollection, + data_model_definition, + data_model_serialize_definition, + data_model_to_from_dict_definition, + data_model_container_definition, + data_model_container_serialize_definition, + data_model_pandas_definition, + data_model_type_vanilla, + data_model_type_vanilla_serialize, + data_model_type_vanilla_to_from_dict, + data_model_type_pydantic, + data_model_type_dataclass, + data_model_type_vector_array, + request, +) -> VectorSearchBase: + item = request.param if request and hasattr(request, "param") else "definition_basic" + defs = { + "definition_basic": data_model_definition, + "definition_with_serialize": data_model_serialize_definition, + "definition_with_to_from": data_model_to_from_dict_definition, + "definition_container": data_model_container_definition, + "definition_container_serialize": data_model_container_serialize_definition, + "definition_pandas": data_model_pandas_definition, + "type_vanilla": data_model_type_vanilla, + "type_vanilla_with_serialize": data_model_type_vanilla_serialize, + "type_vanilla_with_to_from_dict": data_model_type_vanilla_to_from_dict, + "type_pydantic": data_model_type_pydantic, + "type_dataclass": data_model_type_dataclass, + "type_vector_array": data_model_type_vector_array, + } + if item.endswith("pandas"): + return DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=DataFrame, + data_model_definition=defs[item], + ) + if item.startswith("definition_"): + return DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=defs[item], + ) + return DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=defs[item], + ) diff --git a/python/tests/unit/data/test_text_search.py b/python/tests/unit/data/test_text_search.py index 5b03b67e52e9..74a10909317b 100644 --- a/python/tests/unit/data/test_text_search.py +++ b/python/tests/unit/data/test_text_search.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +from collections.abc import AsyncGenerator from typing import Any from unittest.mock import patch @@ -21,6 +22,7 @@ ) from semantic_kernel.exceptions import TextSearchException from semantic_kernel.functions import KernelArguments, KernelParameterMetadata +from semantic_kernel.utils.list_handler import desync_list def test_text_search(): @@ -33,7 +35,7 @@ class TestSearch(TextSearch): async def search(self, **kwargs) -> KernelSearchResults[Any]: """Test search function.""" - async def generator() -> str: + async def generator() -> AsyncGenerator[str, None]: yield "test" return KernelSearchResults(results=generator(), metadata=kwargs) @@ -43,7 +45,7 @@ async def get_text_search_results( ) -> KernelSearchResults[TextSearchResult]: """Test get text search result function.""" - async def generator() -> TextSearchResult: + async def generator() -> AsyncGenerator[TextSearchResult, None]: yield TextSearchResult(value="test") return KernelSearchResults(results=generator(), metadata=kwargs) @@ -53,7 +55,7 @@ async def get_search_results( ) -> KernelSearchResults[Any]: """Test get search result function.""" - async def generator() -> str: + async def generator() -> AsyncGenerator[str, None]: yield "test" return KernelSearchResults(results=generator(), metadata=kwargs) @@ -190,12 +192,18 @@ async def test_create_kernel_function_inner_update_options(kernel: Kernel): called = False args = {} - def update_options(**kwargs: Any) -> tuple[str, SearchOptions]: - kwargs["options"].filter.equal_to("address/city", kwargs.get("city")) + def update_options( + query: str, + options: "SearchOptions", + parameters: list["KernelParameterMetadata"] | None = None, + **kwargs: Any, + ) -> tuple[str, SearchOptions]: + options.filter.equal_to("address/city", kwargs.get("city", "")) nonlocal called, args called = True - args = kwargs - return kwargs["query"], kwargs["options"] + args = {"query": query, "options": options, "parameters": parameters} + args.update(kwargs) + return query, options kernel_function = test_search._create_kernel_function( search_function="search", @@ -225,14 +233,29 @@ def update_options(**kwargs: Any) -> tuple[str, SearchOptions]: assert "parameters" in args -def test_default_map_to_string(): +async def test_default_map_to_string(): test_search = TestSearch() - assert test_search._default_map_to_string("test") == "test" + assert (await test_search._map_results(results=KernelSearchResults(results=desync_list(["test"])))) == ["test"] class TestClass(BaseModel): test: str - assert test_search._default_map_to_string(TestClass(test="test")) == '{"test":"test"}' + assert ( + await test_search._map_results(results=KernelSearchResults(results=desync_list([TestClass(test="test")]))) + ) == ['{"test":"test"}'] + + +async def test_custom_map_to_string(): + test_search = TestSearch() + + class TestClass(BaseModel): + test: str + + assert ( + await test_search._map_results( + results=KernelSearchResults(results=desync_list([TestClass(test="test")])), string_mapper=lambda x: x.test + ) + ) == ["test"] def test_create_options(): @@ -253,6 +276,27 @@ def test_create_options_none(): assert new_options.top == 1 +def test_create_options_vector_to_text(): + options = VectorSearchOptions(top=2, skip=1, include_vectors=True) + options_class = TextSearchOptions + new_options = create_options(options_class, options, top=1) + assert new_options is not None + assert isinstance(new_options, options_class) + assert new_options.top == 1 + assert getattr(new_options, "include_vectors", None) is None + + +def test_create_options_from_dict(): + options = {"skip": 1} + options_class = TextSearchOptions + new_options = create_options(options_class, options, top=1) # type: ignore + assert new_options is not None + assert isinstance(new_options, options_class) + assert new_options.top == 1 + # if a non SearchOptions object is passed in, it should be ignored + assert new_options.skip == 0 + + def test_default_options_update_function(): options = SearchOptions() params = [ @@ -267,3 +311,36 @@ def test_default_options_update_function(): assert options.filter.filters[0].value == "test" assert options.filter.filters[1].field_name == "test2" assert options.filter.filters[1].value == "test2" + + +def test_public_create_functions_search(): + test_search = TestSearch() + function = test_search.create_search() + assert function is not None + assert function.name == "search" + assert ( + function.description == "Perform a search for content related to the specified query and return string results" + ) + assert len(function.parameters) == 3 + + +def test_public_create_functions_get_text_search_results(): + test_search = TestSearch() + function = test_search.create_get_text_search_results() + assert function is not None + assert function.name == "search" + assert ( + function.description == "Perform a search for content related to the specified query and return string results" + ) + assert len(function.parameters) == 3 + + +def test_public_create_functions_get_search_results(): + test_search = TestSearch() + function = test_search.create_get_search_results() + assert function is not None + assert function.name == "search" + assert ( + function.description == "Perform a search for content related to the specified query and return string results" + ) + assert len(function.parameters) == 3 diff --git a/python/tests/unit/data/test_vector_search_base.py b/python/tests/unit/data/test_vector_search_base.py new file mode 100644 index 000000000000..d35c04e45a5b --- /dev/null +++ b/python/tests/unit/data/test_vector_search_base.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft. All rights reserved. + + +import pytest + +from semantic_kernel.data.vector_search.vector_search import VectorSearchBase +from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions +from semantic_kernel.exceptions.vector_store_exceptions import VectorStoreModelDeserializationException + + +async def test_search(vector_store_record_collection: VectorSearchBase): + record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} + await vector_store_record_collection.upsert(record) + results = await vector_store_record_collection._inner_search( + options=VectorSearchOptions(), search_text="test_content" + ) + records = [rec async for rec in results.results] + assert records[0].record == record + + +@pytest.mark.parametrize("include_vectors", [True, False]) +async def test_get_vector_search_results(vector_store_record_collection: VectorSearchBase, include_vectors: bool): + options = VectorSearchOptions(include_vectors=include_vectors) + results = [{"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]}] + async for result in vector_store_record_collection._get_vector_search_results_from_results( + results=results, options=options + ): + assert result.record == results[0] if include_vectors else {"id": "test_id", "content": "test_content"} + break + + +async def test_get_vector_search_results_fail(vector_store_record_collection: VectorSearchBase): + # vector_store_record_collection.data_model_type.serialize = MagicMock(side_effect=Exception) + options = VectorSearchOptions(include_vectors=True) + results = [{"id": "test_id", "content": "test_content"}] + with pytest.raises(VectorStoreModelDeserializationException): + async for result in vector_store_record_collection._get_vector_search_results_from_results( + results=results, options=options + ): + assert result.record == results[0] + break diff --git a/python/tests/unit/data/test_vector_search_mixins.py b/python/tests/unit/data/test_vector_search_mixins.py new file mode 100644 index 000000000000..0590b02bd41c --- /dev/null +++ b/python/tests/unit/data/test_vector_search_mixins.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft. All rights reserved. + +from pytest import raises + +from semantic_kernel.data import VectorizableTextSearchMixin, VectorizedSearchMixin, VectorTextSearchMixin +from semantic_kernel.exceptions import VectorStoreMixinException + + +class VectorTextSearchMixinTest(VectorTextSearchMixin): + """The mixin for text search, to be used in combination with VectorSearchBase.""" + + pass + + +class VectorizableTextSearchMixinTest(VectorizableTextSearchMixin): + """The mixin for text search, to be used in combination with VectorSearchBase.""" + + pass + + +class VectorizedSearchMixinTest(VectorizedSearchMixin): + """The mixin for text search, to be used in combination with VectorSearchBase.""" + + pass + + +async def test_text_search(): + test_instance = VectorTextSearchMixinTest() + assert test_instance is not None + with raises(VectorStoreMixinException): + await test_instance.text_search("test") + + +async def test_vectorizable_text_search(): + test_instance = VectorizableTextSearchMixinTest() + assert test_instance is not None + with raises(VectorStoreMixinException): + await test_instance.vectorizable_text_search("test") + + +async def test_vectorized_text_search(): + test_instance = VectorizedSearchMixinTest() + assert test_instance is not None + with raises(VectorStoreMixinException): + await test_instance.vectorized_search([1, 2, 3]) diff --git a/python/tests/unit/data/test_vector_store_model_decorator.py b/python/tests/unit/data/test_vector_store_model_decorator.py index 0a49530ecd4e..4f9707fe7032 100644 --- a/python/tests/unit/data/test_vector_store_model_decorator.py +++ b/python/tests/unit/data/test_vector_store_model_decorator.py @@ -206,19 +206,25 @@ class DataModelClass: key: Annotated[str, VectorStoreRecordKeyField()] list1: Annotated[list[int], VectorStoreRecordDataField()] list2: Annotated[list[str], VectorStoreRecordDataField] + list3: Annotated[list[str] | None, VectorStoreRecordDataField] dict1: Annotated[dict[str, int], VectorStoreRecordDataField()] dict2: Annotated[dict[str, str], VectorStoreRecordDataField] + dict3: Annotated[dict[str, str] | None, VectorStoreRecordDataField] assert hasattr(DataModelClass, "__kernel_vectorstoremodel__") assert hasattr(DataModelClass, "__kernel_vectorstoremodel_definition__") data_model_definition: VectorStoreRecordDefinition = DataModelClass.__kernel_vectorstoremodel_definition__ - assert len(data_model_definition.fields) == 5 + assert len(data_model_definition.fields) == 7 assert data_model_definition.fields["list1"].name == "list1" assert data_model_definition.fields["list1"].property_type == "list[int]" assert data_model_definition.fields["list2"].name == "list2" assert data_model_definition.fields["list2"].property_type == "list[str]" + assert data_model_definition.fields["list3"].name == "list3" + assert data_model_definition.fields["list3"].property_type == "list[str]" assert data_model_definition.fields["dict1"].name == "dict1" assert data_model_definition.fields["dict1"].property_type == "dict" assert data_model_definition.fields["dict2"].name == "dict2" assert data_model_definition.fields["dict2"].property_type == "dict" + assert data_model_definition.fields["dict3"].name == "dict3" + assert data_model_definition.fields["dict3"].property_type == "dict" assert data_model_definition.container_mode is False diff --git a/python/tests/unit/data/test_vector_store_record_collection.py b/python/tests/unit/data/test_vector_store_record_collection.py index 3e94e6c5ea7d..dcd198a97dbe 100644 --- a/python/tests/unit/data/test_vector_store_record_collection.py +++ b/python/tests/unit/data/test_vector_store_record_collection.py @@ -5,67 +5,21 @@ import numpy as np from pandas import DataFrame -from pytest import fixture, mark, raises +from pytest import mark, raises -from semantic_kernel.data import VectorStoreRecordCollection +from semantic_kernel.data.record_definition.vector_store_model_protocols import ( + SerializeMethodProtocol, + ToDictMethodProtocol, +) from semantic_kernel.exceptions import ( - MemoryConnectorException, VectorStoreModelDeserializationException, VectorStoreModelSerializationException, VectorStoreModelValidationError, + VectorStoreOperationException, ) -@fixture(scope="function") -def vector_store_record_collection( - DictVectorStoreRecordCollection, - data_model_definition, - data_model_serialize_definition, - data_model_to_from_dict_definition, - data_model_container_definition, - data_model_container_serialize_definition, - data_model_pandas_definition, - data_model_type_vanilla, - data_model_type_vanilla_serialize, - data_model_type_vanilla_to_from_dict, - data_model_type_pydantic, - data_model_type_dataclass, - data_model_type_vector_array, - request, -) -> VectorStoreRecordCollection: - item = request.param if request and hasattr(request, "param") else "definition_basic" - defs = { - "definition_basic": data_model_definition, - "definition_with_serialize": data_model_serialize_definition, - "definition_with_to_from": data_model_to_from_dict_definition, - "definition_container": data_model_container_definition, - "definition_container_serialize": data_model_container_serialize_definition, - "definition_pandas": data_model_pandas_definition, - "type_vanilla": data_model_type_vanilla, - "type_vanilla_with_serialize": data_model_type_vanilla_serialize, - "type_vanilla_with_to_from_dict": data_model_type_vanilla_to_from_dict, - "type_pydantic": data_model_type_pydantic, - "type_dataclass": data_model_type_dataclass, - "type_vector_array": data_model_type_vector_array, - } - if item.endswith("pandas"): - return DictVectorStoreRecordCollection( - collection_name="test", - data_model_type=DataFrame, - data_model_definition=defs[item], - ) - if item.startswith("definition_"): - return DictVectorStoreRecordCollection( - collection_name="test", - data_model_type=dict, - data_model_definition=defs[item], - ) - return DictVectorStoreRecordCollection( - collection_name="test", - data_model_type=defs[item], - ) - - +# region init def test_init(DictVectorStoreRecordCollection, data_model_definition): vsrc = DictVectorStoreRecordCollection( collection_name="test", @@ -79,6 +33,59 @@ def test_init(DictVectorStoreRecordCollection, data_model_definition): assert vsrc._key_field_name == "id" +def test_data_model_validation(data_model_type_vanilla, DictVectorStoreRecordCollection): + DictVectorStoreRecordCollection.supported_key_types = PropertyMock(return_value=["str"]) + DictVectorStoreRecordCollection.supported_vector_types = PropertyMock(return_value=["float"]) + DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=data_model_type_vanilla, + ) + + +def test_data_model_validation_key_fail(data_model_type_vanilla, DictVectorStoreRecordCollection): + DictVectorStoreRecordCollection.supported_key_types = PropertyMock(return_value=["int"]) + with raises(VectorStoreModelValidationError, match="Key field must be one of"): + DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=data_model_type_vanilla, + ) + + +def test_data_model_validation_vector_fail(data_model_type_vanilla, DictVectorStoreRecordCollection): + DictVectorStoreRecordCollection.supported_vector_types = PropertyMock(return_value=["list[int]"]) + with raises(VectorStoreModelValidationError, match="Vector field "): + DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=data_model_type_vanilla, + ) + + +# region Collection +async def test_collection_operations(vector_store_record_collection): + await vector_store_record_collection.create_collection() + assert await vector_store_record_collection.does_collection_exist() + record = {"id": "id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} + await vector_store_record_collection.upsert(record) + assert len(vector_store_record_collection.inner_storage) == 1 + await vector_store_record_collection.delete_collection() + assert vector_store_record_collection.inner_storage == {} + await vector_store_record_collection.create_collection_if_not_exists() + + +async def test_collection_create_if_not_exists(DictVectorStoreRecordCollection, data_model_definition): + DictVectorStoreRecordCollection.does_collection_exist = AsyncMock(return_value=False) + create_mock = AsyncMock() + DictVectorStoreRecordCollection.create_collection = create_mock + vector_store_record_collection = DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + ) + await vector_store_record_collection.create_collection_if_not_exists() + create_mock.assert_called_once() + + +# region CRUD @mark.parametrize( "vector_store_record_collection", [ @@ -111,11 +118,13 @@ async def test_crud_operations(vector_store_record_collection): if vector_store_record_collection.data_model_type is dict: assert vector_store_record_collection.inner_storage[id] == record else: + assert not isinstance(record, dict) assert vector_store_record_collection.inner_storage[id]["content"] == record.content record_2 = await vector_store_record_collection.get(id) if vector_store_record_collection.data_model_type is dict: assert record_2 == record else: + assert not isinstance(record, dict) if isinstance(record.vector, list): assert record_2 == record else: @@ -156,6 +165,7 @@ async def test_crud_batch_operations(vector_store_record_collection): if vector_store_record_collection.data_model_type is dict: assert vector_store_record_collection.inner_storage[ids[0]] == batch[0] else: + assert not isinstance(batch[0], dict) assert vector_store_record_collection.inner_storage[ids[0]]["content"] == batch[0].content records = await vector_store_record_collection.get_batch(ids) assert records == batch @@ -248,6 +258,27 @@ async def test_crud_batch_operations_pandas(vector_store_record_collection): assert len(vector_store_record_collection.inner_storage) == 0 +async def test_upsert_with_vectorizing(vector_store_record_collection): + record = {"id": "test_id", "content": "test_content"} + record2 = {"id": "test_id", "content": "test_content"} + + async def embedding_func(record, type, definition): + if isinstance(record, list): + for r in record: + r["vector"] = [1.0, 2.0, 3.0] + return record + record["vector"] = [1.0, 2.0, 3.0] + return record + + await vector_store_record_collection.upsert(record, embedding_generation_function=embedding_func) + assert vector_store_record_collection.inner_storage["test_id"]["vector"] == [1.0, 2.0, 3.0] + await vector_store_record_collection.delete("test_id") + assert len(vector_store_record_collection.inner_storage) == 0 + await vector_store_record_collection.upsert_batch([record2], embedding_generation_function=embedding_func) + assert vector_store_record_collection.inner_storage["test_id"]["vector"] == [1.0, 2.0, 3.0] + + +# region Fails async def test_upsert_fail(DictVectorStoreRecordCollection, data_model_definition): DictVectorStoreRecordCollection._inner_upsert = MagicMock(side_effect=Exception) vector_store_record_collection = DictVectorStoreRecordCollection( @@ -256,9 +287,9 @@ async def test_upsert_fail(DictVectorStoreRecordCollection, data_model_definitio data_model_definition=data_model_definition, ) record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} - with raises(MemoryConnectorException, match="Error upserting record:"): + with raises(VectorStoreOperationException, match="Error upserting record:"): await vector_store_record_collection.upsert(record) - with raises(MemoryConnectorException, match="Error upserting records:"): + with raises(VectorStoreOperationException, match="Error upserting records:"): await vector_store_record_collection.upsert_batch([record]) assert len(vector_store_record_collection.inner_storage) == 0 @@ -273,9 +304,25 @@ async def test_get_fail(DictVectorStoreRecordCollection, data_model_definition): record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} await vector_store_record_collection.upsert(record) assert len(vector_store_record_collection.inner_storage) == 1 - with raises(MemoryConnectorException, match="Error getting record:"): + with raises(VectorStoreOperationException, match="Error getting record:"): await vector_store_record_collection.get("test_id") - with raises(MemoryConnectorException, match="Error getting records:"): + with raises(VectorStoreOperationException, match="Error getting records:"): + await vector_store_record_collection.get_batch(["test_id"]) + + +async def test_deserialize_in_get_fail(DictVectorStoreRecordCollection, data_model_definition): + data_model_definition.deserialize = MagicMock(side_effect=Exception) + vector_store_record_collection = DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + ) + record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} + await vector_store_record_collection.upsert(record) + assert len(vector_store_record_collection.inner_storage) == 1 + with raises(VectorStoreModelDeserializationException, match="Error deserializing records:"): + await vector_store_record_collection.get("test_id") + with raises(VectorStoreModelDeserializationException, match="Error deserializing records:"): await vector_store_record_collection.get_batch(["test_id"]) @@ -292,7 +339,9 @@ async def test_get_fail_multiple(DictVectorStoreRecordCollection, data_model_def patch( "semantic_kernel.data.vector_storage.vector_store_record_collection.VectorStoreRecordCollection.deserialize" ) as deserialize_mock, - raises(MemoryConnectorException, match="Error deserializing record, multiple records returned:"), + raises( + VectorStoreModelDeserializationException, match="Error deserializing record, multiple records returned:" + ), ): deserialize_mock.return_value = [ {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]}, @@ -301,61 +350,49 @@ async def test_get_fail_multiple(DictVectorStoreRecordCollection, data_model_def await vector_store_record_collection.get("test_id") -async def test_serialize_fail(DictVectorStoreRecordCollection, data_model_definition): - DictVectorStoreRecordCollection.serialize = MagicMock(side_effect=Exception) +async def test_delete_fail(DictVectorStoreRecordCollection, data_model_definition): + DictVectorStoreRecordCollection._inner_delete = MagicMock(side_effect=Exception) vector_store_record_collection = DictVectorStoreRecordCollection( collection_name="test", data_model_type=dict, data_model_definition=data_model_definition, ) record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} - with raises(MemoryConnectorException, match="Error serializing record"): - await vector_store_record_collection.upsert(record) - with raises(MemoryConnectorException, match="Error serializing record"): - await vector_store_record_collection.upsert_batch([record]) + await vector_store_record_collection.upsert(record) + assert len(vector_store_record_collection.inner_storage) == 1 + with raises(VectorStoreOperationException, match="Error deleting record:"): + await vector_store_record_collection.delete("test_id") + with raises(VectorStoreOperationException, match="Error deleting records:"): + await vector_store_record_collection.delete_batch(["test_id"]) + assert len(vector_store_record_collection.inner_storage) == 1 -async def test_deserialize_fail(DictVectorStoreRecordCollection, data_model_definition): - DictVectorStoreRecordCollection.deserialize = MagicMock(side_effect=Exception) +# region Serialize +async def test_serialize_in_upsert_fail(DictVectorStoreRecordCollection, data_model_definition): + DictVectorStoreRecordCollection.serialize = MagicMock(side_effect=VectorStoreModelSerializationException) vector_store_record_collection = DictVectorStoreRecordCollection( collection_name="test", data_model_type=dict, data_model_definition=data_model_definition, ) record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} - vector_store_record_collection.inner_storage["test_id"] = record - with raises(MemoryConnectorException, match="Error deserializing record"): - await vector_store_record_collection.get("test_id") - with raises(MemoryConnectorException, match="Error deserializing record"): - await vector_store_record_collection.get_batch(["test_id"]) + with raises(VectorStoreModelSerializationException): + await vector_store_record_collection.upsert(record) + with raises(VectorStoreModelSerializationException): + await vector_store_record_collection.upsert_batch([record]) -def test_serialize_custom_fail(DictVectorStoreRecordCollection, data_model_type_vanilla_serialize): - data_model_type_vanilla_serialize.serialize = MagicMock(side_effect=Exception) +def test_serialize_data_model_type_serialize_fail(DictVectorStoreRecordCollection, data_model_type_vanilla_serialize): vector_store_record_collection = DictVectorStoreRecordCollection( collection_name="test", data_model_type=data_model_type_vanilla_serialize, ) - record = data_model_type_vanilla_serialize( - content="test_content", - vector=[1.0, 2.0, 3.0], - id="test_id", - ) - with raises(VectorStoreModelSerializationException, match="Error serializing record:"): + record = MagicMock(spec=SerializeMethodProtocol) + record.serialize = MagicMock(side_effect=Exception) + with raises(VectorStoreModelSerializationException, match="Error serializing record"): vector_store_record_collection.serialize(record) -def test_deserialize_custom_fail(DictVectorStoreRecordCollection, data_model_type_vanilla_serialize): - data_model_type_vanilla_serialize.deserialize = MagicMock(side_effect=Exception) - vector_store_record_collection = DictVectorStoreRecordCollection( - collection_name="test", - data_model_type=data_model_type_vanilla_serialize, - ) - record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} - with raises(VectorStoreModelSerializationException, match="Error deserializing record:"): - vector_store_record_collection.deserialize(record) - - def test_serialize_data_model_to_dict_fail_mapping(DictVectorStoreRecordCollection, data_model_definition): vector_store_record_collection = DictVectorStoreRecordCollection( collection_name="test", @@ -363,7 +400,7 @@ def test_serialize_data_model_to_dict_fail_mapping(DictVectorStoreRecordCollecti data_model_definition=data_model_definition, ) record = {"content": "test_content", "vector": [1.0, 2.0, 3.0]} - with raises(VectorStoreModelSerializationException, match="Error serializing record"): + with raises(KeyError): vector_store_record_collection._serialize_data_model_to_dict(record) @@ -373,10 +410,83 @@ def test_serialize_data_model_to_dict_fail_object(DictVectorStoreRecordCollectio data_model_type=data_model_type_vanilla, ) record = Mock(spec=data_model_type_vanilla) - with raises(VectorStoreModelSerializationException, match="Error serializing record"): + with raises(AttributeError): vector_store_record_collection._serialize_data_model_to_dict(record) +@mark.parametrize("vector_store_record_collection", ["type_pydantic"], indirect=True) +def test_pydantic_serialize_fail(vector_store_record_collection): + id = "test_id" + model = deepcopy(vector_store_record_collection.data_model_type) + model.model_dump = MagicMock(side_effect=Exception) + vector_store_record_collection.data_model_type = model + dict_record = {"id": id, "content": "test_content", "vector": [1.0, 2.0, 3.0]} + record = model(**dict_record) + with raises(VectorStoreModelSerializationException, match="Error serializing record"): + vector_store_record_collection.serialize(record) + + +@mark.parametrize("vector_store_record_collection", ["type_vanilla_with_to_from_dict"], indirect=True) +def test_to_dict_fail(vector_store_record_collection): + record = MagicMock(spec=ToDictMethodProtocol) + record.to_dict = MagicMock(side_effect=Exception) + with raises(VectorStoreModelSerializationException, match="Error serializing record"): + vector_store_record_collection.serialize(record) + + +def test_serialize_with_array_func(DictVectorStoreRecordCollection, data_model_type_pydantic_array): + vector_store_record_collection = DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=data_model_type_pydantic_array, + ) + record = data_model_type_pydantic_array(**{ + "id": "test_id", + "content": "test_content", + "vector": np.array([1.0, 2.0, 3.0]), + }) + serialized_record = vector_store_record_collection.serialize(record) + assert serialized_record["vector"] == [1.0, 2.0, 3.0] + + +# region Deserialize + + +async def test_deserialize_definition_fail(DictVectorStoreRecordCollection, data_model_definition): + data_model_definition.deserialize = MagicMock(side_effect=Exception) + vector_store_record_collection = DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + ) + record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} + vector_store_record_collection.inner_storage["test_id"] = record + with raises(VectorStoreModelDeserializationException, match="Error deserializing record"): + await vector_store_record_collection.get("test_id") + with raises(VectorStoreModelDeserializationException, match="Error deserializing record"): + await vector_store_record_collection.get_batch(["test_id"]) + + +async def test_deserialize_definition_none(DictVectorStoreRecordCollection, data_model_definition): + vector_store_record_collection = DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + ) + assert vector_store_record_collection.deserialize([]) is None + assert vector_store_record_collection.deserialize({}) is None + + +def test_deserialize_type_fail(DictVectorStoreRecordCollection, data_model_type_vanilla_serialize): + vector_store_record_collection = DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=data_model_type_vanilla_serialize, + ) + record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} + vector_store_record_collection.data_model_type.deserialize = MagicMock(side_effect=Exception) + with raises(VectorStoreModelDeserializationException, match="Error deserializing record"): + vector_store_record_collection.deserialize(record) + + def test_deserialize_dict_data_model_fail_sequence(DictVectorStoreRecordCollection, data_model_type_vanilla): vector_store_record_collection = DictVectorStoreRecordCollection( collection_name="test", @@ -392,7 +502,7 @@ def test_deserialize_dict_data_model_fail(DictVectorStoreRecordCollection, data_ data_model_type=dict, data_model_definition=data_model_definition, ) - with raises(VectorStoreModelDeserializationException, match="Error deserializing record"): + with raises(KeyError): vector_store_record_collection._deserialize_dict_to_data_model({ "content": "test_content", "vector": [1.0, 2.0, 3.0], @@ -412,121 +522,35 @@ def test_deserialize_dict_data_model_shortcut(DictVectorStoreRecordCollection, d @mark.parametrize("vector_store_record_collection", ["type_pydantic"], indirect=True) -async def test_pydantic_fail(vector_store_record_collection): +async def test_pydantic_deserialize_fail(vector_store_record_collection): id = "test_id" - model = deepcopy(vector_store_record_collection.data_model_type) dict_record = {"id": id, "content": "test_content", "vector": [1.0, 2.0, 3.0]} - record = model(**dict_record) - model.model_dump = MagicMock(side_effect=Exception) - with raises(VectorStoreModelSerializationException, match="Error serializing record:"): - vector_store_record_collection.serialize(record) - with raises(MemoryConnectorException, match="Error serializing record:"): - await vector_store_record_collection.upsert(record) - model.model_validate = MagicMock(side_effect=Exception) - with raises(VectorStoreModelDeserializationException, match="Error deserializing record:"): + vector_store_record_collection.data_model_type.model_validate = MagicMock(side_effect=Exception) + with raises(VectorStoreModelDeserializationException, match="Error deserializing record"): vector_store_record_collection.deserialize(dict_record) @mark.parametrize("vector_store_record_collection", ["type_vanilla_with_to_from_dict"], indirect=True) -def test_to_from_dict_fail(vector_store_record_collection): +def test_from_dict_fail(vector_store_record_collection): id = "test_id" model = deepcopy(vector_store_record_collection.data_model_type) dict_record = {"id": id, "content": "test_content", "vector": [1.0, 2.0, 3.0]} - record = model(**dict_record) - model.to_dict = MagicMock(side_effect=Exception) - with raises(VectorStoreModelSerializationException, match="Error serializing record:"): - vector_store_record_collection.serialize(record) model.from_dict = MagicMock(side_effect=Exception) - with raises(VectorStoreModelDeserializationException, match="Error deserializing record:"): + vector_store_record_collection.data_model_type = model + with raises(VectorStoreModelDeserializationException, match="Error deserializing record"): vector_store_record_collection.deserialize(dict_record) -async def test_delete_fail(DictVectorStoreRecordCollection, data_model_definition): - DictVectorStoreRecordCollection._inner_delete = MagicMock(side_effect=Exception) - vector_store_record_collection = DictVectorStoreRecordCollection( - collection_name="test", - data_model_type=dict, - data_model_definition=data_model_definition, - ) - record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} - await vector_store_record_collection.upsert(record) - assert len(vector_store_record_collection.inner_storage) == 1 - with raises(MemoryConnectorException, match="Error deleting record:"): - await vector_store_record_collection.delete("test_id") - with raises(MemoryConnectorException, match="Error deleting records:"): - await vector_store_record_collection.delete_batch(["test_id"]) - assert len(vector_store_record_collection.inner_storage) == 1 - - -async def test_collection_operations(vector_store_record_collection): - await vector_store_record_collection.create_collection() - assert await vector_store_record_collection.does_collection_exist() - record = {"id": "id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} - await vector_store_record_collection.upsert(record) - assert len(vector_store_record_collection.inner_storage) == 1 - await vector_store_record_collection.delete_collection() - assert vector_store_record_collection.inner_storage == {} - await vector_store_record_collection.create_collection_if_not_exists() - - -async def test_collection_create_if_not_exists(DictVectorStoreRecordCollection, data_model_definition): - DictVectorStoreRecordCollection.does_collection_exist = AsyncMock(return_value=False) - create_mock = AsyncMock() - DictVectorStoreRecordCollection.create_collection = create_mock +def test_deserialize_with_array_func(DictVectorStoreRecordCollection, data_model_type_pydantic_array): vector_store_record_collection = DictVectorStoreRecordCollection( collection_name="test", - data_model_type=dict, - data_model_definition=data_model_definition, + data_model_type=data_model_type_pydantic_array, ) - await vector_store_record_collection.create_collection_if_not_exists() - create_mock.assert_called_once() - - -def test_data_model_validation(data_model_type_vanilla, DictVectorStoreRecordCollection): - DictVectorStoreRecordCollection.supported_key_types = PropertyMock(return_value=["str"]) - DictVectorStoreRecordCollection.supported_vector_types = PropertyMock(return_value=["float"]) - DictVectorStoreRecordCollection( - collection_name="test", - data_model_type=data_model_type_vanilla, - ) - - -def test_data_model_validation_key_fail(data_model_type_vanilla, DictVectorStoreRecordCollection): - DictVectorStoreRecordCollection.supported_key_types = PropertyMock(return_value=["int"]) - with raises(VectorStoreModelValidationError, match="Key field must be one of"): - DictVectorStoreRecordCollection( - collection_name="test", - data_model_type=data_model_type_vanilla, - ) - - -def test_data_model_validation_vector_fail(data_model_type_vanilla, DictVectorStoreRecordCollection): - DictVectorStoreRecordCollection.supported_vector_types = PropertyMock(return_value=["list[int]"]) - with raises(VectorStoreModelValidationError, match="Vector field "): - DictVectorStoreRecordCollection( - collection_name="test", - data_model_type=data_model_type_vanilla, - ) - - -async def test_upsert_with_vectorizing(vector_store_record_collection): - record = {"id": "test_id", "content": "test_content"} - record2 = {"id": "test_id", "content": "test_content"} - - async def embedding_func(record, type, definition): - if isinstance(record, list): - for r in record: - r["vector"] = [1.0, 2.0, 3.0] - return record - record["vector"] = [1.0, 2.0, 3.0] - return record - - await vector_store_record_collection.upsert(record, embedding_generation_function=embedding_func) - assert vector_store_record_collection.inner_storage["test_id"]["vector"] == [1.0, 2.0, 3.0] - await vector_store_record_collection.delete("test_id") - assert len(vector_store_record_collection.inner_storage) == 0 - await vector_store_record_collection.upsert_batch([record2], embedding_generation_function=embedding_func) - assert vector_store_record_collection.inner_storage["test_id"]["vector"] == [1.0, 2.0, 3.0] - - -# TODO (eavanvalkenburg): pandas container test + record = { + "id": "test_id", + "content": "test_content", + "vector": [1.0, 2.0, 3.0], + } + deserialized_record = vector_store_record_collection.deserialize(record) + assert isinstance(deserialized_record.vector, np.ndarray) + assert np.array_equal(deserialized_record.vector, np.array([1.0, 2.0, 3.0])) diff --git a/python/tests/unit/data/test_vector_store_record_definition.py b/python/tests/unit/data/test_vector_store_record_definition.py index 897623c95dc7..3f1684748d6c 100644 --- a/python/tests/unit/data/test_vector_store_record_definition.py +++ b/python/tests/unit/data/test_vector_store_record_definition.py @@ -7,6 +7,7 @@ VectorStoreRecordDefinition, VectorStoreRecordKeyField, ) +from semantic_kernel.data.record_definition.vector_store_record_fields import VectorStoreRecordVectorField from semantic_kernel.exceptions import VectorStoreModelException @@ -55,3 +56,64 @@ def test_no_matching_vector_field_fail(): "content": VectorStoreRecordDataField(has_embedding=True, embedding_property_name="vector"), } ) + + +def test_vector_and_non_vector_field_names(): + definition = VectorStoreRecordDefinition( + fields={ + "id": VectorStoreRecordKeyField(), + "content": VectorStoreRecordDataField(), + "vector": VectorStoreRecordVectorField(), + } + ) + assert definition.vector_field_names == ["vector"] + assert definition.non_vector_field_names == ["id", "content"] + + +def test_try_get_vector_field(): + definition = VectorStoreRecordDefinition( + fields={ + "id": VectorStoreRecordKeyField(), + "content": VectorStoreRecordDataField(), + "vector": VectorStoreRecordVectorField(), + } + ) + assert definition.try_get_vector_field() == definition.fields["vector"] + assert definition.try_get_vector_field("vector") == definition.fields["vector"] + + +def test_try_get_vector_field_none(): + definition = VectorStoreRecordDefinition( + fields={ + "id": VectorStoreRecordKeyField(), + "content": VectorStoreRecordDataField(), + } + ) + assert definition.try_get_vector_field() is None + with raises(VectorStoreModelException, match="Field vector not found."): + definition.try_get_vector_field("vector") + + +def test_try_get_vector_field_wrong_name_fail(): + definition = VectorStoreRecordDefinition( + fields={ + "id": VectorStoreRecordKeyField(), + "content": VectorStoreRecordDataField(), + } + ) + with raises(VectorStoreModelException, match="Field content is not a vector field."): + definition.try_get_vector_field("content") + + +def test_get_field_names(): + definition = VectorStoreRecordDefinition( + fields={ + "id": VectorStoreRecordKeyField(), + "content": VectorStoreRecordDataField(), + "vector": VectorStoreRecordVectorField(), + } + ) + assert definition.get_field_names() == ["id", "content", "vector"] + assert definition.get_field_names(include_vector_fields=False) == ["id", "content"] + assert definition.get_field_names(include_key_field=False) == ["content", "vector"] + assert definition.get_field_names(include_vector_fields=False, include_key_field=False) == ["content"] diff --git a/python/tests/unit/data/test_vector_store_record_utils.py b/python/tests/unit/data/test_vector_store_record_utils.py index cfb2ea448d64..669ede25376d 100644 --- a/python/tests/unit/data/test_vector_store_record_utils.py +++ b/python/tests/unit/data/test_vector_store_record_utils.py @@ -40,3 +40,14 @@ async def test_add_vector_wrong_fields(): record = {"id": "test_id", "content": "content"} with raises(VectorStoreModelException, match="Embedding field"): await utils.add_vector_to_records(record, None, data_model) + + +async def test_fail(): + kernel = MagicMock(spec=Kernel) + kernel.add_embedding_to_object = AsyncMock() + utils = VectorStoreRecordUtils(kernel) + assert utils is not None + record = {"id": "test_id", "content": "content"} + with raises(VectorStoreModelException, match="Data model definition is required"): + await utils.add_vector_to_records(record, dict, None) + kernel.add_embedding_to_object.assert_not_called() diff --git a/python/tests/unit/data/test_vector_store_text_search.py b/python/tests/unit/data/test_vector_store_text_search.py index 0f485349d098..70358011a592 100644 --- a/python/tests/unit/data/test_vector_store_text_search.py +++ b/python/tests/unit/data/test_vector_store_text_search.py @@ -2,11 +2,16 @@ from unittest.mock import patch +from pydantic import BaseModel from pytest import fixture, raises from semantic_kernel.connectors.ai.open_ai import AzureTextEmbedding from semantic_kernel.data import VectorStoreTextSearch -from semantic_kernel.exceptions import VectorStoreTextSearchValidationError +from semantic_kernel.data.text_search.text_search_result import TextSearchResult +from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions +from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult +from semantic_kernel.exceptions import VectorStoreInitializationException +from semantic_kernel.utils.list_handler import desync_list @fixture @@ -28,6 +33,7 @@ async def test_from_vectorizable_text_search(vector_collection): assert search is not None assert text_search_result is not None assert search_result is not None + assert vsts.options_class is VectorSearchOptions async def test_from_vector_text_search(vector_collection): @@ -60,10 +66,86 @@ async def test_from_vectorized_search(vector_collection, azure_openai_unit_test_ def test_validation_no_embedder_for_vectorized_search(vector_collection): - with raises(VectorStoreTextSearchValidationError): + with raises(VectorStoreInitializationException): VectorStoreTextSearch(vectorized_search=vector_collection) def test_validation_no_collections(): - with raises(VectorStoreTextSearchValidationError): + with raises(VectorStoreInitializationException): VectorStoreTextSearch() + + +async def test_get_results_as_string(vector_collection): + test_search = VectorStoreTextSearch.from_vector_text_search(vector_text_search=vector_collection) + results = [ + res + async for res in test_search._get_results_as_strings(results=desync_list([VectorSearchResult(record="test")])) + ] + assert results == ["test"] + + class TestClass(BaseModel): + test: str + + results = [ + res + async for res in test_search._get_results_as_strings( + results=desync_list([VectorSearchResult(record=TestClass(test="test"))]) + ) + ] + + assert results == ['{"test":"test"}'] + + test_search = VectorStoreTextSearch.from_vector_text_search( + vector_text_search=vector_collection, string_mapper=lambda x: x.test + ) + + class TestClass(BaseModel): + test: str + + results = [ + res + async for res in test_search._get_results_as_strings( + results=desync_list([VectorSearchResult(record=TestClass(test="test"))]) + ) + ] + + assert results == ["test"] + + +async def test_get_results_as_test_search_result(vector_collection): + test_search = VectorStoreTextSearch.from_vector_text_search(vector_text_search=vector_collection) + results = [ + res + async for res in test_search._get_results_as_text_search_result( + results=desync_list([VectorSearchResult(record="test")]) + ) + ] + assert results == [TextSearchResult(value="test")] + + class TestClass(BaseModel): + test: str + + results = [ + res + async for res in test_search._get_results_as_text_search_result( + results=desync_list([VectorSearchResult(record=TestClass(test="test"))]) + ) + ] + + assert results == [TextSearchResult(value='{"test":"test"}')] + + test_search = VectorStoreTextSearch.from_vector_text_search( + vector_text_search=vector_collection, text_search_results_mapper=lambda x: TextSearchResult(value=x.test) + ) + + class TestClass(BaseModel): + test: str + + results = [ + res + async for res in test_search._get_results_as_text_search_result( + results=desync_list([VectorSearchResult(record=TestClass(test="test"))]) + ) + ] + + assert results == [TextSearchResult(value="test")] diff --git a/python/tests/utils.py b/python/tests/utils.py index fc27ce6e1a31..aac2ca65f408 100644 --- a/python/tests/utils.py +++ b/python/tests/utils.py @@ -14,6 +14,7 @@ async def retry( func: Callable[..., Awaitable[Any]], retries: int = 20, reset: Callable[..., None] | None = None, + name: str | None = None, ): """Retry the function if it raises an exception. @@ -23,9 +24,9 @@ async def retry( reset (function): Function to reset the state of any variables used in the function """ - logger.info(f"Running {retries} retries with func: {func.__module__}") + logger.info(f"Running {retries} retries with func: {name or func.__module__}") for i in range(retries): - logger.info(f" Try {i + 1} for {func.__module__}") + logger.info(f" Try {i + 1} for {name or func.__module__}") try: if reset: reset()