diff --git a/python/.vscode/settings.json b/python/.vscode/settings.json index 93b973d6fc73..9d465bfc9a7a 100644 --- a/python/.vscode/settings.json +++ b/python/.vscode/settings.json @@ -13,11 +13,6 @@ "locale": "en-US" } ], - "python.testing.pytestArgs": [ - "tests" - ], - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true, "[python]": { "editor.codeActionsOnSave": { "source.organizeImports": "explicit", @@ -29,7 +24,7 @@ "notebook.formatOnSave.enabled": true, "notebook.codeActionsOnSave": { "source.fixAll": true, - "source.organizeImports": true + "source.organizeImports": false }, "python.analysis.extraPaths": [ "${workspaceFolder}/samples/learn_resources" diff --git a/python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_collection.py b/python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_collection.py index 1eec8028a1b8..d103997db200 100644 --- a/python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_collection.py +++ b/python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_collection.py @@ -34,7 +34,11 @@ from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult from semantic_kernel.data.vector_search.vector_text_search import VectorTextSearchMixin from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin -from semantic_kernel.exceptions import MemoryConnectorException, MemoryConnectorInitializationError +from semantic_kernel.exceptions import ( + VectorSearchExecutionException, + VectorStoreInitializationException, + VectorStoreOperationException, +) from semantic_kernel.utils.experimental_decorator import experimental_class logger: logging.Logger = logging.getLogger(__name__) @@ -91,7 +95,7 @@ def __init__( if not collection_name: collection_name = search_client._index_name elif search_client._index_name != collection_name: - raise MemoryConnectorInitializationError( + raise VectorStoreInitializationException( "Search client and search index client have different index names." ) super().__init__( @@ -107,7 +111,7 @@ def __init__( if search_index_client: if not collection_name: - raise MemoryConnectorInitializationError("Collection name is required.") + raise VectorStoreInitializationException("Collection name is required.") super().__init__( data_model_type=data_model_type, data_model_definition=data_model_definition, @@ -133,14 +137,14 @@ def __init__( index_name=collection_name, ) except ValidationError as exc: - raise MemoryConnectorInitializationError("Failed to create Azure Cognitive Search settings.") from exc + raise VectorStoreInitializationException("Failed to create Azure Cognitive Search settings.") from exc search_index_client = get_search_index_client( azure_ai_search_settings=azure_ai_search_settings, azure_credential=kwargs.get("azure_credentials"), token_credential=kwargs.get("token_credentials"), ) if not azure_ai_search_settings.index_name: - raise MemoryConnectorInitializationError("Collection name is required.") + raise VectorStoreInitializationException("Collection name is required.") super().__init__( data_model_type=data_model_type, @@ -211,7 +215,7 @@ async def create_collection(self, **kwargs) -> None: if isinstance(index, SearchIndex): await self.search_index_client.create_index(index=index, **kwargs) return - raise MemoryConnectorException("Invalid index type supplied.") + raise VectorStoreOperationException("Invalid index type supplied, should be a SearchIndex object.") await self.search_index_client.create_index( index=data_model_definition_to_azure_ai_search_index( collection_name=self.collection_name, @@ -279,7 +283,10 @@ async def _inner_search( for name, field in self.data_model_definition.fields.items() if not isinstance(field, VectorStoreRecordVectorField) ] - raw_results = await self.search_client.search(**search_args) + try: + raw_results = await self.search_client.search(**search_args) + except Exception as exc: + raise VectorSearchExecutionException("Failed to search the collection.") from exc return KernelSearchResults( results=self._get_vector_search_results_from_results(raw_results, options), total_count=await raw_results.get_count() if options.include_total_count else None, diff --git a/python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_store.py b/python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_store.py index fcaecf62f97f..4c4693abb6d7 100644 --- a/python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_store.py +++ b/python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_store.py @@ -19,7 +19,7 @@ from semantic_kernel.connectors.memory.azure_ai_search.utils import get_search_client, get_search_index_client from semantic_kernel.data.record_definition import VectorStoreRecordDefinition from semantic_kernel.data.vector_storage import VectorStore -from semantic_kernel.exceptions import MemoryConnectorInitializationError +from semantic_kernel.exceptions import VectorStoreInitializationException from semantic_kernel.utils.experimental_decorator import experimental_class if TYPE_CHECKING: @@ -78,7 +78,7 @@ def __init__( env_file_encoding=env_file_encoding, ) except ValidationError as exc: - raise MemoryConnectorInitializationError("Failed to create Azure AI Search settings.") from exc + raise VectorStoreInitializationException("Failed to create Azure AI Search settings.") from exc search_index_client = get_search_index_client( azure_ai_search_settings=azure_ai_search_settings, azure_credential=azure_credentials, diff --git a/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_base.py b/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_base.py index afe8fa75f140..c9a49ba4e546 100644 --- a/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_base.py +++ b/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_base.py @@ -6,9 +6,9 @@ from semantic_kernel.connectors.memory.azure_cosmos_db.azure_cosmos_db_no_sql_settings import AzureCosmosDBNoSQLSettings from semantic_kernel.connectors.memory.azure_cosmos_db.utils import CosmosClientWrapper -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorInitializationError, - MemoryConnectorResourceNotFound, +from semantic_kernel.exceptions import ( + VectorStoreInitializationException, + VectorStoreOperationException, ) from semantic_kernel.kernel_pydantic import KernelBaseModel from semantic_kernel.utils.authentication.async_default_azure_credential_wrapper import ( @@ -63,10 +63,10 @@ def __init__( env_file_encoding=env_file_encoding, ) except ValidationError as e: - raise MemoryConnectorInitializationError("Failed to validate Azure Cosmos DB NoSQL settings.") from e + raise VectorStoreInitializationException("Failed to validate Azure Cosmos DB NoSQL settings.") from e if cosmos_db_nosql_settings.database_name is None: - raise MemoryConnectorInitializationError("The name of the Azure Cosmos DB NoSQL database is missing.") + raise VectorStoreInitializationException("The name of the Azure Cosmos DB NoSQL database is missing.") if cosmos_client is None: if cosmos_db_nosql_settings.key is not None: @@ -94,7 +94,7 @@ async def _does_database_exist(self) -> bool: except CosmosResourceNotFoundError: return False except Exception as e: - raise MemoryConnectorResourceNotFound( + raise VectorStoreOperationException( f"Failed to check if database '{self.database_name}' exists, with message {e}" ) from e @@ -106,9 +106,9 @@ async def _get_database_proxy(self, **kwargs) -> DatabaseProxy: if self.create_database: return await self.cosmos_client.create_database(self.database_name, **kwargs) - raise MemoryConnectorResourceNotFound(f"Database '{self.database_name}' does not exist.") + raise VectorStoreOperationException(f"Database '{self.database_name}' does not exist.") except Exception as e: - raise MemoryConnectorResourceNotFound(f"Failed to get database proxy for '{id}'.") from e + raise VectorStoreOperationException(f"Failed to get database proxy for '{id}'.") from e async def _get_container_proxy(self, container_name: str, **kwargs) -> ContainerProxy: """Gets the container proxy.""" @@ -116,4 +116,4 @@ async def _get_container_proxy(self, container_name: str, **kwargs) -> Container database_proxy = await self._get_database_proxy(**kwargs) return database_proxy.get_container_client(container_name) except Exception as e: - raise MemoryConnectorResourceNotFound(f"Failed to get container proxy for '{container_name}'.") from e + raise VectorStoreOperationException(f"Failed to get container proxy for '{container_name}'.") from e diff --git a/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_collection.py b/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_collection.py index 00a6813c1802..aa8633ecb54e 100644 --- a/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_collection.py +++ b/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_collection.py @@ -11,7 +11,7 @@ from typing_extensions import override # pragma: no cover from azure.cosmos.aio import CosmosClient -from azure.cosmos.exceptions import CosmosBatchOperationError, CosmosHttpResponseError, CosmosResourceNotFoundError +from azure.cosmos.exceptions import CosmosHttpResponseError from azure.cosmos.partition_key import PartitionKey from semantic_kernel.connectors.memory.azure_cosmos_db.azure_cosmos_db_no_sql_base import AzureCosmosDBNoSQLBase @@ -36,10 +36,10 @@ from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult from semantic_kernel.data.vector_search.vector_text_search import VectorTextSearchMixin from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, - MemoryConnectorResourceNotFound, +from semantic_kernel.exceptions import ( + VectorSearchExecutionException, VectorStoreModelDeserializationException, + VectorStoreOperationException, ) from semantic_kernel.kernel_types import OneOrMany from semantic_kernel.utils.experimental_decorator import experimental_class @@ -120,15 +120,8 @@ async def _inner_upsert( **kwargs: Any, ) -> Sequence[TKey]: container_proxy = await self._get_container_proxy(self.collection_name, **kwargs) - try: - results = await asyncio.gather(*(container_proxy.upsert_item(record) for record in records)) - return [result[COSMOS_ITEM_ID_PROPERTY_NAME] for result in results] - except CosmosResourceNotFoundError as e: - raise MemoryConnectorResourceNotFound( - "The collection does not exist yet. Create the collection first." - ) from e - except (CosmosBatchOperationError, CosmosHttpResponseError) as e: - raise MemoryConnectorException("Failed to upsert items.") from e + results = await asyncio.gather(*(container_proxy.upsert_item(record) for record in records)) + return [result[COSMOS_ITEM_ID_PROPERTY_NAME] for result in results] @override async def _inner_get(self, keys: Sequence[TKey], **kwargs: Any) -> OneOrMany[Any] | None: @@ -140,14 +133,7 @@ async def _inner_get(self, keys: Sequence[TKey], **kwargs: Any) -> OneOrMany[Any parameters: list[dict[str, Any]] = [{"name": f"@id{i}", "value": get_key(key)} for i, key in enumerate(keys)] container_proxy = await self._get_container_proxy(self.collection_name, **kwargs) - try: - return [item async for item in container_proxy.query_items(query=query, parameters=parameters)] - except CosmosResourceNotFoundError as e: - raise MemoryConnectorResourceNotFound( - "The collection does not exist yet. Create the collection first." - ) from e - except Exception as e: - raise MemoryConnectorException("Failed to read items.") from e + return [item async for item in container_proxy.query_items(query=query, parameters=parameters)] @override async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None: @@ -158,7 +144,7 @@ async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None: ) exceptions = [result for result in results if isinstance(result, Exception)] if exceptions: - raise MemoryConnectorException("Failed to delete item(s).", exceptions) + raise VectorStoreOperationException("Failed to delete item(s).", exceptions) @override async def _inner_search( @@ -177,12 +163,12 @@ async def _inner_search( query = self._build_vector_query(options) params.append({"name": "@vector", "value": vector}) else: - raise ValueError("Either search_text or vector must be provided.") + raise VectorSearchExecutionException("Either search_text or vector must be provided.") container_proxy = await self._get_container_proxy(self.collection_name, **kwargs) try: results = container_proxy.query_items(query, parameters=params) - except Exception as e: - raise MemoryConnectorException("Failed to search items.") from e + except Exception as exc: + raise VectorSearchExecutionException("Failed to search items.") from exc return KernelSearchResults( results=self._get_vector_search_results_from_results(results, options), total_count=None, @@ -286,26 +272,26 @@ def _deserialize_store_models_to_dicts(self, records: Sequence[Any], **kwargs: A @override async def create_collection(self, **kwargs) -> None: + indexing_policy = kwargs.pop("indexing_policy", create_default_indexing_policy(self.data_model_definition)) + vector_embedding_policy = kwargs.pop( + "vector_embedding_policy", create_default_vector_embedding_policy(self.data_model_definition) + ) + database_proxy = await self._get_database_proxy(**kwargs) try: - database_proxy = await self._get_database_proxy(**kwargs) await database_proxy.create_container_if_not_exists( id=self.collection_name, partition_key=self.partition_key, - indexing_policy=kwargs.pop( - "indexing_policy", create_default_indexing_policy(self.data_model_definition) - ), - vector_embedding_policy=kwargs.pop( - "vector_embedding_policy", create_default_vector_embedding_policy(self.data_model_definition) - ), + indexing_policy=indexing_policy, + vector_embedding_policy=vector_embedding_policy, **kwargs, ) except CosmosHttpResponseError as e: - raise MemoryConnectorException("Failed to create container.") from e + raise VectorStoreOperationException("Failed to create container.") from e @override async def does_collection_exist(self, **kwargs) -> bool: + container_proxy = await self._get_container_proxy(self.collection_name, **kwargs) try: - container_proxy = await self._get_container_proxy(self.collection_name, **kwargs) await container_proxy.read(**kwargs) return True except CosmosHttpResponseError: @@ -313,11 +299,11 @@ async def does_collection_exist(self, **kwargs) -> bool: @override async def delete_collection(self, **kwargs) -> None: + database_proxy = await self._get_database_proxy(**kwargs) try: - database_proxy = await self._get_database_proxy(**kwargs) await database_proxy.delete_container(self.collection_name) - except CosmosHttpResponseError as e: - raise MemoryConnectorException("Container could not be deleted.") from e + except Exception as e: + raise VectorStoreOperationException("Container could not be deleted.") from e @override async def __aexit__(self, exc_type, exc_value, traceback) -> None: diff --git a/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_store.py b/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_store.py index 9cf7b0629d30..45ca18b58c55 100644 --- a/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_store.py +++ b/python/semantic_kernel/connectors/memory/azure_cosmos_db/azure_cosmos_db_no_sql_store.py @@ -18,7 +18,7 @@ from semantic_kernel.data.record_definition.vector_store_model_definition import VectorStoreRecordDefinition from semantic_kernel.data.vector_storage.vector_store import VectorStore from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection -from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorException +from semantic_kernel.exceptions import VectorStoreOperationException from semantic_kernel.utils.experimental_decorator import experimental_class TModel = TypeVar("TModel") @@ -93,7 +93,7 @@ async def list_collection_names(self, **kwargs) -> Sequence[str]: containers = database.list_containers() return [container["id"] async for container in containers] except Exception as e: - raise MemoryConnectorException("Failed to list collection names.") from e + raise VectorStoreOperationException("Failed to list collection names.") from e async def __aexit__(self, exc_type, exc_value, traceback) -> None: """Exit the context manager.""" diff --git a/python/semantic_kernel/connectors/memory/azure_cosmos_db/utils.py b/python/semantic_kernel/connectors/memory/azure_cosmos_db/utils.py index 7955153a8135..40115cb647ef 100644 --- a/python/semantic_kernel/connectors/memory/azure_cosmos_db/utils.py +++ b/python/semantic_kernel/connectors/memory/azure_cosmos_db/utils.py @@ -20,7 +20,7 @@ VectorStoreRecordDataField, VectorStoreRecordVectorField, ) -from semantic_kernel.exceptions.memory_connector_exceptions import VectorStoreModelException +from semantic_kernel.exceptions import VectorStoreModelException def to_vector_index_policy_type(index_kind: IndexKind | None) -> str: @@ -34,6 +34,9 @@ def to_vector_index_policy_type(index_kind: IndexKind | None) -> str: Returns: str: The vector index policy type. + + Raises: + VectorStoreModelException: If the index kind is not supported by Azure Cosmos DB NoSQL container. """ if index_kind is None: # Use IndexKind.FLAT as the default index kind. @@ -46,7 +49,18 @@ def to_vector_index_policy_type(index_kind: IndexKind | None) -> str: def to_distance_function(distance_function: DistanceFunction | None) -> str: - """Converts the distance function to the distance function for Azure Cosmos DB NoSQL container.""" + """Converts the distance function to the distance function for Azure Cosmos DB NoSQL container. + + Args: + distance_function: The distance function. + + Returns: + str: The distance function as defined by Azure Cosmos DB NoSQL container. + + Raises: + VectorStoreModelException: If the distance function is not supported by Azure Cosmos DB NoSQL container. + + """ if distance_function is None: # Use DistanceFunction.COSINE_SIMILARITY as the default distance function. return DISTANCE_FUNCTION_MAPPING[DistanceFunction.COSINE_SIMILARITY] @@ -60,7 +74,18 @@ def to_distance_function(distance_function: DistanceFunction | None) -> str: def to_datatype(property_type: str | None) -> str: - """Converts the property type to the data type for Azure Cosmos DB NoSQL container.""" + """Converts the property type to the data type for Azure Cosmos DB NoSQL container. + + Args: + property_type: The property type. + + Returns: + str: The data type as defined by Azure Cosmos DB NoSQL container. + + Raises: + VectorStoreModelException: If the property type is not supported by Azure Cosmos DB NoSQL container + + """ if property_type is None: # Use the default data type. return DATATYPES_MAPPING["default"] @@ -83,8 +108,11 @@ def create_default_indexing_policy(data_model_definition: VectorStoreRecordDefin Returns: dict[str, Any]: The indexing policy. + + Raises: + VectorStoreModelException: If the field is not full text searchable and not filterable. """ - indexing_policy = { + indexing_policy: dict[str, Any] = { "automatic": True, "includedPaths": [ { @@ -103,15 +131,15 @@ def create_default_indexing_policy(data_model_definition: VectorStoreRecordDefin if isinstance(field, VectorStoreRecordDataField) and ( not field.is_full_text_searchable and not field.is_filterable ): - indexing_policy["excludedPaths"].append({"path": f'/"{field.name}"/*'}) # type: ignore + indexing_policy["excludedPaths"].append({"path": f'/"{field.name}"/*'}) if isinstance(field, VectorStoreRecordVectorField): - indexing_policy["vectorIndexes"].append({ # type: ignore + indexing_policy["vectorIndexes"].append({ "path": f'/"{field.name}"', "type": to_vector_index_policy_type(field.index_kind), }) # Exclude the vector field from the index for performance optimization. - indexing_policy["excludedPaths"].append({"path": f'/"{field.name}"/*'}) # type: ignore + indexing_policy["excludedPaths"].append({"path": f'/"{field.name}"/*'}) return indexing_policy @@ -126,6 +154,10 @@ def create_default_vector_embedding_policy(data_model_definition: VectorStoreRec Returns: dict[str, Any]: The vector embedding policy. + + Raises: + VectorStoreModelException: If the datatype or distance function is not supported by Azure Cosmos DB NoSQL. + """ vector_embedding_policy: dict[str, Any] = {"vectorEmbeddings": []} diff --git a/python/semantic_kernel/connectors/memory/postgres/postgres_collection.py b/python/semantic_kernel/connectors/memory/postgres/postgres_collection.py index a833efcc855a..6de863646dc3 100644 --- a/python/semantic_kernel/connectors/memory/postgres/postgres_collection.py +++ b/python/semantic_kernel/connectors/memory/postgres/postgres_collection.py @@ -11,7 +11,6 @@ from typing_extensions import override # pragma: no cover from psycopg import sql -from psycopg.errors import DatabaseError from psycopg_pool import AsyncConnectionPool from pydantic import PrivateAttr @@ -30,9 +29,9 @@ VectorStoreRecordVectorField, ) from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, +from semantic_kernel.exceptions import ( VectorStoreModelValidationError, + VectorStoreOperationException, ) from semantic_kernel.kernel_types import OneOrMany from semantic_kernel.utils.experimental_decorator import experimental_class @@ -138,46 +137,42 @@ async def _inner_upsert( The keys of the upserted records. """ if self.connection_pool is None: - raise MemoryConnectorException("Connection pool is not available, use the collection as a context manager.") + raise VectorStoreOperationException( + "Connection pool is not available, use the collection as a context manager." + ) keys = [] - try: - async with ( - self.connection_pool.connection() as conn, - conn.transaction(), - conn.cursor() as cur, - ): - # Split the records into batches - max_rows_per_transaction = self._settings.max_rows_per_transaction - for i in range(0, len(records), max_rows_per_transaction): - record_batch = records[i : i + max_rows_per_transaction] - - fields = list(self.data_model_definition.fields.items()) - - row_values = [convert_dict_to_row(record, fields) for record in record_batch] - - # Execute the INSERT statement for each batch - await cur.executemany( - sql.SQL("INSERT INTO {}.{} ({}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {}").format( - sql.Identifier(self.db_schema), - sql.Identifier(self.collection_name), - sql.SQL(", ").join(sql.Identifier(field.name) for _, field in fields), - sql.SQL(", ").join(sql.Placeholder() * len(fields)), - sql.Identifier(self.data_model_definition.key_field.name), - sql.SQL(", ").join( - sql.SQL("{field} = EXCLUDED.{field}").format(field=sql.Identifier(field.name)) - for _, field in fields - if field.name != self.data_model_definition.key_field.name - ), + async with ( + self.connection_pool.connection() as conn, + conn.transaction(), + conn.cursor() as cur, + ): + # Split the records into batches + max_rows_per_transaction = self._settings.max_rows_per_transaction + for i in range(0, len(records), max_rows_per_transaction): + record_batch = records[i : i + max_rows_per_transaction] + + fields = list(self.data_model_definition.fields.items()) + + row_values = [convert_dict_to_row(record, fields) for record in record_batch] + + # Execute the INSERT statement for each batch + await cur.executemany( + sql.SQL("INSERT INTO {}.{} ({}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {}").format( + sql.Identifier(self.db_schema), + sql.Identifier(self.collection_name), + sql.SQL(", ").join(sql.Identifier(field.name) for _, field in fields), + sql.SQL(", ").join(sql.Placeholder() * len(fields)), + sql.Identifier(self.data_model_definition.key_field.name), + sql.SQL(", ").join( + sql.SQL("{field} = EXCLUDED.{field}").format(field=sql.Identifier(field.name)) + for _, field in fields + if field.name != self.data_model_definition.key_field.name ), - row_values, - ) - keys.extend(record.get(self.data_model_definition.key_field.name) for record in record_batch) - - except DatabaseError as error: - # Rollback happens automatically if an exception occurs within the transaction block - raise MemoryConnectorException(f"Error upserting records: {error}") from error - + ), + row_values, + ) + keys.extend(record.get(self.data_model_definition.key_field.name) for record in record_batch) return keys @override @@ -192,27 +187,25 @@ async def _inner_get(self, keys: Sequence[TKey], **kwargs: Any) -> OneOrMany[dic The records from the store, not deserialized. """ if self.connection_pool is None: - raise MemoryConnectorException("Connection pool is not available, use the collection as a context manager.") + raise VectorStoreOperationException( + "Connection pool is not available, use the collection as a context manager." + ) fields = [(field.name, field) for field in self.data_model_definition.fields.values()] - try: - async with self.connection_pool.connection() as conn, conn.cursor() as cur: - await cur.execute( - sql.SQL("SELECT {} FROM {}.{} WHERE {} IN ({})").format( - sql.SQL(", ").join(sql.Identifier(name) for (name, _) in fields), - sql.Identifier(self.db_schema), - sql.Identifier(self.collection_name), - sql.Identifier(self.data_model_definition.key_field.name), - sql.SQL(", ").join(sql.Literal(key) for key in keys), - ) + async with self.connection_pool.connection() as conn, conn.cursor() as cur: + await cur.execute( + sql.SQL("SELECT {} FROM {}.{} WHERE {} IN ({})").format( + sql.SQL(", ").join(sql.Identifier(name) for (name, _) in fields), + sql.Identifier(self.db_schema), + sql.Identifier(self.collection_name), + sql.Identifier(self.data_model_definition.key_field.name), + sql.SQL(", ").join(sql.Literal(key) for key in keys), ) - rows = await cur.fetchall() - if not rows: - return None - return [convert_row_to_dict(row, fields) for row in rows] - - except DatabaseError as error: - raise MemoryConnectorException(f"Error getting records: {error}") from error + ) + rows = await cur.fetchall() + if not rows: + return None + return [convert_row_to_dict(row, fields) for row in rows] @override async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None: @@ -223,32 +216,29 @@ async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None: **kwargs: Additional arguments. """ if self.connection_pool is None: - raise MemoryConnectorException("Connection pool is not available, use the collection as a context manager.") - - try: - async with ( - self.connection_pool.connection() as conn, - conn.transaction(), - conn.cursor() as cur, - ): - # Split the keys into batches - max_rows_per_transaction = self._settings.max_rows_per_transaction - for i in range(0, len(keys), max_rows_per_transaction): - key_batch = keys[i : i + max_rows_per_transaction] - - # Execute the DELETE statement for each batch - await cur.execute( - sql.SQL("DELETE FROM {}.{} WHERE {} IN ({})").format( - sql.Identifier(self.db_schema), - sql.Identifier(self.collection_name), - sql.Identifier(self.data_model_definition.key_field.name), - sql.SQL(", ").join(sql.Literal(key) for key in key_batch), - ) - ) + raise VectorStoreOperationException( + "Connection pool is not available, use the collection as a context manager." + ) - except DatabaseError as error: - # Rollback happens automatically if an exception occurs within the transaction block - raise MemoryConnectorException(f"Error deleting records: {error}") from error + async with ( + self.connection_pool.connection() as conn, + conn.transaction(), + conn.cursor() as cur, + ): + # Split the keys into batches + max_rows_per_transaction = self._settings.max_rows_per_transaction + for i in range(0, len(keys), max_rows_per_transaction): + key_batch = keys[i : i + max_rows_per_transaction] + + # Execute the DELETE statement for each batch + await cur.execute( + sql.SQL("DELETE FROM {}.{} WHERE {} IN ({})").format( + sql.Identifier(self.db_schema), + sql.Identifier(self.collection_name), + sql.Identifier(self.data_model_definition.key_field.name), + sql.SQL(", ").join(sql.Literal(key) for key in key_batch), + ) + ) @override def _serialize_dicts_to_store_models(self, records: Sequence[dict[str, Any]], **kwargs: Any) -> Sequence[Any]: @@ -276,7 +266,9 @@ async def create_collection(self, **kwargs: Any) -> None: **kwargs: Additional arguments """ if self.connection_pool is None: - raise MemoryConnectorException("Connection pool is not available, use the collection as a context manager.") + raise VectorStoreOperationException( + "Connection pool is not available, use the collection as a context manager." + ) column_definitions = [] table_name = self.collection_name @@ -310,20 +302,16 @@ async def create_collection(self, **kwargs: Any) -> None: sql.Identifier(self.db_schema), sql.Identifier(table_name), columns_str ) - try: - async with self.connection_pool.connection() as conn, conn.cursor() as cur: - await cur.execute(create_table_query) - await conn.commit() - - logger.info(f"Postgres table '{table_name}' created successfully.") + async with self.connection_pool.connection() as conn, conn.cursor() as cur: + await cur.execute(create_table_query) + await conn.commit() - # If the vector field defines an index, apply it - for vector_field in self.data_model_definition.vector_fields: - if vector_field.index_kind: - await self._create_index(table_name, vector_field) + logger.info(f"Postgres table '{table_name}' created successfully.") - except DatabaseError as error: - raise MemoryConnectorException(f"Error creating table: {error}") from error + # If the vector field defines an index, apply it + for vector_field in self.data_model_definition.vector_fields: + if vector_field.index_kind: + await self._create_index(table_name, vector_field) async def _create_index(self, table_name: str, vector_field: VectorStoreRecordVectorField) -> None: """Create an index on a column in the table. @@ -333,14 +321,16 @@ async def _create_index(self, table_name: str, vector_field: VectorStoreRecordVe vector_field: The vector field definition that the index is based on. """ if self.connection_pool is None: - raise MemoryConnectorException("Connection pool is not available, use the collection as a context manager.") + raise VectorStoreOperationException( + "Connection pool is not available, use the collection as a context manager." + ) column_name = vector_field.name index_name = f"{table_name}_{column_name}_idx" # Only support creating HNSW indexes through the vector store if vector_field.index_kind != IndexKind.HNSW: - raise MemoryConnectorException( + raise VectorStoreOperationException( f"Unsupported index kind: {vector_field.index_kind}. " "If you need to create an index of this type, please do so manually. " "Only HNSW indexes are supported through the vector store." @@ -348,37 +338,35 @@ async def _create_index(self, table_name: str, vector_field: VectorStoreRecordVe # Require the distance function to be set for HNSW indexes if not vector_field.distance_function: - raise MemoryConnectorException( + raise VectorStoreOperationException( "Distance function must be set for HNSW indexes. " "Please set the distance function in the vector field definition." ) ops_str = get_vector_index_ops_str(vector_field.distance_function) - try: - async with self.connection_pool.connection() as conn, conn.cursor() as cur: - await cur.execute( - sql.SQL("CREATE INDEX {} ON {}.{} USING {} ({} {})").format( - sql.Identifier(index_name), - sql.Identifier(self.db_schema), - sql.Identifier(table_name), - sql.SQL(vector_field.index_kind), - sql.Identifier(column_name), - sql.SQL(ops_str), - ) + async with self.connection_pool.connection() as conn, conn.cursor() as cur: + await cur.execute( + sql.SQL("CREATE INDEX {} ON {}.{} USING {} ({} {})").format( + sql.Identifier(index_name), + sql.Identifier(self.db_schema), + sql.Identifier(table_name), + sql.SQL(vector_field.index_kind), + sql.Identifier(column_name), + sql.SQL(ops_str), ) - await conn.commit() - - logger.info(f"Index '{index_name}' created successfully on column '{column_name}'.") + ) + await conn.commit() - except DatabaseError as error: - raise MemoryConnectorException(f"Error creating index: {error}") from error + logger.info(f"Index '{index_name}' created successfully on column '{column_name}'.") @override async def does_collection_exist(self, **kwargs: Any) -> bool: """Check if the collection exists.""" if self.connection_pool is None: - raise MemoryConnectorException("Connection pool is not available, use the collection as a context manager.") + raise VectorStoreOperationException( + "Connection pool is not available, use the collection as a context manager." + ) async with self.connection_pool.connection() as conn, conn.cursor() as cur: await cur.execute( @@ -396,7 +384,9 @@ async def does_collection_exist(self, **kwargs: Any) -> bool: async def delete_collection(self, **kwargs: Any) -> None: """Delete the collection.""" if self.connection_pool is None: - raise MemoryConnectorException("Connection pool is not available, use the collection as a context manager.") + raise VectorStoreOperationException( + "Connection pool is not available, use the collection as a context manager." + ) async with self.connection_pool.connection() as conn, conn.cursor() as cur: await cur.execute( diff --git a/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py b/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py index cb30fa0cdc76..b0cfd8244299 100644 --- a/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py +++ b/python/semantic_kernel/connectors/memory/qdrant/qdrant_collection.py @@ -23,11 +23,11 @@ from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin from semantic_kernel.exceptions import ( - MemoryConnectorInitializationError, + VectorSearchExecutionException, + VectorStoreInitializationException, VectorStoreModelValidationError, + VectorStoreOperationException, ) -from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorException -from semantic_kernel.exceptions.search_exceptions import VectorSearchExecutionException from semantic_kernel.kernel_types import OneOrMany from semantic_kernel.utils.experimental_decorator import experimental_class from semantic_kernel.utils.telemetry.user_agent import APP_INFO, prepend_semantic_kernel_to_user_agent @@ -125,14 +125,14 @@ def __init__( env_file_encoding=env_file_encoding, ) except ValidationError as ex: - raise MemoryConnectorInitializationError("Failed to create Qdrant settings.", ex) from ex + raise VectorStoreInitializationException("Failed to create Qdrant settings.", ex) from ex if APP_INFO: kwargs.setdefault("metadata", {}) kwargs["metadata"] = prepend_semantic_kernel_to_user_agent(kwargs["metadata"]) try: client = AsyncQdrantClientWrapper(**settings.model_dump(exclude_none=True), **kwargs) except ValueError as ex: - raise MemoryConnectorInitializationError("Failed to create Qdrant client.", ex) from ex + raise VectorStoreInitializationException("Failed to create Qdrant client.", ex) from ex super().__init__( data_model_type=data_model_type, data_model_definition=data_model_definition, @@ -274,7 +274,7 @@ async def create_collection(self, **kwargs) -> None: vector = self.data_model_definition.fields[field] assert isinstance(vector, VectorStoreRecordVectorField) # nosec if not vector.dimensions: - raise MemoryConnectorException("Vector field must have dimensions.") + raise VectorStoreOperationException("Vector field must have dimensions.") vectors_config[field] = VectorParams( size=vector.dimensions, distance=DISTANCE_FUNCTION_MAP[vector.distance_function or "default"], @@ -284,7 +284,7 @@ async def create_collection(self, **kwargs) -> None: vector = self.data_model_definition.fields[self.data_model_definition.vector_field_names[0]] assert isinstance(vector, VectorStoreRecordVectorField) # nosec if not vector.dimensions: - raise MemoryConnectorException("Vector field must have dimensions.") + raise VectorStoreOperationException("Vector field must have dimensions.") vectors_config = VectorParams( size=vector.dimensions, distance=DISTANCE_FUNCTION_MAP[vector.distance_function or "default"], diff --git a/python/semantic_kernel/connectors/memory/qdrant/qdrant_store.py b/python/semantic_kernel/connectors/memory/qdrant/qdrant_store.py index f0fe5000cd8b..0fd00bc59532 100644 --- a/python/semantic_kernel/connectors/memory/qdrant/qdrant_store.py +++ b/python/semantic_kernel/connectors/memory/qdrant/qdrant_store.py @@ -16,7 +16,7 @@ from semantic_kernel.connectors.memory.qdrant.qdrant_collection import QdrantCollection from semantic_kernel.data.record_definition import VectorStoreRecordDefinition from semantic_kernel.data.vector_storage import VectorStore -from semantic_kernel.exceptions import MemoryConnectorInitializationError +from semantic_kernel.exceptions import VectorStoreInitializationException from semantic_kernel.utils.experimental_decorator import experimental_class from semantic_kernel.utils.telemetry.user_agent import APP_INFO, prepend_semantic_kernel_to_user_agent @@ -94,14 +94,14 @@ def __init__( env_file_encoding=env_file_encoding, ) except ValidationError as ex: - raise MemoryConnectorInitializationError("Failed to create Qdrant settings.", ex) from ex + raise VectorStoreInitializationException("Failed to create Qdrant settings.", ex) from ex if APP_INFO: kwargs.setdefault("metadata", {}) kwargs["metadata"] = prepend_semantic_kernel_to_user_agent(kwargs["metadata"]) try: client = AsyncQdrantClient(**settings.model_dump(exclude_none=True), **kwargs) except ValueError as ex: - raise MemoryConnectorInitializationError("Failed to create Qdrant client.", ex) from ex + raise VectorStoreInitializationException("Failed to create Qdrant client.", ex) from ex super().__init__(qdrant_client=client) def get_collection( diff --git a/python/semantic_kernel/connectors/memory/redis/redis_collection.py b/python/semantic_kernel/connectors/memory/redis/redis_collection.py index f4f59a1ecded..73f3a2bd4dea 100644 --- a/python/semantic_kernel/connectors/memory/redis/redis_collection.py +++ b/python/semantic_kernel/connectors/memory/redis/redis_collection.py @@ -47,11 +47,12 @@ from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult from semantic_kernel.data.vector_search.vector_text_search import VectorTextSearchMixin from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, - MemoryConnectorInitializationError, +from semantic_kernel.exceptions import ( + VectorSearchExecutionException, + VectorSearchOptionsException, + VectorStoreInitializationException, + VectorStoreOperationException, ) -from semantic_kernel.exceptions.search_exceptions import VectorSearchExecutionException, VectorSearchOptionsException from semantic_kernel.utils.experimental_decorator import experimental_class from semantic_kernel.utils.list_handler import desync_list @@ -111,7 +112,7 @@ def __init__( env_file_encoding=env_file_encoding, ) except ValidationError as ex: - raise MemoryConnectorInitializationError("Failed to create Redis settings.", ex) from ex + raise VectorStoreInitializationException("Failed to create Redis settings.", ex) from ex super().__init__( data_model_type=data_model_type, data_model_definition=data_model_definition, @@ -149,7 +150,7 @@ async def create_collection(self, **kwargs) -> None: fields, definition=index_definition, **kwargs ) return - raise MemoryConnectorException("Invalid index type supplied.") + raise VectorStoreOperationException("Invalid index type supplied.") fields = data_model_definition_to_redis_fields(self.data_model_definition, self.collection_type) index_definition = IndexDefinition( prefix=f"{self.collection_name}:", index_type=INDEX_TYPE_MAP[self.collection_type] diff --git a/python/semantic_kernel/connectors/memory/redis/redis_store.py b/python/semantic_kernel/connectors/memory/redis/redis_store.py index 4e0629ba4271..8764027e0cd8 100644 --- a/python/semantic_kernel/connectors/memory/redis/redis_store.py +++ b/python/semantic_kernel/connectors/memory/redis/redis_store.py @@ -18,7 +18,7 @@ from semantic_kernel.connectors.memory.redis.utils import RedisWrapper from semantic_kernel.data.record_definition import VectorStoreRecordDefinition from semantic_kernel.data.vector_storage import VectorStore, VectorStoreRecordCollection -from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorInitializationError +from semantic_kernel.exceptions import VectorStoreInitializationException from semantic_kernel.utils.experimental_decorator import experimental_class logger: logging.Logger = logging.getLogger(__name__) @@ -58,7 +58,7 @@ def __init__( env_file_encoding=env_file_encoding, ) except ValidationError as ex: - raise MemoryConnectorInitializationError("Failed to create Redis settings.", ex) from ex + raise VectorStoreInitializationException("Failed to create Redis settings.", ex) from ex super().__init__(redis_database=RedisWrapper.from_url(redis_settings.connection_string.get_secret_value())) @override diff --git a/python/semantic_kernel/connectors/memory/redis/utils.py b/python/semantic_kernel/connectors/memory/redis/utils.py index 68054281d156..35305c4118d8 100644 --- a/python/semantic_kernel/connectors/memory/redis/utils.py +++ b/python/semantic_kernel/connectors/memory/redis/utils.py @@ -27,7 +27,7 @@ VectorStoreRecordVectorField, ) from semantic_kernel.data.vector_search.vector_search_filter import VectorSearchFilter -from semantic_kernel.exceptions.search_exceptions import VectorSearchOptionsException +from semantic_kernel.exceptions import VectorSearchOptionsException from semantic_kernel.memory.memory_record import MemoryRecord diff --git a/python/semantic_kernel/connectors/memory/weaviate/utils.py b/python/semantic_kernel/connectors/memory/weaviate/utils.py index 363d4dae49fe..3f295d1076b7 100644 --- a/python/semantic_kernel/connectors/memory/weaviate/utils.py +++ b/python/semantic_kernel/connectors/memory/weaviate/utils.py @@ -18,7 +18,7 @@ VectorStoreRecordVectorField, ) from semantic_kernel.data.vector_search.vector_search_filter import VectorSearchFilter -from semantic_kernel.exceptions.memory_connector_exceptions import ( +from semantic_kernel.exceptions import ( VectorStoreModelDeserializationException, ) diff --git a/python/semantic_kernel/connectors/memory/weaviate/weaviate_collection.py b/python/semantic_kernel/connectors/memory/weaviate/weaviate_collection.py index 869d80e4df22..947188a3c819 100644 --- a/python/semantic_kernel/connectors/memory/weaviate/weaviate_collection.py +++ b/python/semantic_kernel/connectors/memory/weaviate/weaviate_collection.py @@ -41,10 +41,11 @@ from semantic_kernel.data.vector_search.vectorizable_text_search import VectorizableTextSearchMixin from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin from semantic_kernel.exceptions import ( - MemoryConnectorException, - MemoryConnectorInitializationError, + VectorSearchExecutionException, + VectorStoreInitializationException, + VectorStoreModelValidationError, + VectorStoreOperationException, ) -from semantic_kernel.exceptions.memory_connector_exceptions import VectorStoreModelValidationError from semantic_kernel.kernel_types import OneOrMany from semantic_kernel.utils.experimental_decorator import experimental_class @@ -140,7 +141,7 @@ def __init__( " a local Weaviate instance, or the client embedding options.", ) except Exception as e: - raise MemoryConnectorInitializationError(f"Failed to initialize Weaviate client: {e}") + raise VectorStoreInitializationException(f"Failed to initialize Weaviate client: {e}") super().__init__( data_model_type=data_model_type, @@ -170,47 +171,24 @@ async def _inner_upsert( **kwargs: Any, ) -> Sequence[TKey]: assert all([isinstance(record, DataObject) for record in records]) # nosec - - try: - collection: CollectionAsync = self.async_client.collections.get(self.collection_name) - response = await collection.data.insert_many(records) - except WeaviateClosedClientError as ex: - raise MemoryConnectorException( - "Client is closed, please use the context manager or self.async_client.connect." - ) from ex - except Exception as ex: - raise MemoryConnectorException(f"Failed to upsert records: {ex}") - + collection: CollectionAsync = self.async_client.collections.get(self.collection_name) + response = await collection.data.insert_many(records) return [str(v) for _, v in response.uuids.items()] @override async def _inner_get(self, keys: Sequence[TKey], **kwargs: Any) -> OneOrMany[Any] | None: - try: - collection: CollectionAsync = self.async_client.collections.get(self.collection_name) - result = await collection.query.fetch_objects( - filters=Filter.any_of([Filter.by_id().equal(key) for key in keys]), - include_vector=kwargs.get("include_vectors", False), - ) + collection: CollectionAsync = self.async_client.collections.get(self.collection_name) + result = await collection.query.fetch_objects( + filters=Filter.any_of([Filter.by_id().equal(key) for key in keys]), + include_vector=kwargs.get("include_vectors", False), + ) - return result.objects - except WeaviateClosedClientError as ex: - raise MemoryConnectorException( - "Client is closed, please use the context manager or self.async_client.connect." - ) from ex - except Exception as ex: - raise MemoryConnectorException(f"Failed to get records: {ex}") + return result.objects @override async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None: - try: - collection: CollectionAsync = self.async_client.collections.get(self.collection_name) - await collection.data.delete_many(where=Filter.any_of([Filter.by_id().equal(key) for key in keys])) - except WeaviateClosedClientError as ex: - raise MemoryConnectorException( - "Client is closed, please use the context manager or self.async_client.connect." - ) from ex - except Exception as ex: - raise MemoryConnectorException(f"Failed to delete records: {ex}") + collection: CollectionAsync = self.async_client.collections.get(self.collection_name) + await collection.data.delete_many(where=Filter.any_of([Filter.by_id().equal(key) for key in keys])) @override async def _inner_search( @@ -236,7 +214,7 @@ async def _inner_search( elif vector: results = await self._inner_vectorized_search(collection, vector, vector_field, args) else: - raise MemoryConnectorException("No search criteria provided.") + raise VectorSearchExecutionException("No search criteria provided.") return KernelSearchResults( results=self._get_vector_search_results_from_results(results.objects), total_count=len(results.objects) @@ -252,7 +230,7 @@ async def _inner_vector_text_search( **args, ) except Exception as ex: - raise MemoryConnectorException(f"Failed searching using a text: {ex}") from ex + raise VectorSearchExecutionException(f"Failed searching using a text: {ex}") from ex async def _inner_vectorizable_text_search( self, @@ -262,7 +240,7 @@ async def _inner_vectorizable_text_search( args: dict[str, Any], ) -> Any: if self.named_vectors and not vector_field: - raise MemoryConnectorException( + raise VectorSearchExecutionException( "Vectorizable text search requires a vector field to be specified in the options." ) try: @@ -281,7 +259,7 @@ async def _inner_vectorizable_text_search( "Alternatively you could use a existing collection that has a vectorizer setup." "See also: https://weaviate.io/developers/weaviate/manage-data/collections#create-a-collection" ) - raise MemoryConnectorException(f"Failed searching using a vectorizable text: {ex}") from ex + raise VectorSearchExecutionException(f"Failed searching using a vectorizable text: {ex}") from ex async def _inner_vectorized_search( self, @@ -291,7 +269,7 @@ async def _inner_vectorized_search( args: dict[str, Any], ) -> Any: if self.named_vectors and not vector_field: - raise MemoryConnectorException( + raise VectorSearchExecutionException( "Vectorizable text search requires a vector field to be specified in the options." ) try: @@ -302,11 +280,11 @@ async def _inner_vectorized_search( **args, ) except WeaviateClosedClientError as ex: - raise MemoryConnectorException( + raise VectorSearchExecutionException( "Client is closed, please use the context manager or self.async_client.connect." ) from ex except Exception as ex: - raise MemoryConnectorException(f"Failed searching using a vector: {ex}") from ex + raise VectorSearchExecutionException(f"Failed searching using a vector: {ex}") from ex def _get_record_from_result(self, result: Any) -> Any: """Get the record from the returned search result.""" @@ -361,18 +339,18 @@ async def create_collection(self, **kwargs) -> None: Make sure to check the arguments of that method for the specifications. """ if not self.named_vectors and len(self.data_model_definition.vector_field_names) != 1: - raise MemoryConnectorException( + raise VectorStoreOperationException( "Named vectors must be enabled if there is not exactly one vector field in the data model definition." ) if kwargs: try: await self.async_client.collections.create(**kwargs) except WeaviateClosedClientError as ex: - raise MemoryConnectorException( + raise VectorStoreOperationException( "Client is closed, please use the context manager or self.async_client.connect." ) from ex except Exception as ex: - raise MemoryConnectorException(f"Failed to create collection: {ex}") from ex + raise VectorStoreOperationException(f"Failed to create collection: {ex}") from ex try: await self.async_client.collections.create( name=self.collection_name, @@ -387,11 +365,11 @@ async def create_collection(self, **kwargs) -> None: else None, ) except WeaviateClosedClientError as ex: - raise MemoryConnectorException( + raise VectorStoreOperationException( "Client is closed, please use the context manager or self.async_client.connect." ) from ex except Exception as ex: - raise MemoryConnectorException(f"Failed to create collection: {ex}") from ex + raise VectorStoreOperationException(f"Failed to create collection: {ex}") from ex @override async def does_collection_exist(self, **kwargs) -> bool: @@ -406,11 +384,11 @@ async def does_collection_exist(self, **kwargs) -> bool: try: return await self.async_client.collections.exists(self.collection_name) except WeaviateClosedClientError as ex: - raise MemoryConnectorException( + raise VectorStoreOperationException( "Client is closed, please use the context manager or self.async_client.connect." ) from ex except Exception as ex: - raise MemoryConnectorException(f"Failed to check if collection exists: {ex}") from ex + raise VectorStoreOperationException(f"Failed to check if collection exists: {ex}") from ex @override async def delete_collection(self, **kwargs) -> None: @@ -422,11 +400,11 @@ async def delete_collection(self, **kwargs) -> None: try: await self.async_client.collections.delete(self.collection_name) except WeaviateClosedClientError as ex: - raise MemoryConnectorException( + raise VectorStoreOperationException( "Client is closed, please use the context manager or self.async_client.connect." ) from ex except Exception as ex: - raise MemoryConnectorException(f"Failed to delete collection: {ex}") from ex + raise VectorStoreOperationException(f"Failed to delete collection: {ex}") from ex @override async def __aenter__(self) -> "WeaviateCollection": diff --git a/python/semantic_kernel/connectors/memory/weaviate/weaviate_store.py b/python/semantic_kernel/connectors/memory/weaviate/weaviate_store.py index 82c05c769361..9d57a6e588c2 100644 --- a/python/semantic_kernel/connectors/memory/weaviate/weaviate_store.py +++ b/python/semantic_kernel/connectors/memory/weaviate/weaviate_store.py @@ -19,9 +19,9 @@ from semantic_kernel.data.vector_storage.vector_store import VectorStore from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection from semantic_kernel.exceptions import ( - MemoryConnectorConnectionException, - MemoryConnectorException, - MemoryConnectorInitializationError, + VectorStoreException, + VectorStoreInitializationException, + VectorStoreOperationException, ) from semantic_kernel.utils.experimental_decorator import experimental_class @@ -95,7 +95,7 @@ def __init__( " a local Weaviate instance, or the client embedding options.", ) except Exception as e: - raise MemoryConnectorInitializationError(f"Failed to initialize Weaviate client: {e}") + raise VectorStoreInitializationException(f"Failed to initialize Weaviate client: {e}") super().__init__(async_client=async_client, managed_client=managed_client) @@ -124,7 +124,7 @@ async def list_collection_names(self, **kwargs) -> Sequence[str]: collections = await self.async_client.collections.list_all() return [collection.name for collection in collections] except Exception as e: - raise MemoryConnectorException(f"Failed to list Weaviate collections: {e}") + raise VectorStoreOperationException(f"Failed to list Weaviate collections: {e}") @override async def __aenter__(self) -> "VectorStore": @@ -133,7 +133,7 @@ async def __aenter__(self) -> "VectorStore": try: await self.async_client.connect() except WeaviateConnectionError as exc: - raise MemoryConnectorConnectionException("Weaviate client cannot connect.") from exc + raise VectorStoreException("Weaviate client cannot connect.") from exc return self @override diff --git a/python/semantic_kernel/data/record_definition/vector_store_model_decorator.py b/python/semantic_kernel/data/record_definition/vector_store_model_decorator.py index 6d5a89719292..d9244076bcbb 100644 --- a/python/semantic_kernel/data/record_definition/vector_store_model_decorator.py +++ b/python/semantic_kernel/data/record_definition/vector_store_model_decorator.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. import logging -from inspect import _empty, signature -from types import NoneType +from inspect import Parameter, _empty, signature +from types import MappingProxyType, NoneType from typing import Any from semantic_kernel.data.record_definition.vector_store_model_definition import VectorStoreRecordDefinition @@ -10,7 +10,7 @@ VectorStoreRecordField, VectorStoreRecordVectorField, ) -from semantic_kernel.exceptions.memory_connector_exceptions import VectorStoreModelException +from semantic_kernel.exceptions import VectorStoreModelException from semantic_kernel.utils.experimental_decorator import experimental_function logger = logging.getLogger(__name__) @@ -27,6 +27,7 @@ def vectorstoremodel( - The class must have at least one field with a annotation, of type VectorStoreRecordKeyField, VectorStoreRecordDataField or VectorStoreRecordVectorField. - The class must have exactly one field with the VectorStoreRecordKeyField annotation. + - A field with multiple VectorStoreRecordKeyField annotations will be set to the first one found. Optionally, when there are VectorStoreRecordDataFields that specify a embedding property name, there must be a corresponding VectorStoreRecordVectorField with the same name. @@ -59,63 +60,66 @@ def wrap(cls: Any): return wrap(cls) -def _parse_signature_to_definition(parameters) -> VectorStoreRecordDefinition: +def _parse_signature_to_definition(parameters: MappingProxyType[str, Parameter]) -> VectorStoreRecordDefinition: if len(parameters) == 0: raise VectorStoreModelException( "There must be at least one field in the datamodel. If you are using this with a @dataclass, " "you might have inverted the order of the decorators, the vectorstoremodel decorator should be the top one." ) - fields: dict[str, VectorStoreRecordField] = {} - for field in parameters.values(): - annotation = field.annotation - # check first if there are any annotations - if not hasattr(annotation, "__metadata__"): - if field._default is _empty: - raise VectorStoreModelException( - "Fields that do not have a VectorStoreRecord* annotation must have a default value." - ) - logger.info( - f'Field "{field.name}" does not have a VectorStoreRecord* annotation, will not be part of the record.' - ) - continue - property_type = annotation.__origin__ + return VectorStoreRecordDefinition( + fields={ + field.name: field for field in [_parse_parameter_to_field(field) for field in parameters.values()] if field + } + ) + + +def _parse_parameter_to_field(field: Parameter) -> VectorStoreRecordField | None: + for field_annotation in getattr(field.annotation, "__metadata__", []): + if isinstance(field_annotation, VectorStoreRecordField): + return _parse_vector_store_record_field_instance(field_annotation, field) + if isinstance(field_annotation, type(VectorStoreRecordField)): + return _parse_vector_store_record_field_class(field_annotation, field) + + # This means there are no annotations or non VectorStoreRecordField annotations + if field.default is _empty: + raise VectorStoreModelException( + "Fields that do not have a VectorStoreRecord* annotation must have a default value." + ) + logger.debug( + f'Field "{field.name}" does not have a VectorStoreRecordField annotation, will not be part of the record.' + ) + return None + + +def _parse_vector_store_record_field_instance( + record_field: VectorStoreRecordField, field: Parameter +) -> VectorStoreRecordField: + if not record_field.name or record_field.name != field.name: + record_field.name = field.name + if not record_field.property_type: + property_type = field.annotation.__origin__ if (args := getattr(property_type, "__args__", None)) and NoneType in args and len(args) == 2: property_type = args[0] - metadata = annotation.__metadata__ - field_type = None - for item in metadata: - if isinstance(item, VectorStoreRecordField): - field_type = item - if not field_type.name or field_type.name != field.name: - field_type.name = field.name - if not field_type.property_type: - if hasattr(property_type, "__args__"): - if isinstance(item, VectorStoreRecordVectorField): - field_type.property_type = property_type.__args__[0].__name__ - elif property_type.__name__ == "list": - field_type.property_type = f"{property_type.__name__}[{property_type.__args__[0].__name__}]" - else: - field_type.property_type = property_type.__name__ - - else: - field_type.property_type = property_type.__name__ - elif isinstance(item, type(VectorStoreRecordField)): - if hasattr(property_type, "__args__") and property_type.__name__ == "list": - property_type_name = f"{property_type.__name__}[{property_type.__args__[0].__name__}]" - else: - property_type_name = property_type.__name__ - field_type = item(name=field.name, property_type=property_type_name) - if not field_type: - if field._default is _empty: - raise VectorStoreModelException( - "Fields that do not have a VectorStoreRecord* annotation must have a default value." - ) - logger.debug( - f'Field "{field.name}" does not have a VectorStoreRecordField ' - "annotation, will not be part of the record." - ) - continue - # field name is set either when not None or by instantiating a new field - assert field_type.name is not None # nosec - fields[field_type.name] = field_type - return VectorStoreRecordDefinition(fields=fields) + if hasattr(property_type, "__args__"): + if isinstance(record_field, VectorStoreRecordVectorField): + record_field.property_type = property_type.__args__[0].__name__ + elif property_type.__name__ == "list": + record_field.property_type = f"{property_type.__name__}[{property_type.__args__[0].__name__}]" + else: + record_field.property_type = property_type.__name__ + else: + record_field.property_type = property_type.__name__ + return record_field + + +def _parse_vector_store_record_field_class( + field_type: type[VectorStoreRecordField], field: Parameter +) -> VectorStoreRecordField: + property_type = field.annotation.__origin__ + if (args := getattr(property_type, "__args__", None)) and NoneType in args and len(args) == 2: + property_type = args[0] + if hasattr(property_type, "__args__") and property_type.__name__ == "list": + property_type_name = f"{property_type.__name__}[{property_type.__args__[0].__name__}]" + else: + property_type_name = property_type.__name__ + return field_type(name=field.name, property_type=property_type_name) diff --git a/python/semantic_kernel/data/record_definition/vector_store_model_definition.py b/python/semantic_kernel/data/record_definition/vector_store_model_definition.py index 5d7e01faa42b..adc993ff22e3 100644 --- a/python/semantic_kernel/data/record_definition/vector_store_model_definition.py +++ b/python/semantic_kernel/data/record_definition/vector_store_model_definition.py @@ -4,10 +4,10 @@ from typing import TypeVar from semantic_kernel.data.record_definition.vector_store_model_protocols import ( - DeserializeProtocol, - FromDictProtocol, - SerializeProtocol, - ToDictProtocol, + DeserializeFunctionProtocol, + FromDictFunctionProtocol, + SerializeFunctionProtocol, + ToDictFunctionProtocol, ) from semantic_kernel.data.record_definition.vector_store_record_fields import ( VectorStoreRecordDataField, @@ -15,7 +15,7 @@ VectorStoreRecordKeyField, VectorStoreRecordVectorField, ) -from semantic_kernel.exceptions.memory_connector_exceptions import VectorStoreModelException +from semantic_kernel.exceptions import VectorStoreModelException from semantic_kernel.utils.experimental_decorator import experimental_class VectorStoreRecordFields = TypeVar("VectorStoreRecordFields", bound=VectorStoreRecordField) @@ -40,10 +40,10 @@ class VectorStoreRecordDefinition: key_field_name: str = field(init=False) fields: FieldsType container_mode: bool = False - to_dict: ToDictProtocol | None = None - from_dict: FromDictProtocol | None = None - serialize: SerializeProtocol | None = None - deserialize: DeserializeProtocol | None = None + to_dict: ToDictFunctionProtocol | None = None + from_dict: FromDictFunctionProtocol | None = None + serialize: SerializeFunctionProtocol | None = None + deserialize: DeserializeFunctionProtocol | None = None @property def field_names(self) -> list[str]: @@ -126,7 +126,7 @@ def __post_init__(self): for name, value in self.fields.items(): if not name: raise VectorStoreModelException("Fields must have a name.") - if value.name is None: + if not value.name: value.name = name if ( isinstance(value, VectorStoreRecordDataField) diff --git a/python/semantic_kernel/data/record_definition/vector_store_model_protocols.py b/python/semantic_kernel/data/record_definition/vector_store_model_protocols.py index 452827857b48..0c83a131c965 100644 --- a/python/semantic_kernel/data/record_definition/vector_store_model_protocols.py +++ b/python/semantic_kernel/data/record_definition/vector_store_model_protocols.py @@ -10,8 +10,8 @@ @experimental_class @runtime_checkable -class VectorStoreModelFunctionSerdeProtocol(Protocol): - """Data model serialization and deserialization protocol. +class SerializeMethodProtocol(Protocol): + """Data model serialization protocol. This can optionally be implemented to allow single step serialization and deserialization for using your data model with a specific datastore. @@ -21,46 +21,21 @@ def serialize(self, **kwargs: Any) -> Any: """Serialize the object to the format required by the data store.""" ... # pragma: no cover - @classmethod - def deserialize(cls: type[TModel], obj: Any, **kwargs: Any) -> TModel: - """Deserialize the output of the data store to an object.""" - ... # pragma: no cover - @experimental_class @runtime_checkable -class VectorStoreModelPydanticProtocol(Protocol): - """Class used internally to make sure a datamodel has model_dump and model_validate.""" - - def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]: - """Serialize the object to the format required by the data store.""" - ... # pragma: no cover - - @classmethod - def model_validate(cls: type[TModel], *args: Any, **kwargs: Any) -> TModel: - """Deserialize the output of the data store to an object.""" - ... # pragma: no cover - - -@experimental_class -@runtime_checkable -class VectorStoreModelToDictFromDictProtocol(Protocol): - """Class used internally to check if a model has to_dict and from_dict methods.""" +class ToDictMethodProtocol(Protocol): + """Class used internally to check if a model has a to_dict method.""" def to_dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: """Serialize the object to the format required by the data store.""" ... # pragma: no cover - @classmethod - def from_dict(cls: type[TModel], *args: Any, **kwargs: Any) -> TModel: - """Deserialize the output of the data store to an object.""" - ... # pragma: no cover - @experimental_class @runtime_checkable -class ToDictProtocol(Protocol): - """Protocol for to_dict method. +class ToDictFunctionProtocol(Protocol): + """Protocol for to_dict function. Args: record: The record to be serialized. @@ -75,8 +50,8 @@ def __call__(self, record: Any, **kwargs: Any) -> Sequence[dict[str, Any]]: ... @experimental_class @runtime_checkable -class FromDictProtocol(Protocol): - """Protocol for from_dict method. +class FromDictFunctionProtocol(Protocol): + """Protocol for from_dict function. Args: records: A list of dictionaries. @@ -91,8 +66,8 @@ def __call__(self, records: Sequence[dict[str, Any]], **kwargs: Any) -> Any: ... @experimental_class @runtime_checkable -class SerializeProtocol(Protocol): - """Protocol for serialize method. +class SerializeFunctionProtocol(Protocol): + """Protocol for serialize function. Args: record: The record to be serialized. @@ -108,8 +83,8 @@ def __call__(self, record: Any, **kwargs: Any) -> Any: ... # noqa: D102 @experimental_class @runtime_checkable -class DeserializeProtocol(Protocol): - """Protocol for deserialize method. +class DeserializeFunctionProtocol(Protocol): + """Protocol for deserialize function. Args: records: The serialized record directly from the store. diff --git a/python/semantic_kernel/data/record_definition/vector_store_record_fields.py b/python/semantic_kernel/data/record_definition/vector_store_record_fields.py index d3b782c49912..536482b1069d 100644 --- a/python/semantic_kernel/data/record_definition/vector_store_record_fields.py +++ b/python/semantic_kernel/data/record_definition/vector_store_record_fields.py @@ -17,7 +17,7 @@ class VectorStoreRecordField(ABC): """Base class for all Vector Store Record Fields.""" - name: str | None = None + name: str = "" property_type: str | None = None diff --git a/python/semantic_kernel/data/record_definition/vector_store_record_utils.py b/python/semantic_kernel/data/record_definition/vector_store_record_utils.py index 96494ff6239b..00436fe8e199 100644 --- a/python/semantic_kernel/data/record_definition/vector_store_record_utils.py +++ b/python/semantic_kernel/data/record_definition/vector_store_record_utils.py @@ -8,7 +8,7 @@ VectorStoreRecordDataField, VectorStoreRecordVectorField, ) -from semantic_kernel.exceptions.memory_connector_exceptions import VectorStoreModelException +from semantic_kernel.exceptions import VectorStoreModelException from semantic_kernel.kernel_types import OneOrMany from semantic_kernel.utils.experimental_decorator import experimental_class diff --git a/python/semantic_kernel/data/search_filter.py b/python/semantic_kernel/data/search_filter.py index a6b844e141f8..4d0d84b5a7b9 100644 --- a/python/semantic_kernel/data/search_filter.py +++ b/python/semantic_kernel/data/search_filter.py @@ -4,9 +4,9 @@ from typing import TypeVar if sys.version_info >= (3, 11): - from typing import Self + from typing import Self # pragma: no cover else: - from typing_extensions import Self + from typing_extensions import Self # pragma: no cover from semantic_kernel.data.filter_clauses.equal_to_filter_clause import EqualTo from semantic_kernel.data.filter_clauses.filter_clause_base import FilterClauseBase diff --git a/python/semantic_kernel/data/text_search/text_search.py b/python/semantic_kernel/data/text_search/text_search.py index ff7b6c416435..d40f7169786c 100644 --- a/python/semantic_kernel/data/text_search/text_search.py +++ b/python/semantic_kernel/data/text_search/text_search.py @@ -294,7 +294,7 @@ async def _map_results( return [self._default_map_to_string(result) async for result in results.results] @staticmethod - def _default_map_to_string(result: Any) -> str: + def _default_map_to_string(result: BaseModel | object) -> str: """Default mapping function for text search results.""" if isinstance(result, BaseModel): return result.model_dump_json() diff --git a/python/semantic_kernel/data/text_search/utils.py b/python/semantic_kernel/data/text_search/utils.py index eb60f87b3d82..44e432f1ec36 100644 --- a/python/semantic_kernel/data/text_search/utils.py +++ b/python/semantic_kernel/data/text_search/utils.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. +import logging +from contextlib import suppress from typing import TYPE_CHECKING, Any, Protocol from pydantic import ValidationError @@ -8,6 +10,8 @@ from semantic_kernel.data.search_options import SearchOptions from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata +logger = logging.getLogger(__name__) + class OptionsUpdateFunctionType(Protocol): """Type definition for the options update function in Text Search.""" @@ -20,7 +24,7 @@ def __call__( **kwargs: Any, ) -> tuple[str, "SearchOptions"]: """Signature of the function.""" - ... + ... # pragma: no cover def create_options( @@ -44,30 +48,24 @@ def create_options( SearchOptions: The options. """ + new_options = options_class() if options: if not isinstance(options, options_class): + inputs = None try: - # Validate the options in one go - new_options = options_class.model_validate( - options.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True), - ) - except ValidationError: - # if that fails, go one by one - new_options = options_class() - for key, value in options.model_dump( - exclude_none=True, exclude_defaults=True, exclude_unset=True - ).items(): - setattr(new_options, key, value) + inputs = options.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True) + except Exception: + logger.warning("Options are not valid. Creating new options.") + if inputs: + new_options = options_class.model_validate(inputs) else: new_options = options for key, value in kwargs.items(): if key in new_options.model_fields: setattr(new_options, key, value) else: - try: + with suppress(ValidationError): new_options = options_class(**kwargs) - except ValidationError: - new_options = options_class() return new_options diff --git a/python/semantic_kernel/data/text_search/vector_store_text_search.py b/python/semantic_kernel/data/text_search/vector_store_text_search.py index dd2242650132..688e6f763e71 100644 --- a/python/semantic_kernel/data/text_search/vector_store_text_search.py +++ b/python/semantic_kernel/data/text_search/vector_store_text_search.py @@ -14,7 +14,7 @@ from semantic_kernel.data.vector_search.vector_text_search import VectorTextSearchMixin from semantic_kernel.data.vector_search.vectorizable_text_search import VectorizableTextSearchMixin from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin -from semantic_kernel.exceptions import VectorSearchExecutionException, VectorStoreTextSearchValidationError +from semantic_kernel.exceptions import VectorSearchExecutionException, VectorStoreInitializationException from semantic_kernel.kernel_pydantic import KernelBaseModel if TYPE_CHECKING: @@ -59,11 +59,11 @@ def _validate_stores(cls, data: dict[str, Any]) -> dict[str, Any]: and not data.get("vectorized_search") and not data.get("vector_text_search") ): - raise VectorStoreTextSearchValidationError( + raise VectorStoreInitializationException( "At least one of vectorizable_text_search, vectorized_search or vector_text_search is required." ) if data.get("vectorized_search") and not data.get("embedding_service"): - raise VectorStoreTextSearchValidationError("embedding_service is required when using vectorized_search.") + raise VectorStoreInitializationException("embedding_service is required when using vectorized_search.") return data @classmethod diff --git a/python/semantic_kernel/data/vector_search/vector_search.py b/python/semantic_kernel/data/vector_search/vector_search.py index 935a3b02464f..166676136ef9 100644 --- a/python/semantic_kernel/data/vector_search/vector_search.py +++ b/python/semantic_kernel/data/vector_search/vector_search.py @@ -10,6 +10,7 @@ from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection +from semantic_kernel.exceptions import VectorStoreModelDeserializationException from semantic_kernel.utils.experimental_decorator import experimental_class from semantic_kernel.utils.list_handler import desync_list @@ -57,6 +58,9 @@ async def _inner_search( The implementation of this method must deal with the possibility that multiple search contents are provided, and should handle them in a way that makes sense for that particular store. + The public methods will catch and reraise the three exceptions mentioned below, others are caught and turned + into a VectorSearchExecutionException. + Args: options: The search options, can be None. search_text: The text to search for, optional. @@ -67,6 +71,11 @@ async def _inner_search( Returns: The search results, wrapped in a KernelSearchResults object. + Raises: + VectorSearchExecutionException: If an error occurs during the search. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + VectorSearchOptionsException: If the search options are invalid. + """ ... @@ -106,9 +115,16 @@ async def _get_vector_search_results_from_results( if isinstance(results, Sequence): results = desync_list(results) async for result in results: - record = self.deserialize( - self._get_record_from_result(result), include_vectors=options.include_vectors if options else True - ) + try: + record = self.deserialize( + self._get_record_from_result(result), include_vectors=options.include_vectors if options else True + ) + except VectorStoreModelDeserializationException: + raise + except Exception as exc: + raise VectorStoreModelDeserializationException( + f"An error occurred while deserializing the record: {exc}" + ) from exc score = self._get_score_from_result(result) if record: # single records are always returned as single records by the deserializer diff --git a/python/semantic_kernel/data/vector_search/vector_search_filter.py b/python/semantic_kernel/data/vector_search/vector_search_filter.py index c7bd99d6797e..6944fe69ba4d 100644 --- a/python/semantic_kernel/data/vector_search/vector_search_filter.py +++ b/python/semantic_kernel/data/vector_search/vector_search_filter.py @@ -3,9 +3,9 @@ import sys if sys.version_info >= (3, 11): - from typing import Self + from typing import Self # pragma: no cover else: - from typing_extensions import Self + from typing_extensions import Self # pragma: no cover from semantic_kernel.data.filter_clauses.any_tags_equal_to_filter_clause import AnyTagsEqualTo from semantic_kernel.data.search_filter import SearchFilter diff --git a/python/semantic_kernel/data/vector_search/vector_text_search.py b/python/semantic_kernel/data/vector_search/vector_text_search.py index a5445c62e16c..f2a29b2908b8 100644 --- a/python/semantic_kernel/data/vector_search/vector_text_search.py +++ b/python/semantic_kernel/data/vector_search/vector_text_search.py @@ -5,7 +5,12 @@ from semantic_kernel.data.search_options import SearchOptions from semantic_kernel.data.text_search.utils import create_options -from semantic_kernel.exceptions import VectorStoreMixinException +from semantic_kernel.exceptions import ( + VectorSearchExecutionException, + VectorSearchOptionsException, + VectorStoreMixinException, + VectorStoreModelDeserializationException, +) from semantic_kernel.utils.experimental_decorator import experimental_class if TYPE_CHECKING: @@ -35,8 +40,10 @@ async def text_search( **kwargs: if options are not set, this is used to create them. Raises: - VectorSearchOptionsException: raised when the options given are not correct. - SearchResultEmptyError: raised when there are no results returned. + VectorSearchExecutionException: If an error occurs during the search. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + VectorSearchOptionsException: If the search options are invalid. + VectorStoreMixinException: raised when the method is not used in combination with the VectorSearchBase. """ from semantic_kernel.data.vector_search.vector_search import VectorSearchBase @@ -44,4 +51,9 @@ async def text_search( if not isinstance(self, VectorSearchBase): raise VectorStoreMixinException("This method can only be used in combination with the VectorSearchBase.") options = create_options(self.options_class, options, **kwargs) - return await self._inner_search(search_text=search_text, options=options) # type: ignore + try: + return await self._inner_search(search_text=search_text, options=options) # type: ignore + except (VectorStoreModelDeserializationException, VectorSearchOptionsException, VectorSearchExecutionException): + raise # pragma: no cover + except Exception as exc: + raise VectorSearchExecutionException(f"An error occurred during the search: {exc}") from exc diff --git a/python/semantic_kernel/data/vector_search/vectorizable_text_search.py b/python/semantic_kernel/data/vector_search/vectorizable_text_search.py index 76e653a5b053..9c5b882cf6f4 100644 --- a/python/semantic_kernel/data/vector_search/vectorizable_text_search.py +++ b/python/semantic_kernel/data/vector_search/vectorizable_text_search.py @@ -4,7 +4,12 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar from semantic_kernel.data.text_search.utils import create_options -from semantic_kernel.exceptions import VectorStoreMixinException +from semantic_kernel.exceptions import ( + VectorSearchExecutionException, + VectorSearchOptionsException, + VectorStoreMixinException, + VectorStoreModelDeserializationException, +) from semantic_kernel.utils.experimental_decorator import experimental_class if TYPE_CHECKING: @@ -40,8 +45,10 @@ async def vectorizable_text_search( **kwargs: if options are not set, this is used to create them. Raises: - VectorSearchOptionsException: raised when the options given are not correct. - SearchResultEmptyError: raised when there are no results returned. + VectorSearchExecutionException: If an error occurs during the search. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + VectorSearchOptionsException: If the search options are invalid. + VectorStoreMixinException: raised when the method is not used in combination with the VectorSearchBase. """ from semantic_kernel.data.vector_search.vector_search import VectorSearchBase @@ -49,4 +56,9 @@ async def vectorizable_text_search( if not isinstance(self, VectorSearchBase): raise VectorStoreMixinException("This method can only be used in combination with the VectorSearchBase.") options = create_options(self.options_class, options, **kwargs) - return await self._inner_search(vectorizable_text=vectorizable_text, options=options) # type: ignore + try: + return await self._inner_search(vectorizable_text=vectorizable_text, options=options) # type: ignore + except (VectorStoreModelDeserializationException, VectorSearchOptionsException, VectorSearchExecutionException): + raise # pragma: no cover + except Exception as exc: + raise VectorSearchExecutionException(f"An error occurred during the search: {exc}") from exc diff --git a/python/semantic_kernel/data/vector_search/vectorized_search.py b/python/semantic_kernel/data/vector_search/vectorized_search.py index ae170840eb8a..1b3e5aa25f9e 100644 --- a/python/semantic_kernel/data/vector_search/vectorized_search.py +++ b/python/semantic_kernel/data/vector_search/vectorized_search.py @@ -4,7 +4,12 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar from semantic_kernel.data.text_search.utils import create_options -from semantic_kernel.exceptions import VectorStoreMixinException +from semantic_kernel.exceptions import ( + VectorSearchExecutionException, + VectorSearchOptionsException, + VectorStoreMixinException, + VectorStoreModelDeserializationException, +) from semantic_kernel.utils.experimental_decorator import experimental_class if TYPE_CHECKING: @@ -35,8 +40,10 @@ async def vectorized_search( **kwargs: if options are not set, this is used to create them. Raises: - VectorSearchOptionsException: raised when the options given are not correct. - SearchResultEmptyError: raised when there are no results returned. + VectorSearchExecutionException: If an error occurs during the search. + VectorStoreModelDeserializationException: If an error occurs during deserialization. + VectorSearchOptionsException: If the search options are invalid. + VectorStoreMixinException: raised when the method is not used in combination with the VectorSearchBase. """ from semantic_kernel.data.vector_search.vector_search import VectorSearchBase @@ -44,4 +51,9 @@ async def vectorized_search( if not isinstance(self, VectorSearchBase): raise VectorStoreMixinException("This method can only be used in combination with the VectorSearchBase.") options = create_options(self.options_class, options, **kwargs) - return await self._inner_search(vector=vector, options=options) # type: ignore + try: + return await self._inner_search(vector=vector, options=options) # type: ignore + except (VectorStoreModelDeserializationException, VectorSearchOptionsException, VectorSearchExecutionException): + raise # pragma: no cover + except Exception as exc: + raise VectorSearchExecutionException(f"An error occurred during the search: {exc}") from exc diff --git a/python/semantic_kernel/data/vector_storage/vector_store.py b/python/semantic_kernel/data/vector_storage/vector_store.py index 5608e6ec174d..796973a63854 100644 --- a/python/semantic_kernel/data/vector_storage/vector_store.py +++ b/python/semantic_kernel/data/vector_storage/vector_store.py @@ -48,4 +48,4 @@ async def __aexit__(self, exc_type, exc_value, traceback) -> None: If the client is passed in the constructor, it should not be closed, in that case the managed_client should be set to False. """ - pass + pass # pragma: no cover diff --git a/python/semantic_kernel/data/vector_storage/vector_store_record_collection.py b/python/semantic_kernel/data/vector_storage/vector_store_record_collection.py index 5c239db849b6..2774ebfb2fd8 100644 --- a/python/semantic_kernel/data/vector_storage/vector_store_record_collection.py +++ b/python/semantic_kernel/data/vector_storage/vector_store_record_collection.py @@ -7,14 +7,18 @@ from collections.abc import Awaitable, Callable, Mapping, Sequence from typing import Any, ClassVar, Generic, TypeVar -from pydantic import model_validator +from pydantic import BaseModel, model_validator from semantic_kernel.data.record_definition.vector_store_model_definition import VectorStoreRecordDefinition -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, +from semantic_kernel.data.record_definition.vector_store_model_protocols import ( + SerializeMethodProtocol, + ToDictMethodProtocol, +) +from semantic_kernel.exceptions import ( VectorStoreModelDeserializationException, VectorStoreModelSerializationException, VectorStoreModelValidationError, + VectorStoreOperationException, ) from semantic_kernel.kernel_pydantic import KernelBaseModel from semantic_kernel.kernel_types import OneOrMany @@ -96,6 +100,15 @@ async def _inner_upsert( Returns: The keys of the upserted records. + + Raises: + Exception: If an error occurs during the upsert. + There is no need to catch and parse exceptions in the inner functions, + they are handled by the public methods. + The only exception is raises exceptions yourself, such as a ValueError. + This is then caught and turned into the relevant exception by the public method. + This setup promotes a limited depth of the stack trace. + """ ... # pragma: no cover @@ -109,6 +122,14 @@ async def _inner_get(self, keys: Sequence[TKey], **kwargs: Any) -> OneOrMany[Any Returns: The records from the store, not deserialized. + + Raises: + Exception: If an error occurs during the upsert. + There is no need to catch and parse exceptions in the inner functions, + they are handled by the public methods. + The only exception is raises exceptions yourself, such as a ValueError. + This is then caught and turned into the relevant exception by the public method. + This setup promotes a limited depth of the stack trace. """ ... # pragma: no cover @@ -119,6 +140,14 @@ async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None: Args: keys: The keys. **kwargs: Additional arguments. + + Raises: + Exception: If an error occurs during the upsert. + There is no need to catch and parse exceptions in the inner functions, + they are handled by the public methods. + The only exception is raises exceptions yourself, such as a ValueError. + This is then caught and turned into the relevant exception by the public method. + This setup promotes a limited depth of the stack trace. """ ... # pragma: no cover @@ -184,17 +213,39 @@ async def create_collection_if_not_exists(self, **kwargs: Any) -> bool: @abstractmethod async def create_collection(self, **kwargs: Any) -> None: - """Create the collection in the service.""" + """Create the collection in the service. + + This should be overridden by the child class. + + Raises: + Make sure the implementation of this function raises relevant exceptions with good descriptions. + This is different then the `_inner_x` methods, as this is a public method. + + """ ... # pragma: no cover @abstractmethod async def does_collection_exist(self, **kwargs: Any) -> bool: - """Check if the collection exists.""" + """Check if the collection exists. + + This should be overridden by the child class. + + Raises: + Make sure the implementation of this function raises relevant exceptions with good descriptions. + This is different then the `_inner_x` methods, as this is a public method. + """ ... # pragma: no cover @abstractmethod async def delete_collection(self, **kwargs: Any) -> None: - """Delete the collection.""" + """Delete the collection. + + This should be overridden by the child class. + + Raises: + Make sure the implementation of this function raises relevant exceptions with good descriptions. + This is different then the `_inner_x` methods, as this is a public method. + """ ... # pragma: no cover # region Public Methods @@ -223,19 +274,24 @@ async def upsert( Returns: The key of the upserted record or a list of keys, when a container type is used. + + Raises: + VectorStoreModelSerializationException: If an error occurs during serialization. + VectorStoreOperationException: If an error occurs during upserting. """ if embedding_generation_function: record = await embedding_generation_function(record, self.data_model_type, self.data_model_definition) try: data = self.serialize(record) - except Exception as exc: - raise MemoryConnectorException(f"Error serializing record: {exc}") from exc + # the serialize method will parse any exception into a VectorStoreModelSerializationException + except VectorStoreModelSerializationException: + raise try: results = await self._inner_upsert(data if isinstance(data, Sequence) else [data], **kwargs) except Exception as exc: - raise MemoryConnectorException(f"Error upserting record: {exc}") from exc + raise VectorStoreOperationException(f"Error upserting record: {exc}") from exc if self._container_mode: return results @@ -266,19 +322,24 @@ async def upsert_batch( Returns: Sequence[TKey]: The keys of the upserted records, this is always a list, corresponds to the input or the items in the container. + + Raises: + VectorStoreModelSerializationException: If an error occurs during serialization. + VectorStoreOperationException: If an error occurs during upserting. """ if embedding_generation_function: records = await embedding_generation_function(records, self.data_model_type, self.data_model_definition) try: data = self.serialize(records) - except Exception as exc: - raise MemoryConnectorException(f"Error serializing records: {exc}") from exc + # the serialize method will parse any exception into a VectorStoreModelSerializationException + except VectorStoreModelSerializationException: + raise try: return await self._inner_upsert(data, **kwargs) # type: ignore except Exception as exc: - raise MemoryConnectorException(f"Error upserting records: {exc}") from exc + raise VectorStoreOperationException(f"Error upserting records: {exc}") from exc async def get(self, key: TKey, include_vectors: bool = True, **kwargs: Any) -> TModel | None: """Get a record if the key exists. @@ -293,19 +354,24 @@ async def get(self, key: TKey, include_vectors: bool = True, **kwargs: Any) -> T Returns: TModel: The record. None if the key does not exist. + + Raises: + VectorStoreOperationException: If an error occurs during the get. + VectorStoreModelDeserializationException: If an error occurs during deserialization. """ try: records = await self._inner_get([key], include_vectors=include_vectors, **kwargs) except Exception as exc: - raise MemoryConnectorException(f"Error getting record: {exc}") from exc + raise VectorStoreOperationException(f"Error getting record: {exc}") from exc if not records: return None try: model_records = self.deserialize(records[0], **kwargs) - except Exception as exc: - raise MemoryConnectorException(f"Error deserializing record: {exc}") from exc + # the deserialize method will parse any exception into a VectorStoreModelDeserializationException + except VectorStoreModelDeserializationException: + raise # there are many code paths within the deserialize method, some supplied by the developer, # and so depending on what is used, @@ -316,7 +382,9 @@ async def get(self, key: TKey, include_vectors: bool = True, **kwargs: Any) -> T return model_records if len(model_records) == 1: return model_records[0] - raise MemoryConnectorException(f"Error deserializing record, multiple records returned: {model_records}") + raise VectorStoreModelDeserializationException( + f"Error deserializing record, multiple records returned: {model_records}" + ) async def get_batch( self, keys: Sequence[TKey], include_vectors: bool = True, **kwargs: Any @@ -333,19 +401,24 @@ async def get_batch( Returns: The records, either a list of TModel or the container type. + + Raises: + VectorStoreOperationException: If an error occurs during the get. + VectorStoreModelDeserializationException: If an error occurs during deserialization. """ try: records = await self._inner_get(keys, include_vectors=include_vectors, **kwargs) except Exception as exc: - raise MemoryConnectorException(f"Error getting records: {exc}") from exc + raise VectorStoreOperationException(f"Error getting records: {exc}") from exc if not records: return None try: return self.deserialize(records, keys=keys, **kwargs) - except Exception as exc: - raise MemoryConnectorException(f"Error deserializing record: {exc}") from exc + # the deserialize method will parse any exception into a VectorStoreModelDeserializationException + except VectorStoreModelDeserializationException: + raise async def delete(self, key: TKey, **kwargs: Any) -> None: """Delete a record. @@ -354,12 +427,12 @@ async def delete(self, key: TKey, **kwargs: Any) -> None: key: The key. **kwargs: Additional arguments. Exceptions: - MemoryConnectorException: If an error occurs during deletion or the record does not exist. + VectorStoreOperationException: If an error occurs during deletion or the record does not exist. """ try: await self._inner_delete([key], **kwargs) except Exception as exc: - raise MemoryConnectorException(f"Error deleting record: {exc}") from exc + raise VectorStoreOperationException(f"Error deleting record: {exc}") from exc async def delete_batch(self, keys: Sequence[TKey], **kwargs: Any) -> None: """Delete a batch of records. @@ -370,14 +443,14 @@ async def delete_batch(self, keys: Sequence[TKey], **kwargs: Any) -> None: keys: The keys. **kwargs: Additional arguments. Exceptions: - MemoryConnectorException: If an error occurs during deletion or a record does not exist. + VectorStoreOperationException: If an error occurs during deletion or a record does not exist. """ try: await self._inner_delete(keys, **kwargs) except Exception as exc: - raise MemoryConnectorException(f"Error deleting records: {exc}") from exc + raise VectorStoreOperationException(f"Error deleting records: {exc}") from exc - # region Internal Serialization methods + # region Serialization methods def serialize(self, records: OneOrMany[TModel], **kwargs: Any) -> OneOrMany[Any]: """Serialize the data model to the store model. @@ -391,45 +464,31 @@ def serialize(self, records: OneOrMany[TModel], **kwargs: Any) -> OneOrMany[Any] If overriding this method, make sure to first try to serialize the data model to the store model, before doing the store specific version, the user supplied version should have precedence. - """ - if serialized := self._serialize_data_model_to_store_model(records): - return serialized - - if isinstance(records, Sequence): - dict_records = [self._serialize_data_model_to_dict(rec) for rec in records] - return self._serialize_dicts_to_store_models(dict_records, **kwargs) # type: ignore - - dict_records = self._serialize_data_model_to_dict(records) # type: ignore - if isinstance(dict_records, Sequence): - # most likely this is a container, so we return all records as a list - # can also be a single record, but the to_dict returns a list - # hence we will treat it as a container. - return self._serialize_dicts_to_store_models(dict_records, **kwargs) # type: ignore - # this case is single record in, single record out - return self._serialize_dicts_to_store_models([dict_records], **kwargs)[0] - def deserialize(self, records: OneOrMany[Any | dict[str, Any]], **kwargs: Any) -> OneOrMany[TModel] | None: - """Deserialize the store model to the data model. + Raises: + VectorStoreModelSerializationException: If an error occurs during serialization. - This method follows the following steps: - 1. Check if the data model has a deserialize method. - Use that method to deserialize and return the result. - 2. Deserialize the store model to a dict, using the store specific method. - 3. Convert the dict to the data model, using the data model specific method. """ - if deserialized := self._deserialize_store_model_to_data_model(records, **kwargs): - return deserialized - - if isinstance(records, Sequence): - dict_records = self._deserialize_store_models_to_dicts(records, **kwargs) - if self._container_mode: - return self._deserialize_dict_to_data_model(dict_records, **kwargs) - return [self._deserialize_dict_to_data_model(rec, **kwargs) for rec in dict_records] - - dict_record = self._deserialize_store_models_to_dicts([records], **kwargs)[0] - if not dict_record: - return None - return self._deserialize_dict_to_data_model(dict_record, **kwargs) + try: + if serialized := self._serialize_data_model_to_store_model(records): + return serialized + + if isinstance(records, Sequence): + dict_records = [self._serialize_data_model_to_dict(rec) for rec in records] + return self._serialize_dicts_to_store_models(dict_records, **kwargs) # type: ignore + + dict_records = self._serialize_data_model_to_dict(records) # type: ignore + if isinstance(dict_records, Sequence): + # most likely this is a container, so we return all records as a list + # can also be a single record, but the to_dict returns a list + # hence we will treat it as a container. + return self._serialize_dicts_to_store_models(dict_records, **kwargs) # type: ignore + # this case is single record in, single record out + return self._serialize_dicts_to_store_models([dict_records], **kwargs)[0] + except VectorStoreModelSerializationException: + raise # pragma: no cover + except Exception as exc: + raise VectorStoreModelSerializationException(f"Error serializing records: {exc}") from exc def _serialize_data_model_to_store_model(self, record: OneOrMany[TModel], **kwargs: Any) -> OneOrMany[Any] | None: """Serialize the data model to the store model. @@ -445,33 +504,9 @@ def _serialize_data_model_to_store_model(self, record: OneOrMany[TModel], **kwar return None return result if self.data_model_definition.serialize: - return self.data_model_definition.serialize(record, **kwargs) # type: ignore - if hasattr(record, "serialize"): - try: - return record.serialize(**kwargs) - except Exception as exc: - raise VectorStoreModelSerializationException(f"Error serializing record: {exc}") from exc - return None - - def _deserialize_store_model_to_data_model(self, record: OneOrMany[Any], **kwargs: Any) -> OneOrMany[TModel] | None: - """Deserialize the store model to the data model. - - This works when the data model has supplied a deserialize method, specific to a data source. - This uses a method called 'deserialize()' on the data model or part of the vector store record definition. - - The developer is responsible for correctly deserializing for the specific data source. - """ - if self.data_model_definition.deserialize: - if isinstance(record, Sequence): - return self.data_model_definition.deserialize(record, **kwargs) - return self.data_model_definition.deserialize([record], **kwargs) - try: - if hasattr(self.data_model_type, "deserialize"): - if isinstance(record, Sequence): - return [self.data_model_type.deserialize(rec, **kwargs) for rec in record] # type: ignore - return self.data_model_type.deserialize(record, **kwargs) # type: ignore - except Exception as exc: - raise VectorStoreModelSerializationException(f"Error deserializing record: {exc}") from exc + return self.data_model_definition.serialize(record, **kwargs) + if isinstance(record, SerializeMethodProtocol): + return record.serialize(**kwargs) return None def _serialize_data_model_to_dict(self, record: TModel, **kwargs: Any) -> OneOrMany[dict[str, Any]]: @@ -483,44 +518,79 @@ def _serialize_data_model_to_dict(self, record: TModel, **kwargs: Any) -> OneOrM """ if self.data_model_definition.to_dict: return self.data_model_definition.to_dict(record, **kwargs) - if hasattr(record, "model_dump"): - try: - ret = record.model_dump() # type: ignore - if not any(field.serialize_function is not None for field in self.data_model_definition.vector_fields): - return ret - for field in self.data_model_definition.vector_fields: - if field.serialize_function: - assert field.name is not None # nosec - ret[field.name] = field.serialize_function(ret[field.name]) - return ret - except Exception as exc: - raise VectorStoreModelSerializationException(f"Error serializing record: {exc}") from exc - if hasattr(record, "to_dict"): - try: - ret = record.to_dict() # type: ignore - if not any(field.serialize_function is not None for field in self.data_model_definition.vector_fields): - return ret - for field in self.data_model_definition.vector_fields: - if field.serialize_function: - assert field.name is not None # nosec - ret[field.name] = field.serialize_function(ret[field.name]) - return ret - except Exception as exc: - raise VectorStoreModelSerializationException(f"Error serializing record: {exc}") from exc + if isinstance(record, BaseModel): + return self._serialize_vectors(record.model_dump()) + if isinstance(record, ToDictMethodProtocol): + return self._serialize_vectors(record.to_dict()) store_model = {} for field_name in self.data_model_definition.field_names: - try: - value = record[field_name] if isinstance(record, Mapping) else getattr(record, field_name) - if func := getattr(self.data_model_definition.fields[field_name], "serialize_function", None): - value = func(value) - store_model[field_name] = value - except (AttributeError, KeyError) as exc: - raise VectorStoreModelSerializationException( - f"Error serializing record, not able to get: {field_name}" - ) from exc + value = record[field_name] if isinstance(record, Mapping) else getattr(record, field_name) + if func := getattr(self.data_model_definition.fields[field_name], "serialize_function", None): + value = func(value) + store_model[field_name] = value return store_model + def _serialize_vectors(self, record: dict[str, Any]) -> dict[str, Any]: + for field in self.data_model_definition.vector_fields: + if field.serialize_function: + record[field.name or ""] = field.serialize_function(record[field.name or ""]) + return record + + # region Deserialization methods + + def deserialize(self, records: OneOrMany[Any | dict[str, Any]], **kwargs: Any) -> OneOrMany[TModel] | None: + """Deserialize the store model to the data model. + + This method follows the following steps: + 1. Check if the data model has a deserialize method. + Use that method to deserialize and return the result. + 2. Deserialize the store model to a dict, using the store specific method. + 3. Convert the dict to the data model, using the data model specific method. + + Raises: + VectorStoreModelDeserializationException: If an error occurs during deserialization. + """ + try: + if not records: + return None + if deserialized := self._deserialize_store_model_to_data_model(records, **kwargs): + return deserialized + + if isinstance(records, Sequence): + dict_records = self._deserialize_store_models_to_dicts(records, **kwargs) + return ( + self._deserialize_dict_to_data_model(dict_records, **kwargs) + if self._container_mode + else [self._deserialize_dict_to_data_model(rec, **kwargs) for rec in dict_records] + ) + + dict_record = self._deserialize_store_models_to_dicts([records], **kwargs)[0] + # regardless of mode, only 1 object is returned. + return self._deserialize_dict_to_data_model(dict_record, **kwargs) + except VectorStoreModelDeserializationException: + raise # pragma: no cover + except Exception as exc: + raise VectorStoreModelDeserializationException(f"Error deserializing records: {exc}") from exc + + def _deserialize_store_model_to_data_model(self, record: OneOrMany[Any], **kwargs: Any) -> OneOrMany[TModel] | None: + """Deserialize the store model to the data model. + + This works when the data model has supplied a deserialize method, specific to a data source. + This uses a method called 'deserialize()' on the data model or part of the vector store record definition. + + The developer is responsible for correctly deserializing for the specific data source. + """ + if self.data_model_definition.deserialize: + if isinstance(record, Sequence): + return self.data_model_definition.deserialize(record, **kwargs) + return self.data_model_definition.deserialize([record], **kwargs) + if func := getattr(self.data_model_type, "deserialize", None): + if isinstance(record, Sequence): + return [func(rec, **kwargs) for rec in record] + return func(record, **kwargs) + return None + def _deserialize_dict_to_data_model(self, record: OneOrMany[dict[str, Any]], **kwargs: Any) -> TModel: """This function is used if no deserialize method is found on the data model. @@ -541,45 +611,32 @@ def _deserialize_dict_to_data_model(self, record: OneOrMany[dict[str, Any]], **k "Cannot deserialize multiple records to a single record unless you are using a container." ) record = record[0] - if hasattr(self.data_model_type, "model_validate"): - try: - if not any(field.serialize_function is not None for field in self.data_model_definition.vector_fields): - return self.data_model_type.model_validate(record) # type: ignore - if include_vectors: - for field in self.data_model_definition.vector_fields: - if field.serialize_function: - record[field.name] = field.serialize_function(record[field.name]) # type: ignore - return self.data_model_type.model_validate(record) # type: ignore - except Exception as exc: - raise VectorStoreModelDeserializationException(f"Error deserializing record: {exc}") from exc - if hasattr(self.data_model_type, "from_dict"): - try: - if not any(field.serialize_function is not None for field in self.data_model_definition.vector_fields): - return self.data_model_type.from_dict(record) # type: ignore - if include_vectors: - for field in self.data_model_definition.vector_fields: - if field.serialize_function: - record[field.name] = field.serialize_function(record[field.name]) # type: ignore - return self.data_model_type.from_dict(record) # type: ignore - except Exception as exc: - raise VectorStoreModelDeserializationException(f"Error deserializing record: {exc}") from exc + if issubclass(self.data_model_type, BaseModel): + if include_vectors: + record = self._deserialize_vector(record) + return self.data_model_type.model_validate(record) # type: ignore + if func := getattr(self.data_model_type, "from_dict", None): + if include_vectors: + record = self._deserialize_vector(record) + return func(record) data_model_dict: dict[str, Any] = {} - for field_name in self.data_model_definition.fields: # type: ignore + for field_name in self.data_model_definition.fields: if not include_vectors and field_name in self.data_model_definition.vector_field_names: continue - try: - value = record[field_name] - if func := getattr(self.data_model_definition.fields[field_name], "deserialize_function", None): - value = func(value) - data_model_dict[field_name] = value - except KeyError as exc: - raise VectorStoreModelDeserializationException( - f"Error deserializing record, not able to get: {field_name}" - ) from exc + value = record[field_name] + if func := getattr(self.data_model_definition.fields[field_name], "deserialize_function", None): + value = func(value) + data_model_dict[field_name] = value if self.data_model_type is dict: return data_model_dict # type: ignore return self.data_model_type(**data_model_dict) + def _deserialize_vector(self, record: dict[str, Any]) -> dict[str, Any]: + for field in self.data_model_definition.vector_fields: + if field.deserialize_function: + record[field.name or ""] = field.deserialize_function(record[field.name or ""]) + return record + # region Internal Functions def __del__(self): diff --git a/python/semantic_kernel/exceptions/__init__.py b/python/semantic_kernel/exceptions/__init__.py index 9ed131971525..6667c4570b32 100644 --- a/python/semantic_kernel/exceptions/__init__.py +++ b/python/semantic_kernel/exceptions/__init__.py @@ -11,3 +11,4 @@ from semantic_kernel.exceptions.search_exceptions import * # noqa: F403 from semantic_kernel.exceptions.service_exceptions import * # noqa: F403 from semantic_kernel.exceptions.template_engine_exceptions import * # noqa: F403 +from semantic_kernel.exceptions.vector_store_exceptions import * # noqa: F403 diff --git a/python/semantic_kernel/exceptions/memory_connector_exceptions.py b/python/semantic_kernel/exceptions/memory_connector_exceptions.py index 16848f572025..d8b9bb736023 100644 --- a/python/semantic_kernel/exceptions/memory_connector_exceptions.py +++ b/python/semantic_kernel/exceptions/memory_connector_exceptions.py @@ -16,24 +16,6 @@ class MemoryConnectorConnectionException(MemoryConnectorException): pass -class VectorStoreModelException(MemoryConnectorException): - """Base class for all vector store model exceptions.""" - - pass - - -class VectorStoreModelSerializationException(VectorStoreModelException): - """An error occurred while serializing the vector store model.""" - - pass - - -class VectorStoreModelDeserializationException(VectorStoreModelException): - """An error occurred while deserializing the vector store model.""" - - pass - - class MemoryConnectorInitializationError(MemoryConnectorException): """An error occurred while initializing the memory connector.""" @@ -46,26 +28,9 @@ class MemoryConnectorResourceNotFound(MemoryConnectorException): pass -class VectorStoreModelValidationError(VectorStoreModelException): - """An error occurred while validating the vector store model.""" - - pass - - -class VectorStoreSearchError(MemoryConnectorException): - """An error occurred while searching the vector store model.""" - - pass - - __all__ = [ "MemoryConnectorConnectionException", "MemoryConnectorException", "MemoryConnectorInitializationError", "MemoryConnectorResourceNotFound", - "VectorStoreModelDeserializationException", - "VectorStoreModelException", - "VectorStoreModelSerializationException", - "VectorStoreModelValidationError", - "VectorStoreSearchError", ] diff --git a/python/semantic_kernel/exceptions/search_exceptions.py b/python/semantic_kernel/exceptions/search_exceptions.py index fdb0bae13284..456235e6b008 100644 --- a/python/semantic_kernel/exceptions/search_exceptions.py +++ b/python/semantic_kernel/exceptions/search_exceptions.py @@ -9,36 +9,12 @@ class SearchException(KernelException): pass -class VectorStoreMixinException(SearchException): - """Raised when a mixin is used without the VectorSearchBase Class.""" - - pass - - -class VectorStoreTextSearchValidationError(SearchException): - """An error occurred while validating the vector store text search model.""" - - pass - - class SearchResultEmptyError(SearchException): """Raised when there are no hits in the search results.""" pass -class VectorSearchExecutionException(SearchException): - """Raised when there is an error executing a VectorSearch function.""" - - pass - - -class VectorSearchOptionsException(SearchException): - """Raised when invalid options are given to a VectorSearch function.""" - - pass - - class TextSearchException(SearchException): """An error occurred while executing a text search function.""" @@ -56,8 +32,4 @@ class TextSearchOptionsException(SearchException): "SearchResultEmptyError", "TextSearchException", "TextSearchOptionsException", - "VectorSearchExecutionException", - "VectorSearchOptionsException", - "VectorStoreMixinException", - "VectorStoreTextSearchValidationError", ] diff --git a/python/semantic_kernel/exceptions/vector_store_exceptions.py b/python/semantic_kernel/exceptions/vector_store_exceptions.py new file mode 100644 index 000000000000..0f42939c6155 --- /dev/null +++ b/python/semantic_kernel/exceptions/vector_store_exceptions.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft. All rights reserved. + +from semantic_kernel.exceptions.kernel_exceptions import KernelException + + +class VectorStoreException(KernelException): + """Base class for all vector store exceptions.""" + + pass + + +class VectorStoreInitializationException(VectorStoreException): + """Class for all vector store initialization exceptions.""" + + pass + + +class VectorStoreModelException(VectorStoreException): + """Base class for all vector store model exceptions.""" + + pass + + +class VectorStoreModelSerializationException(VectorStoreModelException): + """An error occurred while serializing the vector store model.""" + + pass + + +class VectorStoreModelDeserializationException(VectorStoreModelException): + """An error occurred while deserializing the vector store model.""" + + pass + + +class VectorStoreModelValidationError(VectorStoreModelException): + """An error occurred while validating the vector store model.""" + + pass + + +class VectorStoreMixinException(VectorStoreException): + """Raised when a mixin is used without the VectorSearchBase Class.""" + + pass + + +class VectorStoreOperationException(VectorStoreException): + """An error occurred while performing an operation on the vector store record collection.""" + + pass + + +class VectorSearchExecutionException(VectorStoreOperationException): + """Raised when there is an error executing a VectorSearch function.""" + + pass + + +class VectorSearchOptionsException(VectorStoreOperationException): + """Raised when invalid options are given to a VectorSearch function.""" + + pass + + +__all__ = [ + "VectorSearchExecutionException", + "VectorSearchOptionsException", + "VectorStoreException", + "VectorStoreInitializationException", + "VectorStoreMixinException", + "VectorStoreModelDeserializationException", + "VectorStoreModelException", + "VectorStoreModelException", + "VectorStoreModelSerializationException", + "VectorStoreModelValidationError", + "VectorStoreOperationException", +] diff --git a/python/tests/conftest.py b/python/tests/conftest.py index e6a01549f020..aba80113c333 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -195,44 +195,6 @@ def prompt() -> str: return "test prompt" -# @fixture(autouse=True) -# def enable_debug_mode(): -# """Set `autouse=True` to enable easy debugging for tests. - -# How to debug: -# 1. Ensure [snoop](https://github.com/alexmojaki/snoop) is installed -# (`pip install snoop`). -# 2. If you're doing print based debugging, use `pr` instead of `print`. -# That is, convert `print(some_var)` to `pr(some_var)`. -# 3. If you want a trace of a particular functions calls, just add `ss()` as the first -# line of the function. - -# Note: -# ---- -# It's completely fine to leave `autouse=True` in the fixture. It doesn't affect -# the tests unless you use `pr` or `ss` in any test. - -# Note: -# ---- -# When you use `ss` or `pr` in a test, pylance or mypy will complain. This is -# because they don't know that we're adding these functions to the builtins. The -# tests will run fine though. -# """ -# import builtins - -# try: -# import snoop -# except ImportError: -# warnings.warn( -# "Install snoop to enable trace debugging. `pip install snoop`", -# ImportWarning, -# ) -# return - -# builtins.ss = snoop.snoop(depth=4).__enter__ -# builtins.pr = snoop.pp - - @fixture def exclude_list(request): """Fixture that returns a list of environment variables to exclude.""" diff --git a/python/tests/integration/completions/chat_completion_test_base.py b/python/tests/integration/completions/chat_completion_test_base.py index a31882951c9b..d05157e607c5 100644 --- a/python/tests/integration/completions/chat_completion_test_base.py +++ b/python/tests/integration/completions/chat_completion_test_base.py @@ -118,11 +118,12 @@ def services(self) -> dict[str, tuple[ServiceType | None, type[PromptExecutionSe default_headers={"Test-User-X-ID": "test"}, ), ) + assert deployment_name azure_ai_inference_client = AzureAIInferenceChatCompletion( ai_model_id=deployment_name, client=ChatCompletionsClient( endpoint=f"{endpoint.strip('/')}/openai/deployments/{deployment_name}", - credential=DefaultAzureCredential(), + credential=DefaultAzureCredential(), # type: ignore credential_scopes=["https://cognitiveservices.azure.com/.default"], ), ) @@ -190,7 +191,7 @@ def setup(self, kernel: Kernel): async def get_chat_completion_response( self, kernel: Kernel, - service: ChatCompletionClientBase, + service: ServiceType, execution_settings: PromptExecutionSettings, chat_history: ChatHistory, stream: bool, @@ -204,6 +205,7 @@ async def get_chat_completion_response( input (str): Input string. stream (bool): Stream flag. """ + assert isinstance(service, ChatCompletionClientBase) if not stream: return await service.get_chat_message_content( chat_history, diff --git a/python/tests/integration/completions/test_chat_completion_with_function_calling.py b/python/tests/integration/completions/test_chat_completion_with_function_calling.py index b6d83c6d0735..6760eb5a1b02 100644 --- a/python/tests/integration/completions/test_chat_completion_with_function_calling.py +++ b/python/tests/integration/completions/test_chat_completion_with_function_calling.py @@ -971,6 +971,7 @@ async def test_streaming_completion( def evaluate(self, test_target: Any, **kwargs): inputs = kwargs.get("inputs") test_type = kwargs.get("test_type") + assert isinstance(inputs, list) if test_type == FunctionChoiceTestTypes.AUTO: self._evaluate_auto_function_choice(test_target, inputs) @@ -1060,7 +1061,7 @@ async def _test_helper( self.setup(kernel) - cmc = await retry( + cmc: ChatMessageContent | None = await retry( partial( self.get_chat_completion_response, kernel=kernel, @@ -1070,11 +1071,13 @@ async def _test_helper( stream=stream, ), retries=5, + name="function_calling", ) # We need to add the latest message to the history because the connector is # not responsible for updating the history, unless it is related to auto function # calling, when the history is updated after the function calls are invoked. - history.add_message(cmc) + if cmc: + history.add_message(cmc) self.evaluate(history, inputs=inputs, test_type=test_type) diff --git a/python/tests/integration/completions/test_chat_completion_with_image_input_text_output.py b/python/tests/integration/completions/test_chat_completion_with_image_input_text_output.py index 3f8eaa10ea82..f45f3367268c 100644 --- a/python/tests/integration/completions/test_chat_completion_with_image_input_text_output.py +++ b/python/tests/integration/completions/test_chat_completion_with_image_input_text_output.py @@ -263,7 +263,7 @@ async def test_completion( service_id: str, services: dict[str, tuple[ServiceType, type[PromptExecutionSettings]]], execution_settings_kwargs: dict[str, Any], - inputs: list[str | ChatMessageContent | list[ChatMessageContent]], + inputs: list[ChatMessageContent], kwargs: dict[str, Any], ): await self._test_helper( @@ -282,7 +282,7 @@ async def test_streaming_completion( service_id: str, services: dict[str, tuple[ServiceType, type[PromptExecutionSettings]]], execution_settings_kwargs: dict[str, Any], - inputs: list[str | ChatMessageContent | list[ChatMessageContent]], + inputs: list[ChatMessageContent], kwargs: dict[str, Any], ): await self._test_helper( @@ -297,6 +297,7 @@ async def test_streaming_completion( @override def evaluate(self, test_target: Any, **kwargs): inputs = kwargs.get("inputs") + assert isinstance(inputs, list) assert len(test_target) == len(inputs) * 2 for i in range(len(inputs)): message = test_target[i * 2 + 1] @@ -323,7 +324,7 @@ async def _test_helper( for message in inputs: history.add_message(message) - cmc = await retry( + cmc: ChatMessageContent | None = await retry( partial( self.get_chat_completion_response, kernel=kernel, @@ -333,7 +334,9 @@ async def _test_helper( stream=stream, ), retries=5, + name="image_input", ) - history.add_message(cmc) + if cmc: + history.add_message(cmc) self.evaluate(history.messages, inputs=inputs) diff --git a/python/tests/integration/completions/test_chat_completions.py b/python/tests/integration/completions/test_chat_completions.py index 810be08fd5e2..e3a77542f0f6 100644 --- a/python/tests/integration/completions/test_chat_completions.py +++ b/python/tests/integration/completions/test_chat_completions.py @@ -6,6 +6,11 @@ import pytest +if sys.version_info >= (3, 12): + from typing import override # pragma: no cover +else: + from typing_extensions import override # pragma: no cover + from semantic_kernel import Kernel from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings from semantic_kernel.contents import ChatMessageContent, TextContent @@ -24,11 +29,6 @@ from tests.integration.completions.completion_test_base import ServiceType from tests.utils import retry -if sys.version_info >= (3, 12): - from typing import override # pragma: no cover -else: - from typing_extensions import override # pragma: no cover - class Step(KernelBaseModel): explanation: str @@ -265,7 +265,7 @@ async def test_completion( service_id: str, services: dict[str, tuple[ServiceType, type[PromptExecutionSettings]]], execution_settings_kwargs: dict[str, Any], - inputs: list[str | ChatMessageContent | list[ChatMessageContent]], + inputs: list[ChatMessageContent], kwargs: dict[str, Any], ): await self._test_helper( @@ -284,7 +284,7 @@ async def test_streaming_completion( service_id: str, services: dict[str, tuple[ServiceType, type[PromptExecutionSettings]]], execution_settings_kwargs: dict[str, Any], - inputs: list[str | ChatMessageContent | list[ChatMessageContent]], + inputs: list[ChatMessageContent], kwargs: dict[str, Any], ): await self._test_helper( @@ -299,6 +299,7 @@ async def test_streaming_completion( @override def evaluate(self, test_target: Any, **kwargs): inputs = kwargs.get("inputs") + assert isinstance(inputs, list) assert len(test_target) == len(inputs) * 2 for i in range(len(inputs)): message = test_target[i * 2 + 1] @@ -325,7 +326,7 @@ async def _test_helper( for message in inputs: history.add_message(message) - cmc = await retry( + cmc: ChatMessageContent | None = await retry( partial( self.get_chat_completion_response, kernel=kernel, @@ -335,7 +336,9 @@ async def _test_helper( stream=stream, ), retries=5, + name="get_chat_completion_response", ) - history.add_message(cmc) + if cmc: + history.add_message(cmc) self.evaluate(history.messages, inputs=inputs) diff --git a/python/tests/integration/completions/test_conversation_summary_plugin.py b/python/tests/integration/completions/test_conversation_summary_plugin.py index a89749b2993f..8e466d6d4731 100644 --- a/python/tests/integration/completions/test_conversation_summary_plugin.py +++ b/python/tests/integration/completions/test_conversation_summary_plugin.py @@ -55,7 +55,7 @@ async def test_azure_summarize_conversation_using_plugin(kernel): ) prompt_template_config = PromptTemplateConfig( description="Given a section of a conversation transcript, summarize the part of the conversation.", - execution_settings=execution_settings, + execution_settings={service_id: execution_settings}, ) kernel.add_service(sk_oai.OpenAIChatCompletion(service_id=service_id)) diff --git a/python/tests/integration/completions/test_text_completion.py b/python/tests/integration/completions/test_text_completion.py index 9e9e93084ed0..3e9b34ef76aa 100644 --- a/python/tests/integration/completions/test_text_completion.py +++ b/python/tests/integration/completions/test_text_completion.py @@ -2,7 +2,7 @@ import platform import sys -from functools import partial, reduce +from functools import partial from typing import Any if sys.version_info >= (3, 12): @@ -27,8 +27,7 @@ ) from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings from semantic_kernel.connectors.ai.text_completion_client_base import TextCompletionClientBase -from semantic_kernel.contents.chat_message_content import ChatMessageContent -from semantic_kernel.contents.text_content import TextContent +from semantic_kernel.contents import StreamingTextContent, TextContent from semantic_kernel.utils.authentication.entra_id_authentication import get_entra_auth_token from tests.integration.completions.completion_test_base import CompletionTestBase, ServiceType from tests.utils import is_service_setup_for_testing, is_test_running_on_supported_platforms, retry @@ -275,7 +274,7 @@ def services(self) -> dict[str, tuple[ServiceType | None, type[PromptExecutionSe async def get_text_completion_response( self, - service: TextCompletionClientBase, + service: ServiceType, execution_settings: PromptExecutionSettings, prompt: str, stream: bool, @@ -289,21 +288,20 @@ async def get_text_completion_response( prompt (str): Input string. stream (bool): Stream flag. """ + assert isinstance(service, TextCompletionClientBase) if stream: response = service.get_streaming_text_content( prompt=prompt, settings=execution_settings, ) - parts = [part async for part in response] + parts: list[StreamingTextContent] = [part async for part in response if part is not None] if parts: - response = reduce(lambda p, r: p + r, parts) - else: - raise AssertionError("No response") - else: - response = await service.get_text_content( - prompt=prompt, - settings=execution_settings, - ) + return sum(parts[1:], parts[0]) + raise AssertionError("No response") + return await service.get_text_content( + prompt=prompt, + settings=execution_settings, + ) return response @@ -314,7 +312,7 @@ async def test_completion( service_id: str, services: dict[str, tuple[ServiceType, type[PromptExecutionSettings]]], execution_settings_kwargs: dict[str, Any], - inputs: list[str | ChatMessageContent | list[ChatMessageContent]], + inputs: list[str], kwargs: dict[str, Any], ) -> None: await self._test_helper(service_id, services, execution_settings_kwargs, inputs, False) @@ -326,7 +324,7 @@ async def test_streaming_completion( service_id: str, services: dict[str, tuple[ServiceType, type[PromptExecutionSettings]]], execution_settings_kwargs: dict[str, Any], - inputs: list[str | ChatMessageContent | list[ChatMessageContent]], + inputs: list[str], kwargs: dict[str, Any], ): if "streaming" in kwargs and not kwargs["streaming"]: @@ -364,5 +362,6 @@ async def _test_helper( stream=stream, ), retries=5, + name="text completions", ) self.evaluate(response) diff --git a/python/tests/unit/connectors/memory/azure_ai_search/test_azure_ai_search.py b/python/tests/unit/connectors/memory/azure_ai_search/test_azure_ai_search.py index 6bc6102215b5..3d87f26d0ccd 100644 --- a/python/tests/unit/connectors/memory/azure_ai_search/test_azure_ai_search.py +++ b/python/tests/unit/connectors/memory/azure_ai_search/test_azure_ai_search.py @@ -14,11 +14,11 @@ data_model_definition_to_azure_ai_search_index, get_search_index_client, ) -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, - MemoryConnectorInitializationError, +from semantic_kernel.exceptions import ( + ServiceInitializationError, + VectorStoreInitializationException, + VectorStoreOperationException, ) -from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError from semantic_kernel.utils.list_handler import desync_list BASE_PATH_SEARCH_CLIENT = "azure.search.documents.aio.SearchClient" @@ -104,7 +104,7 @@ def test_init_with_type(azure_ai_search_unit_test_env, data_model_type): @mark.parametrize("exclude_list", [["AZURE_AI_SEARCH_ENDPOINT"]], indirect=True) def test_init_endpoint_fail(azure_ai_search_unit_test_env, data_model_definition): - with raises(MemoryConnectorInitializationError): + with raises(VectorStoreInitializationException): AzureAISearchCollection( data_model_type=dict, data_model_definition=data_model_definition, env_file_path="test.env" ) @@ -112,7 +112,7 @@ def test_init_endpoint_fail(azure_ai_search_unit_test_env, data_model_definition @mark.parametrize("exclude_list", [["AZURE_AI_SEARCH_INDEX_NAME"]], indirect=True) def test_init_index_fail(azure_ai_search_unit_test_env, data_model_definition): - with raises(MemoryConnectorInitializationError): + with raises(VectorStoreInitializationException): AzureAISearchCollection( data_model_type=dict, data_model_definition=data_model_definition, env_file_path="test.env" ) @@ -161,7 +161,7 @@ def test_init_with_search_index_client(azure_ai_search_unit_test_env, data_model def test_init_with_search_index_client_fail(azure_ai_search_unit_test_env, data_model_definition): search_index_client = MagicMock(spec=SearchIndexClientWrapper) - with raises(MemoryConnectorInitializationError, match="Collection name is required."): + with raises(VectorStoreInitializationException, match="Collection name is required."): AzureAISearchCollection( data_model_type=dict, data_model_definition=data_model_definition, @@ -175,7 +175,7 @@ def test_init_with_clients_fail(azure_ai_search_unit_test_env, data_model_defini search_client._index_name = "test-index-name" with raises( - MemoryConnectorInitializationError, match="Search client and search index client have different index names." + VectorStoreInitializationException, match="Search client and search index client have different index names." ): AzureAISearchCollection( data_model_type=dict, @@ -233,7 +233,7 @@ async def test_create_index_from_definition(collection, mock_create_collection): async def test_create_index_from_index_fail(collection, mock_create_collection): index = Mock() - with raises(MemoryConnectorException): + with raises(VectorStoreOperationException): await collection.create_collection(index=index) @@ -247,7 +247,7 @@ def test_data_model_definition_to_azure_ai_search_index(data_model_definition): @mark.parametrize("exclude_list", [["AZURE_AI_SEARCH_ENDPOINT"]], indirect=True) async def test_vector_store_fail(azure_ai_search_unit_test_env): - with raises(MemoryConnectorInitializationError): + with raises(VectorStoreInitializationException): AzureAISearchStore(env_file_path="test.env") diff --git a/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_collection.py b/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_collection.py index 6c5dc4d1c2b4..e307e9bab8dc 100644 --- a/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_collection.py +++ b/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_collection.py @@ -15,11 +15,10 @@ create_default_indexing_policy, create_default_vector_embedding_policy, ) -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, - MemoryConnectorInitializationError, - MemoryConnectorResourceNotFound, +from semantic_kernel.exceptions import ( + VectorStoreInitializationException, ) +from semantic_kernel.exceptions.vector_store_exceptions import VectorStoreModelException, VectorStoreOperationException def test_azure_cosmos_db_no_sql_collection_init( @@ -74,7 +73,7 @@ def test_azure_cosmos_db_no_sql_collection_init_no_url( collection_name: str, ) -> None: """Test the initialization of an AzureCosmosDBNoSQLCollection object with missing URL.""" - with pytest.raises(MemoryConnectorInitializationError): + with pytest.raises(VectorStoreInitializationException): AzureCosmosDBNoSQLCollection( data_model_type=data_model_type, collection_name=collection_name, @@ -90,7 +89,7 @@ def test_azure_cosmos_db_no_sql_collection_init_no_database_name( ) -> None: """Test the initialization of an AzureCosmosDBNoSQLCollection object with missing database name.""" with pytest.raises( - MemoryConnectorInitializationError, match="The name of the Azure Cosmos DB NoSQL database is missing." + VectorStoreInitializationException, match="The name of the Azure Cosmos DB NoSQL database is missing." ): AzureCosmosDBNoSQLCollection( data_model_type=data_model_type, @@ -105,7 +104,7 @@ def test_azure_cosmos_db_no_sql_collection_invalid_settings( collection_name: str, ) -> None: """Test the initialization of an AzureCosmosDBNoSQLCollection object with invalid settings.""" - with pytest.raises(MemoryConnectorInitializationError): + with pytest.raises(VectorStoreInitializationException): AzureCosmosDBNoSQLCollection( data_model_type=data_model_type, collection_name=collection_name, @@ -204,7 +203,7 @@ async def test_azure_cosmos_db_no_sql_collection_create_database_raise_if_databa assert vector_collection.create_database is False - with pytest.raises(MemoryConnectorResourceNotFound): + with pytest.raises(VectorStoreOperationException): await vector_collection._get_database_proxy() @@ -325,7 +324,7 @@ async def test_azure_cosmos_db_no_sql_collection_create_collection_unsupported_v mock_database_proxy.create_container_if_not_exists = AsyncMock(return_value=None) - with pytest.raises(MemoryConnectorException): + with pytest.raises(VectorStoreModelException): await vector_collection.create_collection() @@ -367,7 +366,7 @@ async def test_azure_cosmos_db_no_sql_collection_delete_collection_fail( vector_collection._get_database_proxy = AsyncMock(return_value=mock_database_proxy) mock_database_proxy.delete_container = AsyncMock(side_effect=CosmosHttpResponseError) - with pytest.raises(MemoryConnectorException, match="Container could not be deleted."): + with pytest.raises(VectorStoreOperationException, match="Container could not be deleted."): await vector_collection.delete_collection() diff --git a/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_store.py b/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_store.py index 7f7f53d155b8..0514f9d8dc91 100644 --- a/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_store.py +++ b/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_store.py @@ -10,7 +10,7 @@ ) from semantic_kernel.connectors.memory.azure_cosmos_db.azure_cosmos_db_no_sql_store import AzureCosmosDBNoSQLStore from semantic_kernel.connectors.memory.azure_cosmos_db.utils import CosmosClientWrapper -from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorInitializationError +from semantic_kernel.exceptions import VectorStoreInitializationException def test_azure_cosmos_db_no_sql_store_init( @@ -43,7 +43,7 @@ def test_azure_cosmos_db_no_sql_store_init_no_url( azure_cosmos_db_no_sql_unit_test_env, ) -> None: """Test the initialization of an AzureCosmosDBNoSQLStore object with missing URL.""" - with pytest.raises(MemoryConnectorInitializationError): + with pytest.raises(VectorStoreInitializationException): AzureCosmosDBNoSQLStore(env_file_path="fake_path") @@ -53,7 +53,7 @@ def test_azure_cosmos_db_no_sql_store_init_no_database_name( ) -> None: """Test the initialization of an AzureCosmosDBNoSQLStore object with missing database name.""" with pytest.raises( - MemoryConnectorInitializationError, match="The name of the Azure Cosmos DB NoSQL database is missing." + VectorStoreInitializationException, match="The name of the Azure Cosmos DB NoSQL database is missing." ): AzureCosmosDBNoSQLStore(env_file_path="fake_path") @@ -62,7 +62,7 @@ def test_azure_cosmos_db_no_sql_store_invalid_settings( clear_azure_cosmos_db_no_sql_env, ) -> None: """Test the initialization of an AzureCosmosDBNoSQLStore object with invalid settings.""" - with pytest.raises(MemoryConnectorInitializationError, match="Failed to validate Azure Cosmos DB NoSQL settings."): + with pytest.raises(VectorStoreInitializationException, match="Failed to validate Azure Cosmos DB NoSQL settings."): AzureCosmosDBNoSQLStore(url="invalid_url") diff --git a/python/tests/unit/connectors/memory/qdrant/test_qdrant.py b/python/tests/unit/connectors/memory/qdrant/test_qdrant.py index c92571daf238..a9f4c69c97ae 100644 --- a/python/tests/unit/connectors/memory/qdrant/test_qdrant.py +++ b/python/tests/unit/connectors/memory/qdrant/test_qdrant.py @@ -11,12 +11,12 @@ from semantic_kernel.data.record_definition.vector_store_record_fields import VectorStoreRecordVectorField from semantic_kernel.data.vector_search.vector_search_filter import VectorSearchFilter from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, - MemoryConnectorInitializationError, +from semantic_kernel.exceptions import ( + VectorSearchExecutionException, + VectorStoreInitializationException, VectorStoreModelValidationError, + VectorStoreOperationException, ) -from semantic_kernel.exceptions.search_exceptions import VectorSearchExecutionException BASE_PATH = "qdrant_client.async_qdrant_client.AsyncQdrantClient" @@ -144,10 +144,10 @@ def test_vector_store_in_memory(qdrant_unit_test_env): def test_vector_store_fail(): - with raises(MemoryConnectorInitializationError, match="Failed to create Qdrant settings."): + with raises(VectorStoreInitializationException, match="Failed to create Qdrant settings."): QdrantStore(location="localhost", url="localhost", env_file_path="test.env") - with raises(MemoryConnectorInitializationError, match="Failed to create Qdrant client."): + with raises(VectorStoreInitializationException, match="Failed to create Qdrant client."): QdrantStore(location="localhost", url="http://localhost", env_file_path="test.env") @@ -180,7 +180,7 @@ async def test_collection_init(data_model_definition, qdrant_unit_test_env): def test_collection_init_fail(data_model_definition): - with raises(MemoryConnectorInitializationError, match="Failed to create Qdrant settings."): + with raises(VectorStoreInitializationException, match="Failed to create Qdrant settings."): QdrantCollection( data_model_type=dict, collection_name="test", @@ -188,7 +188,7 @@ def test_collection_init_fail(data_model_definition): url="localhost", env_file_path="test.env", ) - with raises(MemoryConnectorInitializationError, match="Failed to create Qdrant client."): + with raises(VectorStoreInitializationException, match="Failed to create Qdrant client."): QdrantCollection( data_model_type=dict, collection_name="test", @@ -274,7 +274,7 @@ async def test_create_index_with_named_vectors(collection_to_use, results, mock_ async def test_create_index_fail(collection_to_use, request): collection = request.getfixturevalue(collection_to_use) collection.data_model_definition.fields["vector"].dimensions = None - with raises(MemoryConnectorException, match="Vector field must have dimensions."): + with raises(VectorStoreOperationException, match="Vector field must have dimensions."): await collection.create_collection() diff --git a/python/tests/unit/connectors/memory/redis/test_redis_store.py b/python/tests/unit/connectors/memory/redis/test_redis_store.py index f62a8a34e468..0fb3b06efae4 100644 --- a/python/tests/unit/connectors/memory/redis/test_redis_store.py +++ b/python/tests/unit/connectors/memory/redis/test_redis_store.py @@ -9,9 +9,9 @@ from semantic_kernel.connectors.memory.redis.const import RedisCollectionTypes from semantic_kernel.connectors.memory.redis.redis_collection import RedisHashsetCollection, RedisJsonCollection from semantic_kernel.connectors.memory.redis.redis_store import RedisStore -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, - MemoryConnectorInitializationError, +from semantic_kernel.exceptions import ( + VectorStoreInitializationException, + VectorStoreOperationException, ) BASE_PATH = "redis.asyncio.client.Redis" @@ -153,7 +153,7 @@ def test_vector_store_with_client(redis_unit_test_env): @mark.parametrize("exclude_list", [["REDIS_CONNECTION_STRING"]], indirect=True) def test_vector_store_fail(redis_unit_test_env): - with raises(MemoryConnectorInitializationError, match="Failed to create Redis settings."): + with raises(VectorStoreInitializationException, match="Failed to create Redis settings."): RedisStore(env_file_path="test.env") @@ -223,14 +223,14 @@ def test_init_with_type(redis_unit_test_env, data_model_type, type_): @mark.parametrize("exclude_list", [["REDIS_CONNECTION_STRING"]], indirect=True) def test_collection_fail(redis_unit_test_env, data_model_definition): - with raises(MemoryConnectorInitializationError, match="Failed to create Redis settings."): + with raises(VectorStoreInitializationException, match="Failed to create Redis settings."): RedisHashsetCollection( data_model_type=dict, collection_name="test", data_model_definition=data_model_definition, env_file_path="test.env", ) - with raises(MemoryConnectorInitializationError, match="Failed to create Redis settings."): + with raises(VectorStoreInitializationException, match="Failed to create Redis settings."): RedisJsonCollection( data_model_type=dict, collection_name="test", @@ -326,5 +326,5 @@ async def test_create_index_manual(collection_hash, mock_create_collection): async def test_create_index_fail(collection_hash, mock_create_collection): - with raises(MemoryConnectorException, match="Invalid index type supplied."): + with raises(VectorStoreOperationException, match="Invalid index type supplied."): await collection_hash.create_collection(index_definition="index_definition", fields="fields") diff --git a/python/tests/unit/connectors/memory/weaviate/test_weaviate_collection.py b/python/tests/unit/connectors/memory/weaviate/test_weaviate_collection.py index 5ba167d5c6a8..4db7fc3a257e 100644 --- a/python/tests/unit/connectors/memory/weaviate/test_weaviate_collection.py +++ b/python/tests/unit/connectors/memory/weaviate/test_weaviate_collection.py @@ -9,11 +9,11 @@ from weaviate.collections.classes.data import DataObject from semantic_kernel.connectors.memory.weaviate.weaviate_collection import WeaviateCollection -from semantic_kernel.exceptions.memory_connector_exceptions import ( - MemoryConnectorException, - MemoryConnectorInitializationError, +from semantic_kernel.exceptions import ( + ServiceInvalidExecutionSettingsError, + VectorStoreInitializationException, + VectorStoreOperationException, ) -from semantic_kernel.exceptions.service_exceptions import ServiceInvalidExecutionSettingsError @patch( @@ -164,7 +164,7 @@ def test_weaviate_collection_init_fail_to_create_client( """Test the initialization of a WeaviateCollection object raises an error when failing to create a client.""" collection_name = "TestCollection" - with pytest.raises(MemoryConnectorInitializationError): + with pytest.raises(VectorStoreInitializationException): WeaviateCollection( data_model_type=data_model_type, data_model_definition=data_model_definition, @@ -262,7 +262,7 @@ async def test_weaviate_collection_create_collection_fail( env_file_path="fake_env_file_path.env", ) - with pytest.raises(MemoryConnectorException): + with pytest.raises(VectorStoreOperationException): await collection.create_collection() @@ -312,7 +312,7 @@ async def test_weaviate_collection_delete_collection_fail( env_file_path="fake_env_file_path.env", ) - with pytest.raises(MemoryConnectorException): + with pytest.raises(VectorStoreOperationException): await collection.delete_collection() @@ -362,7 +362,7 @@ async def test_weaviate_collection_collection_exist_fail( env_file_path="fake_env_file_path.env", ) - with pytest.raises(MemoryConnectorException): + with pytest.raises(VectorStoreOperationException): await collection.does_collection_exist() diff --git a/python/tests/unit/connectors/memory/weaviate/test_weaviate_store.py b/python/tests/unit/connectors/memory/weaviate/test_weaviate_store.py index d47f95fa1462..d50f024bf2ce 100644 --- a/python/tests/unit/connectors/memory/weaviate/test_weaviate_store.py +++ b/python/tests/unit/connectors/memory/weaviate/test_weaviate_store.py @@ -7,8 +7,7 @@ from weaviate import WeaviateAsyncClient from semantic_kernel.connectors.memory.weaviate.weaviate_store import WeaviateStore -from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorInitializationError -from semantic_kernel.exceptions.service_exceptions import ServiceInvalidExecutionSettingsError +from semantic_kernel.exceptions import ServiceInvalidExecutionSettingsError, VectorStoreInitializationException @patch.object( @@ -119,7 +118,7 @@ def test_weaviate_store_init_fail_to_create_client( data_model_definition, ) -> None: """Test the initialization of a WeaviateStore object raises an error when failing to create a client.""" - with pytest.raises(MemoryConnectorInitializationError): + with pytest.raises(VectorStoreInitializationException): WeaviateStore( local_host="localhost", env_file_path="fake_env_file_path.env", diff --git a/python/tests/unit/data/conftest.py b/python/tests/unit/data/conftest.py index 61110ebfcab4..8d926ad1d676 100644 --- a/python/tests/unit/data/conftest.py +++ b/python/tests/unit/data/conftest.py @@ -6,7 +6,8 @@ from typing import Annotated, Any import numpy as np -from pydantic import BaseModel, Field +from pandas import DataFrame +from pydantic import BaseModel, ConfigDict, Field from pytest import fixture from semantic_kernel.data import ( @@ -25,7 +26,7 @@ @fixture -def DictVectorStoreRecordCollection(): +def DictVectorStoreRecordCollection() -> type[VectorSearchBase]: class DictVectorStoreRecordCollection( VectorSearchBase[str, Any], VectorizedSearchMixin[Any], @@ -334,6 +335,24 @@ class DataModelClass(BaseModel): return DataModelClass +@fixture +def data_model_type_pydantic_array(): + @vectorstoremodel + class DataModelClass(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + content: Annotated[str, VectorStoreRecordDataField()] + vector: Annotated[ + np.ndarray, + VectorStoreRecordVectorField( + serialize_function=np.ndarray.tolist, + deserialize_function=np.array, + ), + ] + id: Annotated[str, VectorStoreRecordKeyField()] + + return DataModelClass + + @fixture def data_model_type_dataclass(): @vectorstoremodel @@ -344,3 +363,53 @@ class DataModelClass: id: Annotated[str, VectorStoreRecordKeyField()] return DataModelClass + + +@fixture(scope="function") +def vector_store_record_collection( + DictVectorStoreRecordCollection, + data_model_definition, + data_model_serialize_definition, + data_model_to_from_dict_definition, + data_model_container_definition, + data_model_container_serialize_definition, + data_model_pandas_definition, + data_model_type_vanilla, + data_model_type_vanilla_serialize, + data_model_type_vanilla_to_from_dict, + data_model_type_pydantic, + data_model_type_dataclass, + data_model_type_vector_array, + request, +) -> VectorSearchBase: + item = request.param if request and hasattr(request, "param") else "definition_basic" + defs = { + "definition_basic": data_model_definition, + "definition_with_serialize": data_model_serialize_definition, + "definition_with_to_from": data_model_to_from_dict_definition, + "definition_container": data_model_container_definition, + "definition_container_serialize": data_model_container_serialize_definition, + "definition_pandas": data_model_pandas_definition, + "type_vanilla": data_model_type_vanilla, + "type_vanilla_with_serialize": data_model_type_vanilla_serialize, + "type_vanilla_with_to_from_dict": data_model_type_vanilla_to_from_dict, + "type_pydantic": data_model_type_pydantic, + "type_dataclass": data_model_type_dataclass, + "type_vector_array": data_model_type_vector_array, + } + if item.endswith("pandas"): + return DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=DataFrame, + data_model_definition=defs[item], + ) + if item.startswith("definition_"): + return DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=defs[item], + ) + return DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=defs[item], + ) diff --git a/python/tests/unit/data/test_text_search.py b/python/tests/unit/data/test_text_search.py index 5b03b67e52e9..74a10909317b 100644 --- a/python/tests/unit/data/test_text_search.py +++ b/python/tests/unit/data/test_text_search.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +from collections.abc import AsyncGenerator from typing import Any from unittest.mock import patch @@ -21,6 +22,7 @@ ) from semantic_kernel.exceptions import TextSearchException from semantic_kernel.functions import KernelArguments, KernelParameterMetadata +from semantic_kernel.utils.list_handler import desync_list def test_text_search(): @@ -33,7 +35,7 @@ class TestSearch(TextSearch): async def search(self, **kwargs) -> KernelSearchResults[Any]: """Test search function.""" - async def generator() -> str: + async def generator() -> AsyncGenerator[str, None]: yield "test" return KernelSearchResults(results=generator(), metadata=kwargs) @@ -43,7 +45,7 @@ async def get_text_search_results( ) -> KernelSearchResults[TextSearchResult]: """Test get text search result function.""" - async def generator() -> TextSearchResult: + async def generator() -> AsyncGenerator[TextSearchResult, None]: yield TextSearchResult(value="test") return KernelSearchResults(results=generator(), metadata=kwargs) @@ -53,7 +55,7 @@ async def get_search_results( ) -> KernelSearchResults[Any]: """Test get search result function.""" - async def generator() -> str: + async def generator() -> AsyncGenerator[str, None]: yield "test" return KernelSearchResults(results=generator(), metadata=kwargs) @@ -190,12 +192,18 @@ async def test_create_kernel_function_inner_update_options(kernel: Kernel): called = False args = {} - def update_options(**kwargs: Any) -> tuple[str, SearchOptions]: - kwargs["options"].filter.equal_to("address/city", kwargs.get("city")) + def update_options( + query: str, + options: "SearchOptions", + parameters: list["KernelParameterMetadata"] | None = None, + **kwargs: Any, + ) -> tuple[str, SearchOptions]: + options.filter.equal_to("address/city", kwargs.get("city", "")) nonlocal called, args called = True - args = kwargs - return kwargs["query"], kwargs["options"] + args = {"query": query, "options": options, "parameters": parameters} + args.update(kwargs) + return query, options kernel_function = test_search._create_kernel_function( search_function="search", @@ -225,14 +233,29 @@ def update_options(**kwargs: Any) -> tuple[str, SearchOptions]: assert "parameters" in args -def test_default_map_to_string(): +async def test_default_map_to_string(): test_search = TestSearch() - assert test_search._default_map_to_string("test") == "test" + assert (await test_search._map_results(results=KernelSearchResults(results=desync_list(["test"])))) == ["test"] class TestClass(BaseModel): test: str - assert test_search._default_map_to_string(TestClass(test="test")) == '{"test":"test"}' + assert ( + await test_search._map_results(results=KernelSearchResults(results=desync_list([TestClass(test="test")]))) + ) == ['{"test":"test"}'] + + +async def test_custom_map_to_string(): + test_search = TestSearch() + + class TestClass(BaseModel): + test: str + + assert ( + await test_search._map_results( + results=KernelSearchResults(results=desync_list([TestClass(test="test")])), string_mapper=lambda x: x.test + ) + ) == ["test"] def test_create_options(): @@ -253,6 +276,27 @@ def test_create_options_none(): assert new_options.top == 1 +def test_create_options_vector_to_text(): + options = VectorSearchOptions(top=2, skip=1, include_vectors=True) + options_class = TextSearchOptions + new_options = create_options(options_class, options, top=1) + assert new_options is not None + assert isinstance(new_options, options_class) + assert new_options.top == 1 + assert getattr(new_options, "include_vectors", None) is None + + +def test_create_options_from_dict(): + options = {"skip": 1} + options_class = TextSearchOptions + new_options = create_options(options_class, options, top=1) # type: ignore + assert new_options is not None + assert isinstance(new_options, options_class) + assert new_options.top == 1 + # if a non SearchOptions object is passed in, it should be ignored + assert new_options.skip == 0 + + def test_default_options_update_function(): options = SearchOptions() params = [ @@ -267,3 +311,36 @@ def test_default_options_update_function(): assert options.filter.filters[0].value == "test" assert options.filter.filters[1].field_name == "test2" assert options.filter.filters[1].value == "test2" + + +def test_public_create_functions_search(): + test_search = TestSearch() + function = test_search.create_search() + assert function is not None + assert function.name == "search" + assert ( + function.description == "Perform a search for content related to the specified query and return string results" + ) + assert len(function.parameters) == 3 + + +def test_public_create_functions_get_text_search_results(): + test_search = TestSearch() + function = test_search.create_get_text_search_results() + assert function is not None + assert function.name == "search" + assert ( + function.description == "Perform a search for content related to the specified query and return string results" + ) + assert len(function.parameters) == 3 + + +def test_public_create_functions_get_search_results(): + test_search = TestSearch() + function = test_search.create_get_search_results() + assert function is not None + assert function.name == "search" + assert ( + function.description == "Perform a search for content related to the specified query and return string results" + ) + assert len(function.parameters) == 3 diff --git a/python/tests/unit/data/test_vector_search_base.py b/python/tests/unit/data/test_vector_search_base.py new file mode 100644 index 000000000000..d35c04e45a5b --- /dev/null +++ b/python/tests/unit/data/test_vector_search_base.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft. All rights reserved. + + +import pytest + +from semantic_kernel.data.vector_search.vector_search import VectorSearchBase +from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions +from semantic_kernel.exceptions.vector_store_exceptions import VectorStoreModelDeserializationException + + +async def test_search(vector_store_record_collection: VectorSearchBase): + record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} + await vector_store_record_collection.upsert(record) + results = await vector_store_record_collection._inner_search( + options=VectorSearchOptions(), search_text="test_content" + ) + records = [rec async for rec in results.results] + assert records[0].record == record + + +@pytest.mark.parametrize("include_vectors", [True, False]) +async def test_get_vector_search_results(vector_store_record_collection: VectorSearchBase, include_vectors: bool): + options = VectorSearchOptions(include_vectors=include_vectors) + results = [{"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]}] + async for result in vector_store_record_collection._get_vector_search_results_from_results( + results=results, options=options + ): + assert result.record == results[0] if include_vectors else {"id": "test_id", "content": "test_content"} + break + + +async def test_get_vector_search_results_fail(vector_store_record_collection: VectorSearchBase): + # vector_store_record_collection.data_model_type.serialize = MagicMock(side_effect=Exception) + options = VectorSearchOptions(include_vectors=True) + results = [{"id": "test_id", "content": "test_content"}] + with pytest.raises(VectorStoreModelDeserializationException): + async for result in vector_store_record_collection._get_vector_search_results_from_results( + results=results, options=options + ): + assert result.record == results[0] + break diff --git a/python/tests/unit/data/test_vector_search_mixins.py b/python/tests/unit/data/test_vector_search_mixins.py new file mode 100644 index 000000000000..0590b02bd41c --- /dev/null +++ b/python/tests/unit/data/test_vector_search_mixins.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft. All rights reserved. + +from pytest import raises + +from semantic_kernel.data import VectorizableTextSearchMixin, VectorizedSearchMixin, VectorTextSearchMixin +from semantic_kernel.exceptions import VectorStoreMixinException + + +class VectorTextSearchMixinTest(VectorTextSearchMixin): + """The mixin for text search, to be used in combination with VectorSearchBase.""" + + pass + + +class VectorizableTextSearchMixinTest(VectorizableTextSearchMixin): + """The mixin for text search, to be used in combination with VectorSearchBase.""" + + pass + + +class VectorizedSearchMixinTest(VectorizedSearchMixin): + """The mixin for text search, to be used in combination with VectorSearchBase.""" + + pass + + +async def test_text_search(): + test_instance = VectorTextSearchMixinTest() + assert test_instance is not None + with raises(VectorStoreMixinException): + await test_instance.text_search("test") + + +async def test_vectorizable_text_search(): + test_instance = VectorizableTextSearchMixinTest() + assert test_instance is not None + with raises(VectorStoreMixinException): + await test_instance.vectorizable_text_search("test") + + +async def test_vectorized_text_search(): + test_instance = VectorizedSearchMixinTest() + assert test_instance is not None + with raises(VectorStoreMixinException): + await test_instance.vectorized_search([1, 2, 3]) diff --git a/python/tests/unit/data/test_vector_store_model_decorator.py b/python/tests/unit/data/test_vector_store_model_decorator.py index 0a49530ecd4e..4f9707fe7032 100644 --- a/python/tests/unit/data/test_vector_store_model_decorator.py +++ b/python/tests/unit/data/test_vector_store_model_decorator.py @@ -206,19 +206,25 @@ class DataModelClass: key: Annotated[str, VectorStoreRecordKeyField()] list1: Annotated[list[int], VectorStoreRecordDataField()] list2: Annotated[list[str], VectorStoreRecordDataField] + list3: Annotated[list[str] | None, VectorStoreRecordDataField] dict1: Annotated[dict[str, int], VectorStoreRecordDataField()] dict2: Annotated[dict[str, str], VectorStoreRecordDataField] + dict3: Annotated[dict[str, str] | None, VectorStoreRecordDataField] assert hasattr(DataModelClass, "__kernel_vectorstoremodel__") assert hasattr(DataModelClass, "__kernel_vectorstoremodel_definition__") data_model_definition: VectorStoreRecordDefinition = DataModelClass.__kernel_vectorstoremodel_definition__ - assert len(data_model_definition.fields) == 5 + assert len(data_model_definition.fields) == 7 assert data_model_definition.fields["list1"].name == "list1" assert data_model_definition.fields["list1"].property_type == "list[int]" assert data_model_definition.fields["list2"].name == "list2" assert data_model_definition.fields["list2"].property_type == "list[str]" + assert data_model_definition.fields["list3"].name == "list3" + assert data_model_definition.fields["list3"].property_type == "list[str]" assert data_model_definition.fields["dict1"].name == "dict1" assert data_model_definition.fields["dict1"].property_type == "dict" assert data_model_definition.fields["dict2"].name == "dict2" assert data_model_definition.fields["dict2"].property_type == "dict" + assert data_model_definition.fields["dict3"].name == "dict3" + assert data_model_definition.fields["dict3"].property_type == "dict" assert data_model_definition.container_mode is False diff --git a/python/tests/unit/data/test_vector_store_record_collection.py b/python/tests/unit/data/test_vector_store_record_collection.py index 3e94e6c5ea7d..dcd198a97dbe 100644 --- a/python/tests/unit/data/test_vector_store_record_collection.py +++ b/python/tests/unit/data/test_vector_store_record_collection.py @@ -5,67 +5,21 @@ import numpy as np from pandas import DataFrame -from pytest import fixture, mark, raises +from pytest import mark, raises -from semantic_kernel.data import VectorStoreRecordCollection +from semantic_kernel.data.record_definition.vector_store_model_protocols import ( + SerializeMethodProtocol, + ToDictMethodProtocol, +) from semantic_kernel.exceptions import ( - MemoryConnectorException, VectorStoreModelDeserializationException, VectorStoreModelSerializationException, VectorStoreModelValidationError, + VectorStoreOperationException, ) -@fixture(scope="function") -def vector_store_record_collection( - DictVectorStoreRecordCollection, - data_model_definition, - data_model_serialize_definition, - data_model_to_from_dict_definition, - data_model_container_definition, - data_model_container_serialize_definition, - data_model_pandas_definition, - data_model_type_vanilla, - data_model_type_vanilla_serialize, - data_model_type_vanilla_to_from_dict, - data_model_type_pydantic, - data_model_type_dataclass, - data_model_type_vector_array, - request, -) -> VectorStoreRecordCollection: - item = request.param if request and hasattr(request, "param") else "definition_basic" - defs = { - "definition_basic": data_model_definition, - "definition_with_serialize": data_model_serialize_definition, - "definition_with_to_from": data_model_to_from_dict_definition, - "definition_container": data_model_container_definition, - "definition_container_serialize": data_model_container_serialize_definition, - "definition_pandas": data_model_pandas_definition, - "type_vanilla": data_model_type_vanilla, - "type_vanilla_with_serialize": data_model_type_vanilla_serialize, - "type_vanilla_with_to_from_dict": data_model_type_vanilla_to_from_dict, - "type_pydantic": data_model_type_pydantic, - "type_dataclass": data_model_type_dataclass, - "type_vector_array": data_model_type_vector_array, - } - if item.endswith("pandas"): - return DictVectorStoreRecordCollection( - collection_name="test", - data_model_type=DataFrame, - data_model_definition=defs[item], - ) - if item.startswith("definition_"): - return DictVectorStoreRecordCollection( - collection_name="test", - data_model_type=dict, - data_model_definition=defs[item], - ) - return DictVectorStoreRecordCollection( - collection_name="test", - data_model_type=defs[item], - ) - - +# region init def test_init(DictVectorStoreRecordCollection, data_model_definition): vsrc = DictVectorStoreRecordCollection( collection_name="test", @@ -79,6 +33,59 @@ def test_init(DictVectorStoreRecordCollection, data_model_definition): assert vsrc._key_field_name == "id" +def test_data_model_validation(data_model_type_vanilla, DictVectorStoreRecordCollection): + DictVectorStoreRecordCollection.supported_key_types = PropertyMock(return_value=["str"]) + DictVectorStoreRecordCollection.supported_vector_types = PropertyMock(return_value=["float"]) + DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=data_model_type_vanilla, + ) + + +def test_data_model_validation_key_fail(data_model_type_vanilla, DictVectorStoreRecordCollection): + DictVectorStoreRecordCollection.supported_key_types = PropertyMock(return_value=["int"]) + with raises(VectorStoreModelValidationError, match="Key field must be one of"): + DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=data_model_type_vanilla, + ) + + +def test_data_model_validation_vector_fail(data_model_type_vanilla, DictVectorStoreRecordCollection): + DictVectorStoreRecordCollection.supported_vector_types = PropertyMock(return_value=["list[int]"]) + with raises(VectorStoreModelValidationError, match="Vector field "): + DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=data_model_type_vanilla, + ) + + +# region Collection +async def test_collection_operations(vector_store_record_collection): + await vector_store_record_collection.create_collection() + assert await vector_store_record_collection.does_collection_exist() + record = {"id": "id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} + await vector_store_record_collection.upsert(record) + assert len(vector_store_record_collection.inner_storage) == 1 + await vector_store_record_collection.delete_collection() + assert vector_store_record_collection.inner_storage == {} + await vector_store_record_collection.create_collection_if_not_exists() + + +async def test_collection_create_if_not_exists(DictVectorStoreRecordCollection, data_model_definition): + DictVectorStoreRecordCollection.does_collection_exist = AsyncMock(return_value=False) + create_mock = AsyncMock() + DictVectorStoreRecordCollection.create_collection = create_mock + vector_store_record_collection = DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + ) + await vector_store_record_collection.create_collection_if_not_exists() + create_mock.assert_called_once() + + +# region CRUD @mark.parametrize( "vector_store_record_collection", [ @@ -111,11 +118,13 @@ async def test_crud_operations(vector_store_record_collection): if vector_store_record_collection.data_model_type is dict: assert vector_store_record_collection.inner_storage[id] == record else: + assert not isinstance(record, dict) assert vector_store_record_collection.inner_storage[id]["content"] == record.content record_2 = await vector_store_record_collection.get(id) if vector_store_record_collection.data_model_type is dict: assert record_2 == record else: + assert not isinstance(record, dict) if isinstance(record.vector, list): assert record_2 == record else: @@ -156,6 +165,7 @@ async def test_crud_batch_operations(vector_store_record_collection): if vector_store_record_collection.data_model_type is dict: assert vector_store_record_collection.inner_storage[ids[0]] == batch[0] else: + assert not isinstance(batch[0], dict) assert vector_store_record_collection.inner_storage[ids[0]]["content"] == batch[0].content records = await vector_store_record_collection.get_batch(ids) assert records == batch @@ -248,6 +258,27 @@ async def test_crud_batch_operations_pandas(vector_store_record_collection): assert len(vector_store_record_collection.inner_storage) == 0 +async def test_upsert_with_vectorizing(vector_store_record_collection): + record = {"id": "test_id", "content": "test_content"} + record2 = {"id": "test_id", "content": "test_content"} + + async def embedding_func(record, type, definition): + if isinstance(record, list): + for r in record: + r["vector"] = [1.0, 2.0, 3.0] + return record + record["vector"] = [1.0, 2.0, 3.0] + return record + + await vector_store_record_collection.upsert(record, embedding_generation_function=embedding_func) + assert vector_store_record_collection.inner_storage["test_id"]["vector"] == [1.0, 2.0, 3.0] + await vector_store_record_collection.delete("test_id") + assert len(vector_store_record_collection.inner_storage) == 0 + await vector_store_record_collection.upsert_batch([record2], embedding_generation_function=embedding_func) + assert vector_store_record_collection.inner_storage["test_id"]["vector"] == [1.0, 2.0, 3.0] + + +# region Fails async def test_upsert_fail(DictVectorStoreRecordCollection, data_model_definition): DictVectorStoreRecordCollection._inner_upsert = MagicMock(side_effect=Exception) vector_store_record_collection = DictVectorStoreRecordCollection( @@ -256,9 +287,9 @@ async def test_upsert_fail(DictVectorStoreRecordCollection, data_model_definitio data_model_definition=data_model_definition, ) record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} - with raises(MemoryConnectorException, match="Error upserting record:"): + with raises(VectorStoreOperationException, match="Error upserting record:"): await vector_store_record_collection.upsert(record) - with raises(MemoryConnectorException, match="Error upserting records:"): + with raises(VectorStoreOperationException, match="Error upserting records:"): await vector_store_record_collection.upsert_batch([record]) assert len(vector_store_record_collection.inner_storage) == 0 @@ -273,9 +304,25 @@ async def test_get_fail(DictVectorStoreRecordCollection, data_model_definition): record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} await vector_store_record_collection.upsert(record) assert len(vector_store_record_collection.inner_storage) == 1 - with raises(MemoryConnectorException, match="Error getting record:"): + with raises(VectorStoreOperationException, match="Error getting record:"): await vector_store_record_collection.get("test_id") - with raises(MemoryConnectorException, match="Error getting records:"): + with raises(VectorStoreOperationException, match="Error getting records:"): + await vector_store_record_collection.get_batch(["test_id"]) + + +async def test_deserialize_in_get_fail(DictVectorStoreRecordCollection, data_model_definition): + data_model_definition.deserialize = MagicMock(side_effect=Exception) + vector_store_record_collection = DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + ) + record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} + await vector_store_record_collection.upsert(record) + assert len(vector_store_record_collection.inner_storage) == 1 + with raises(VectorStoreModelDeserializationException, match="Error deserializing records:"): + await vector_store_record_collection.get("test_id") + with raises(VectorStoreModelDeserializationException, match="Error deserializing records:"): await vector_store_record_collection.get_batch(["test_id"]) @@ -292,7 +339,9 @@ async def test_get_fail_multiple(DictVectorStoreRecordCollection, data_model_def patch( "semantic_kernel.data.vector_storage.vector_store_record_collection.VectorStoreRecordCollection.deserialize" ) as deserialize_mock, - raises(MemoryConnectorException, match="Error deserializing record, multiple records returned:"), + raises( + VectorStoreModelDeserializationException, match="Error deserializing record, multiple records returned:" + ), ): deserialize_mock.return_value = [ {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]}, @@ -301,61 +350,49 @@ async def test_get_fail_multiple(DictVectorStoreRecordCollection, data_model_def await vector_store_record_collection.get("test_id") -async def test_serialize_fail(DictVectorStoreRecordCollection, data_model_definition): - DictVectorStoreRecordCollection.serialize = MagicMock(side_effect=Exception) +async def test_delete_fail(DictVectorStoreRecordCollection, data_model_definition): + DictVectorStoreRecordCollection._inner_delete = MagicMock(side_effect=Exception) vector_store_record_collection = DictVectorStoreRecordCollection( collection_name="test", data_model_type=dict, data_model_definition=data_model_definition, ) record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} - with raises(MemoryConnectorException, match="Error serializing record"): - await vector_store_record_collection.upsert(record) - with raises(MemoryConnectorException, match="Error serializing record"): - await vector_store_record_collection.upsert_batch([record]) + await vector_store_record_collection.upsert(record) + assert len(vector_store_record_collection.inner_storage) == 1 + with raises(VectorStoreOperationException, match="Error deleting record:"): + await vector_store_record_collection.delete("test_id") + with raises(VectorStoreOperationException, match="Error deleting records:"): + await vector_store_record_collection.delete_batch(["test_id"]) + assert len(vector_store_record_collection.inner_storage) == 1 -async def test_deserialize_fail(DictVectorStoreRecordCollection, data_model_definition): - DictVectorStoreRecordCollection.deserialize = MagicMock(side_effect=Exception) +# region Serialize +async def test_serialize_in_upsert_fail(DictVectorStoreRecordCollection, data_model_definition): + DictVectorStoreRecordCollection.serialize = MagicMock(side_effect=VectorStoreModelSerializationException) vector_store_record_collection = DictVectorStoreRecordCollection( collection_name="test", data_model_type=dict, data_model_definition=data_model_definition, ) record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} - vector_store_record_collection.inner_storage["test_id"] = record - with raises(MemoryConnectorException, match="Error deserializing record"): - await vector_store_record_collection.get("test_id") - with raises(MemoryConnectorException, match="Error deserializing record"): - await vector_store_record_collection.get_batch(["test_id"]) + with raises(VectorStoreModelSerializationException): + await vector_store_record_collection.upsert(record) + with raises(VectorStoreModelSerializationException): + await vector_store_record_collection.upsert_batch([record]) -def test_serialize_custom_fail(DictVectorStoreRecordCollection, data_model_type_vanilla_serialize): - data_model_type_vanilla_serialize.serialize = MagicMock(side_effect=Exception) +def test_serialize_data_model_type_serialize_fail(DictVectorStoreRecordCollection, data_model_type_vanilla_serialize): vector_store_record_collection = DictVectorStoreRecordCollection( collection_name="test", data_model_type=data_model_type_vanilla_serialize, ) - record = data_model_type_vanilla_serialize( - content="test_content", - vector=[1.0, 2.0, 3.0], - id="test_id", - ) - with raises(VectorStoreModelSerializationException, match="Error serializing record:"): + record = MagicMock(spec=SerializeMethodProtocol) + record.serialize = MagicMock(side_effect=Exception) + with raises(VectorStoreModelSerializationException, match="Error serializing record"): vector_store_record_collection.serialize(record) -def test_deserialize_custom_fail(DictVectorStoreRecordCollection, data_model_type_vanilla_serialize): - data_model_type_vanilla_serialize.deserialize = MagicMock(side_effect=Exception) - vector_store_record_collection = DictVectorStoreRecordCollection( - collection_name="test", - data_model_type=data_model_type_vanilla_serialize, - ) - record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} - with raises(VectorStoreModelSerializationException, match="Error deserializing record:"): - vector_store_record_collection.deserialize(record) - - def test_serialize_data_model_to_dict_fail_mapping(DictVectorStoreRecordCollection, data_model_definition): vector_store_record_collection = DictVectorStoreRecordCollection( collection_name="test", @@ -363,7 +400,7 @@ def test_serialize_data_model_to_dict_fail_mapping(DictVectorStoreRecordCollecti data_model_definition=data_model_definition, ) record = {"content": "test_content", "vector": [1.0, 2.0, 3.0]} - with raises(VectorStoreModelSerializationException, match="Error serializing record"): + with raises(KeyError): vector_store_record_collection._serialize_data_model_to_dict(record) @@ -373,10 +410,83 @@ def test_serialize_data_model_to_dict_fail_object(DictVectorStoreRecordCollectio data_model_type=data_model_type_vanilla, ) record = Mock(spec=data_model_type_vanilla) - with raises(VectorStoreModelSerializationException, match="Error serializing record"): + with raises(AttributeError): vector_store_record_collection._serialize_data_model_to_dict(record) +@mark.parametrize("vector_store_record_collection", ["type_pydantic"], indirect=True) +def test_pydantic_serialize_fail(vector_store_record_collection): + id = "test_id" + model = deepcopy(vector_store_record_collection.data_model_type) + model.model_dump = MagicMock(side_effect=Exception) + vector_store_record_collection.data_model_type = model + dict_record = {"id": id, "content": "test_content", "vector": [1.0, 2.0, 3.0]} + record = model(**dict_record) + with raises(VectorStoreModelSerializationException, match="Error serializing record"): + vector_store_record_collection.serialize(record) + + +@mark.parametrize("vector_store_record_collection", ["type_vanilla_with_to_from_dict"], indirect=True) +def test_to_dict_fail(vector_store_record_collection): + record = MagicMock(spec=ToDictMethodProtocol) + record.to_dict = MagicMock(side_effect=Exception) + with raises(VectorStoreModelSerializationException, match="Error serializing record"): + vector_store_record_collection.serialize(record) + + +def test_serialize_with_array_func(DictVectorStoreRecordCollection, data_model_type_pydantic_array): + vector_store_record_collection = DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=data_model_type_pydantic_array, + ) + record = data_model_type_pydantic_array(**{ + "id": "test_id", + "content": "test_content", + "vector": np.array([1.0, 2.0, 3.0]), + }) + serialized_record = vector_store_record_collection.serialize(record) + assert serialized_record["vector"] == [1.0, 2.0, 3.0] + + +# region Deserialize + + +async def test_deserialize_definition_fail(DictVectorStoreRecordCollection, data_model_definition): + data_model_definition.deserialize = MagicMock(side_effect=Exception) + vector_store_record_collection = DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + ) + record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} + vector_store_record_collection.inner_storage["test_id"] = record + with raises(VectorStoreModelDeserializationException, match="Error deserializing record"): + await vector_store_record_collection.get("test_id") + with raises(VectorStoreModelDeserializationException, match="Error deserializing record"): + await vector_store_record_collection.get_batch(["test_id"]) + + +async def test_deserialize_definition_none(DictVectorStoreRecordCollection, data_model_definition): + vector_store_record_collection = DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + ) + assert vector_store_record_collection.deserialize([]) is None + assert vector_store_record_collection.deserialize({}) is None + + +def test_deserialize_type_fail(DictVectorStoreRecordCollection, data_model_type_vanilla_serialize): + vector_store_record_collection = DictVectorStoreRecordCollection( + collection_name="test", + data_model_type=data_model_type_vanilla_serialize, + ) + record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} + vector_store_record_collection.data_model_type.deserialize = MagicMock(side_effect=Exception) + with raises(VectorStoreModelDeserializationException, match="Error deserializing record"): + vector_store_record_collection.deserialize(record) + + def test_deserialize_dict_data_model_fail_sequence(DictVectorStoreRecordCollection, data_model_type_vanilla): vector_store_record_collection = DictVectorStoreRecordCollection( collection_name="test", @@ -392,7 +502,7 @@ def test_deserialize_dict_data_model_fail(DictVectorStoreRecordCollection, data_ data_model_type=dict, data_model_definition=data_model_definition, ) - with raises(VectorStoreModelDeserializationException, match="Error deserializing record"): + with raises(KeyError): vector_store_record_collection._deserialize_dict_to_data_model({ "content": "test_content", "vector": [1.0, 2.0, 3.0], @@ -412,121 +522,35 @@ def test_deserialize_dict_data_model_shortcut(DictVectorStoreRecordCollection, d @mark.parametrize("vector_store_record_collection", ["type_pydantic"], indirect=True) -async def test_pydantic_fail(vector_store_record_collection): +async def test_pydantic_deserialize_fail(vector_store_record_collection): id = "test_id" - model = deepcopy(vector_store_record_collection.data_model_type) dict_record = {"id": id, "content": "test_content", "vector": [1.0, 2.0, 3.0]} - record = model(**dict_record) - model.model_dump = MagicMock(side_effect=Exception) - with raises(VectorStoreModelSerializationException, match="Error serializing record:"): - vector_store_record_collection.serialize(record) - with raises(MemoryConnectorException, match="Error serializing record:"): - await vector_store_record_collection.upsert(record) - model.model_validate = MagicMock(side_effect=Exception) - with raises(VectorStoreModelDeserializationException, match="Error deserializing record:"): + vector_store_record_collection.data_model_type.model_validate = MagicMock(side_effect=Exception) + with raises(VectorStoreModelDeserializationException, match="Error deserializing record"): vector_store_record_collection.deserialize(dict_record) @mark.parametrize("vector_store_record_collection", ["type_vanilla_with_to_from_dict"], indirect=True) -def test_to_from_dict_fail(vector_store_record_collection): +def test_from_dict_fail(vector_store_record_collection): id = "test_id" model = deepcopy(vector_store_record_collection.data_model_type) dict_record = {"id": id, "content": "test_content", "vector": [1.0, 2.0, 3.0]} - record = model(**dict_record) - model.to_dict = MagicMock(side_effect=Exception) - with raises(VectorStoreModelSerializationException, match="Error serializing record:"): - vector_store_record_collection.serialize(record) model.from_dict = MagicMock(side_effect=Exception) - with raises(VectorStoreModelDeserializationException, match="Error deserializing record:"): + vector_store_record_collection.data_model_type = model + with raises(VectorStoreModelDeserializationException, match="Error deserializing record"): vector_store_record_collection.deserialize(dict_record) -async def test_delete_fail(DictVectorStoreRecordCollection, data_model_definition): - DictVectorStoreRecordCollection._inner_delete = MagicMock(side_effect=Exception) - vector_store_record_collection = DictVectorStoreRecordCollection( - collection_name="test", - data_model_type=dict, - data_model_definition=data_model_definition, - ) - record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} - await vector_store_record_collection.upsert(record) - assert len(vector_store_record_collection.inner_storage) == 1 - with raises(MemoryConnectorException, match="Error deleting record:"): - await vector_store_record_collection.delete("test_id") - with raises(MemoryConnectorException, match="Error deleting records:"): - await vector_store_record_collection.delete_batch(["test_id"]) - assert len(vector_store_record_collection.inner_storage) == 1 - - -async def test_collection_operations(vector_store_record_collection): - await vector_store_record_collection.create_collection() - assert await vector_store_record_collection.does_collection_exist() - record = {"id": "id", "content": "test_content", "vector": [1.0, 2.0, 3.0]} - await vector_store_record_collection.upsert(record) - assert len(vector_store_record_collection.inner_storage) == 1 - await vector_store_record_collection.delete_collection() - assert vector_store_record_collection.inner_storage == {} - await vector_store_record_collection.create_collection_if_not_exists() - - -async def test_collection_create_if_not_exists(DictVectorStoreRecordCollection, data_model_definition): - DictVectorStoreRecordCollection.does_collection_exist = AsyncMock(return_value=False) - create_mock = AsyncMock() - DictVectorStoreRecordCollection.create_collection = create_mock +def test_deserialize_with_array_func(DictVectorStoreRecordCollection, data_model_type_pydantic_array): vector_store_record_collection = DictVectorStoreRecordCollection( collection_name="test", - data_model_type=dict, - data_model_definition=data_model_definition, + data_model_type=data_model_type_pydantic_array, ) - await vector_store_record_collection.create_collection_if_not_exists() - create_mock.assert_called_once() - - -def test_data_model_validation(data_model_type_vanilla, DictVectorStoreRecordCollection): - DictVectorStoreRecordCollection.supported_key_types = PropertyMock(return_value=["str"]) - DictVectorStoreRecordCollection.supported_vector_types = PropertyMock(return_value=["float"]) - DictVectorStoreRecordCollection( - collection_name="test", - data_model_type=data_model_type_vanilla, - ) - - -def test_data_model_validation_key_fail(data_model_type_vanilla, DictVectorStoreRecordCollection): - DictVectorStoreRecordCollection.supported_key_types = PropertyMock(return_value=["int"]) - with raises(VectorStoreModelValidationError, match="Key field must be one of"): - DictVectorStoreRecordCollection( - collection_name="test", - data_model_type=data_model_type_vanilla, - ) - - -def test_data_model_validation_vector_fail(data_model_type_vanilla, DictVectorStoreRecordCollection): - DictVectorStoreRecordCollection.supported_vector_types = PropertyMock(return_value=["list[int]"]) - with raises(VectorStoreModelValidationError, match="Vector field "): - DictVectorStoreRecordCollection( - collection_name="test", - data_model_type=data_model_type_vanilla, - ) - - -async def test_upsert_with_vectorizing(vector_store_record_collection): - record = {"id": "test_id", "content": "test_content"} - record2 = {"id": "test_id", "content": "test_content"} - - async def embedding_func(record, type, definition): - if isinstance(record, list): - for r in record: - r["vector"] = [1.0, 2.0, 3.0] - return record - record["vector"] = [1.0, 2.0, 3.0] - return record - - await vector_store_record_collection.upsert(record, embedding_generation_function=embedding_func) - assert vector_store_record_collection.inner_storage["test_id"]["vector"] == [1.0, 2.0, 3.0] - await vector_store_record_collection.delete("test_id") - assert len(vector_store_record_collection.inner_storage) == 0 - await vector_store_record_collection.upsert_batch([record2], embedding_generation_function=embedding_func) - assert vector_store_record_collection.inner_storage["test_id"]["vector"] == [1.0, 2.0, 3.0] - - -# TODO (eavanvalkenburg): pandas container test + record = { + "id": "test_id", + "content": "test_content", + "vector": [1.0, 2.0, 3.0], + } + deserialized_record = vector_store_record_collection.deserialize(record) + assert isinstance(deserialized_record.vector, np.ndarray) + assert np.array_equal(deserialized_record.vector, np.array([1.0, 2.0, 3.0])) diff --git a/python/tests/unit/data/test_vector_store_record_definition.py b/python/tests/unit/data/test_vector_store_record_definition.py index 897623c95dc7..3f1684748d6c 100644 --- a/python/tests/unit/data/test_vector_store_record_definition.py +++ b/python/tests/unit/data/test_vector_store_record_definition.py @@ -7,6 +7,7 @@ VectorStoreRecordDefinition, VectorStoreRecordKeyField, ) +from semantic_kernel.data.record_definition.vector_store_record_fields import VectorStoreRecordVectorField from semantic_kernel.exceptions import VectorStoreModelException @@ -55,3 +56,64 @@ def test_no_matching_vector_field_fail(): "content": VectorStoreRecordDataField(has_embedding=True, embedding_property_name="vector"), } ) + + +def test_vector_and_non_vector_field_names(): + definition = VectorStoreRecordDefinition( + fields={ + "id": VectorStoreRecordKeyField(), + "content": VectorStoreRecordDataField(), + "vector": VectorStoreRecordVectorField(), + } + ) + assert definition.vector_field_names == ["vector"] + assert definition.non_vector_field_names == ["id", "content"] + + +def test_try_get_vector_field(): + definition = VectorStoreRecordDefinition( + fields={ + "id": VectorStoreRecordKeyField(), + "content": VectorStoreRecordDataField(), + "vector": VectorStoreRecordVectorField(), + } + ) + assert definition.try_get_vector_field() == definition.fields["vector"] + assert definition.try_get_vector_field("vector") == definition.fields["vector"] + + +def test_try_get_vector_field_none(): + definition = VectorStoreRecordDefinition( + fields={ + "id": VectorStoreRecordKeyField(), + "content": VectorStoreRecordDataField(), + } + ) + assert definition.try_get_vector_field() is None + with raises(VectorStoreModelException, match="Field vector not found."): + definition.try_get_vector_field("vector") + + +def test_try_get_vector_field_wrong_name_fail(): + definition = VectorStoreRecordDefinition( + fields={ + "id": VectorStoreRecordKeyField(), + "content": VectorStoreRecordDataField(), + } + ) + with raises(VectorStoreModelException, match="Field content is not a vector field."): + definition.try_get_vector_field("content") + + +def test_get_field_names(): + definition = VectorStoreRecordDefinition( + fields={ + "id": VectorStoreRecordKeyField(), + "content": VectorStoreRecordDataField(), + "vector": VectorStoreRecordVectorField(), + } + ) + assert definition.get_field_names() == ["id", "content", "vector"] + assert definition.get_field_names(include_vector_fields=False) == ["id", "content"] + assert definition.get_field_names(include_key_field=False) == ["content", "vector"] + assert definition.get_field_names(include_vector_fields=False, include_key_field=False) == ["content"] diff --git a/python/tests/unit/data/test_vector_store_record_utils.py b/python/tests/unit/data/test_vector_store_record_utils.py index cfb2ea448d64..669ede25376d 100644 --- a/python/tests/unit/data/test_vector_store_record_utils.py +++ b/python/tests/unit/data/test_vector_store_record_utils.py @@ -40,3 +40,14 @@ async def test_add_vector_wrong_fields(): record = {"id": "test_id", "content": "content"} with raises(VectorStoreModelException, match="Embedding field"): await utils.add_vector_to_records(record, None, data_model) + + +async def test_fail(): + kernel = MagicMock(spec=Kernel) + kernel.add_embedding_to_object = AsyncMock() + utils = VectorStoreRecordUtils(kernel) + assert utils is not None + record = {"id": "test_id", "content": "content"} + with raises(VectorStoreModelException, match="Data model definition is required"): + await utils.add_vector_to_records(record, dict, None) + kernel.add_embedding_to_object.assert_not_called() diff --git a/python/tests/unit/data/test_vector_store_text_search.py b/python/tests/unit/data/test_vector_store_text_search.py index 0f485349d098..70358011a592 100644 --- a/python/tests/unit/data/test_vector_store_text_search.py +++ b/python/tests/unit/data/test_vector_store_text_search.py @@ -2,11 +2,16 @@ from unittest.mock import patch +from pydantic import BaseModel from pytest import fixture, raises from semantic_kernel.connectors.ai.open_ai import AzureTextEmbedding from semantic_kernel.data import VectorStoreTextSearch -from semantic_kernel.exceptions import VectorStoreTextSearchValidationError +from semantic_kernel.data.text_search.text_search_result import TextSearchResult +from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions +from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult +from semantic_kernel.exceptions import VectorStoreInitializationException +from semantic_kernel.utils.list_handler import desync_list @fixture @@ -28,6 +33,7 @@ async def test_from_vectorizable_text_search(vector_collection): assert search is not None assert text_search_result is not None assert search_result is not None + assert vsts.options_class is VectorSearchOptions async def test_from_vector_text_search(vector_collection): @@ -60,10 +66,86 @@ async def test_from_vectorized_search(vector_collection, azure_openai_unit_test_ def test_validation_no_embedder_for_vectorized_search(vector_collection): - with raises(VectorStoreTextSearchValidationError): + with raises(VectorStoreInitializationException): VectorStoreTextSearch(vectorized_search=vector_collection) def test_validation_no_collections(): - with raises(VectorStoreTextSearchValidationError): + with raises(VectorStoreInitializationException): VectorStoreTextSearch() + + +async def test_get_results_as_string(vector_collection): + test_search = VectorStoreTextSearch.from_vector_text_search(vector_text_search=vector_collection) + results = [ + res + async for res in test_search._get_results_as_strings(results=desync_list([VectorSearchResult(record="test")])) + ] + assert results == ["test"] + + class TestClass(BaseModel): + test: str + + results = [ + res + async for res in test_search._get_results_as_strings( + results=desync_list([VectorSearchResult(record=TestClass(test="test"))]) + ) + ] + + assert results == ['{"test":"test"}'] + + test_search = VectorStoreTextSearch.from_vector_text_search( + vector_text_search=vector_collection, string_mapper=lambda x: x.test + ) + + class TestClass(BaseModel): + test: str + + results = [ + res + async for res in test_search._get_results_as_strings( + results=desync_list([VectorSearchResult(record=TestClass(test="test"))]) + ) + ] + + assert results == ["test"] + + +async def test_get_results_as_test_search_result(vector_collection): + test_search = VectorStoreTextSearch.from_vector_text_search(vector_text_search=vector_collection) + results = [ + res + async for res in test_search._get_results_as_text_search_result( + results=desync_list([VectorSearchResult(record="test")]) + ) + ] + assert results == [TextSearchResult(value="test")] + + class TestClass(BaseModel): + test: str + + results = [ + res + async for res in test_search._get_results_as_text_search_result( + results=desync_list([VectorSearchResult(record=TestClass(test="test"))]) + ) + ] + + assert results == [TextSearchResult(value='{"test":"test"}')] + + test_search = VectorStoreTextSearch.from_vector_text_search( + vector_text_search=vector_collection, text_search_results_mapper=lambda x: TextSearchResult(value=x.test) + ) + + class TestClass(BaseModel): + test: str + + results = [ + res + async for res in test_search._get_results_as_text_search_result( + results=desync_list([VectorSearchResult(record=TestClass(test="test"))]) + ) + ] + + assert results == [TextSearchResult(value="test")] diff --git a/python/tests/utils.py b/python/tests/utils.py index fc27ce6e1a31..aac2ca65f408 100644 --- a/python/tests/utils.py +++ b/python/tests/utils.py @@ -14,6 +14,7 @@ async def retry( func: Callable[..., Awaitable[Any]], retries: int = 20, reset: Callable[..., None] | None = None, + name: str | None = None, ): """Retry the function if it raises an exception. @@ -23,9 +24,9 @@ async def retry( reset (function): Function to reset the state of any variables used in the function """ - logger.info(f"Running {retries} retries with func: {func.__module__}") + logger.info(f"Running {retries} retries with func: {name or func.__module__}") for i in range(retries): - logger.info(f" Try {i + 1} for {func.__module__}") + logger.info(f" Try {i + 1} for {name or func.__module__}") try: if reset: reset()