diff --git a/tests/unit/data/test_errors.py b/tests/unit/data/test_errors.py new file mode 100644 index 000000000..c4b2f0265 --- /dev/null +++ b/tests/unit/data/test_errors.py @@ -0,0 +1,136 @@ +from pinecone.db_data.errors import ( + VectorDictionaryMissingKeysError, + VectorDictionaryExcessKeysError, + VectorTupleLengthError, + SparseValuesTypeError, + SparseValuesMissingKeysError, + SparseValuesDictionaryExpectedError, + MetadataDictionaryExpectedError, +) + + +class TestVectorDictionaryMissingKeysError: + """Test VectorDictionaryMissingKeysError exception.""" + + def test_error_message_includes_missing_keys(self): + """Test that error message lists missing required fields.""" + item = {"values": [0.1, 0.2]} + error = VectorDictionaryMissingKeysError(item) + assert "missing required fields" in str(error).lower() + assert "id" in str(error) + + def test_error_message_with_multiple_missing_keys(self): + """Test error message when multiple keys are missing.""" + item = {} + error = VectorDictionaryMissingKeysError(item) + assert "missing required fields" in str(error).lower() + + +class TestVectorDictionaryExcessKeysError: + """Test VectorDictionaryExcessKeysError exception.""" + + def test_error_message_includes_excess_keys(self): + """Test that error message lists excess keys.""" + item = {"id": "1", "values": [0.1, 0.2], "extra_field": "value", "another_extra": 123} + error = VectorDictionaryExcessKeysError(item) + assert "excess keys" in str(error).lower() + assert "extra_field" in str(error) or "another_extra" in str(error) + + def test_error_message_includes_allowed_keys(self): + """Test that error message includes list of allowed keys.""" + item = {"id": "1", "values": [0.1, 0.2], "invalid": "key"} + error = VectorDictionaryExcessKeysError(item) + assert "allowed keys" in str(error).lower() + + +class TestVectorTupleLengthError: + """Test VectorTupleLengthError exception.""" + + def test_error_message_includes_tuple_length(self): + """Test that error message includes the tuple length.""" + item = ("id", "values", "metadata", "extra") + error = VectorTupleLengthError(item) + assert str(len(item)) in str(error) + assert "tuple" in str(error).lower() + + def test_error_message_with_length_one(self): + """Test error message for tuple of length 1.""" + item = ("id",) + error = VectorTupleLengthError(item) + assert "1" in str(error) + + def test_error_message_with_length_four(self): + """Test error message for tuple of length 4.""" + item = ("id", "values", "metadata", "extra") + error = VectorTupleLengthError(item) + assert "4" in str(error) + + +class TestSparseValuesTypeError: + """Test SparseValuesTypeError exception.""" + + def test_error_message_mentions_sparse_values(self): + """Test that error message mentions sparse_values.""" + error = SparseValuesTypeError() + assert "sparse_values" in str(error).lower() + + def test_error_is_both_value_and_type_error(self): + """Test that SparseValuesTypeError is both ValueError and TypeError.""" + error = SparseValuesTypeError() + assert isinstance(error, ValueError) + assert isinstance(error, TypeError) + + +class TestSparseValuesMissingKeysError: + """Test SparseValuesMissingKeysError exception.""" + + def test_error_message_includes_found_keys(self): + """Test that error message includes the keys that were found.""" + sparse_values_dict = {"indices": [0, 2]} + error = SparseValuesMissingKeysError(sparse_values_dict) + assert "missing required keys" in str(error).lower() + assert "indices" in str(error) or "values" in str(error) + + def test_error_message_with_empty_dict(self): + """Test error message when dictionary is empty.""" + sparse_values_dict = {} + error = SparseValuesMissingKeysError(sparse_values_dict) + assert "missing required keys" in str(error).lower() + + +class TestSparseValuesDictionaryExpectedError: + """Test SparseValuesDictionaryExpectedError exception.""" + + def test_error_message_includes_actual_type(self): + """Test that error message includes the actual type found.""" + sparse_values_dict = "not a dict" + error = SparseValuesDictionaryExpectedError(sparse_values_dict) + assert "dictionary" in str(error).lower() + assert "str" in str(error) or type(sparse_values_dict).__name__ in str(error) + + def test_error_message_with_integer(self): + """Test error message when integer is provided.""" + sparse_values_dict = 123 + error = SparseValuesDictionaryExpectedError(sparse_values_dict) + assert "dictionary" in str(error).lower() + assert isinstance(error, ValueError) + assert isinstance(error, TypeError) + + +class TestMetadataDictionaryExpectedError: + """Test MetadataDictionaryExpectedError exception.""" + + def test_error_message_includes_actual_type(self): + """Test that error message includes the actual type found.""" + item = {"metadata": "not a dict"} + error = MetadataDictionaryExpectedError(item) + assert "dictionary" in str(error).lower() + assert "metadata" in str(error).lower() + + def test_error_message_with_list(self): + """Test error message when list is provided as metadata.""" + item = {"metadata": [1, 2, 3]} + error = MetadataDictionaryExpectedError(item) + assert "dictionary" in str(error).lower() + assert isinstance(error, ValueError) + assert isinstance(error, TypeError) diff --git a/tests/unit/data/test_sparse_values_factory.py b/tests/unit/data/test_sparse_values_factory.py new file mode 100644 index 000000000..3b87381aa --- /dev/null +++ b/tests/unit/data/test_sparse_values_factory.py @@ -0,0 +1,152 @@ +import numpy as np +import pandas as pd +import pytest + +from pinecone.db_data.sparse_values_factory import SparseValuesFactory +from pinecone import SparseValues +from pinecone.core.openapi.db_data.models import SparseValues as OpenApiSparseValues +from pinecone.db_data.errors import ( + SparseValuesTypeError, + SparseValuesMissingKeysError, + SparseValuesDictionaryExpectedError, +) + + +class TestSparseValuesFactory: + """Test SparseValuesFactory for REST API (db_data module).""" + + def test_build_when_none_returns_none(self): + """Test that None input returns None.""" + assert SparseValuesFactory.build(None) is None + + def test_build_when_passed_openapi_sparse_values(self): + """Test that OpenApiSparseValues are returned unchanged.""" + sv = OpenApiSparseValues(indices=[0, 2], values=[0.1, 0.3]) + actual = SparseValuesFactory.build(sv) + assert actual == sv + assert actual is sv + + def test_build_when_given_sparse_values_dataclass(self): + """Test conversion from SparseValues dataclass to OpenApiSparseValues.""" + sv = SparseValues(indices=[0, 2], values=[0.1, 0.3]) + actual = SparseValuesFactory.build(sv) + expected = OpenApiSparseValues(indices=[0, 2], values=[0.1, 0.3]) + assert isinstance(actual, OpenApiSparseValues) + assert actual.indices == expected.indices + assert actual.values == expected.values + + @pytest.mark.parametrize( + "input_dict", + [ + {"indices": [2], "values": [0.3]}, + {"indices": [88, 102], "values": [-0.1, 0.3]}, + {"indices": [0, 2, 4], "values": [0.1, 0.3, 0.5]}, + {"indices": [0, 2, 4, 6], "values": [0.1, 0.3, 0.5, 0.7]}, + ], + ) + def test_build_when_valid_dictionary(self, input_dict): + """Test building from valid dictionary input.""" + actual = SparseValuesFactory.build(input_dict) + expected = OpenApiSparseValues(indices=input_dict["indices"], values=input_dict["values"]) + assert actual.indices == expected.indices + assert actual.values == expected.values + + @pytest.mark.parametrize( + "input_dict", + [ + {"indices": np.array([0, 2]), "values": [0.1, 0.3]}, + {"indices": [0, 2], "values": np.array([0.1, 0.3])}, + {"indices": np.array([0, 2]), "values": np.array([0.1, 0.3])}, + {"indices": pd.array([0, 2]), "values": [0.1, 0.3]}, + {"indices": [0, 2], "values": pd.array([0.1, 0.3])}, + {"indices": pd.array([0, 2]), "values": pd.array([0.1, 0.3])}, + ], + ) + def test_build_when_special_data_types(self, input_dict): + """Test that the factory handles numpy/pandas arrays correctly.""" + actual = SparseValuesFactory.build(input_dict) + expected = OpenApiSparseValues(indices=[0, 2], values=[0.1, 0.3]) + assert actual.indices == expected.indices + assert actual.values == expected.values + + @pytest.mark.parametrize( + "input_dict", + [{"indices": [2], "values": [0.3, 0.3]}, {"indices": [88, 102], "values": [-0.1]}], + ) + def test_build_when_list_sizes_dont_match(self, input_dict): + """Test that mismatched indices and values lengths raise ValueError.""" + with pytest.raises( + ValueError, match="Sparse values indices and values must have the same length" + ): + SparseValuesFactory.build(input_dict) + + @pytest.mark.parametrize( + "input_dict", + [ + {"indices": [2.0], "values": [0.3]}, + {"indices": ["2"], "values": [0.3]}, + {"indices": np.array([2.0]), "values": [0.3]}, + {"indices": pd.array([2.0]), "values": [0.3]}, + ], + ) + def test_build_when_non_integer_indices(self, input_dict): + """Test that non-integer indices raise SparseValuesTypeError.""" + with pytest.raises(SparseValuesTypeError): + SparseValuesFactory.build(input_dict) + + @pytest.mark.parametrize( + "input_dict", [{"indices": [2], "values": ["3.2"]}, {"indices": [2], "values": [True]}] + ) + def test_build_when_non_float_values(self, input_dict): + """Test that non-float values raise SparseValuesTypeError.""" + with pytest.raises(SparseValuesTypeError): + SparseValuesFactory.build(input_dict) + + def test_build_when_missing_indices_key(self): + """Test that missing 'indices' key raises SparseValuesMissingKeysError.""" + input_dict = {"values": [0.1, 0.3]} + with pytest.raises(SparseValuesMissingKeysError) as exc_info: + SparseValuesFactory.build(input_dict) + assert "indices" in str(exc_info.value) + + def test_build_when_missing_values_key(self): + """Test that missing 'values' key raises SparseValuesMissingKeysError.""" + input_dict = {"indices": [0, 2]} + with pytest.raises(SparseValuesMissingKeysError) as exc_info: + SparseValuesFactory.build(input_dict) + assert "values" in str(exc_info.value) + + def test_build_when_missing_both_keys(self): + """Test that missing both keys raises SparseValuesMissingKeysError.""" + input_dict = {} + with pytest.raises(SparseValuesMissingKeysError) as exc_info: + SparseValuesFactory.build(input_dict) + assert "indices" in str(exc_info.value) or "values" in str(exc_info.value) + + def test_build_when_not_a_dictionary(self): + """Test that non-dictionary input raises SparseValuesDictionaryExpectedError.""" + with pytest.raises(SparseValuesDictionaryExpectedError) as exc_info: + SparseValuesFactory.build("not a dict") + assert "dictionary" in str(exc_info.value).lower() + + with pytest.raises(SparseValuesDictionaryExpectedError): + SparseValuesFactory.build(123) + + with pytest.raises(SparseValuesDictionaryExpectedError): + SparseValuesFactory.build([1, 2, 3]) + + def test_build_when_empty_indices_list(self): + """Test that empty indices list is handled correctly.""" + input_dict = {"indices": [], "values": []} + actual = SparseValuesFactory.build(input_dict) + expected = OpenApiSparseValues(indices=[], values=[]) + assert actual.indices == expected.indices + assert actual.values == expected.values + + def test_build_when_empty_values_list(self): + """Test that empty values list is handled correctly.""" + input_dict = {"indices": [], "values": []} + actual = SparseValuesFactory.build(input_dict) + expected = OpenApiSparseValues(indices=[], values=[]) + assert actual.indices == expected.indices + assert actual.values == expected.values diff --git a/tests/unit/utils/test_check_kwargs.py b/tests/unit/utils/test_check_kwargs.py new file mode 100644 index 000000000..8c7d9655e --- /dev/null +++ b/tests/unit/utils/test_check_kwargs.py @@ -0,0 +1,85 @@ +from unittest.mock import patch + +from pinecone.utils import check_kwargs + + +def example_function(arg1, arg2, arg3=None): + """Example function for testing.""" + pass + + +class TestCheckKwargs: + """Test check_kwargs utility function.""" + + def test_no_unexpected_kwargs_no_logging(self): + """Test that no logging occurs when all kwargs are valid.""" + with patch("logging.exception") as mock_log: + check_kwargs(example_function, {"arg1", "arg2", "arg3"}) + mock_log.assert_not_called() + + def test_unexpected_kwargs_logs_warning(self): + """Test that unexpected kwargs trigger logging.""" + with patch("logging.exception") as mock_log: + check_kwargs(example_function, {"arg1", "arg2", "arg3", "unexpected_arg"}) + mock_log.assert_called_once() + call_args = mock_log.call_args[0][0] + assert "unexpected keyword argument" in call_args.lower() + assert "unexpected_arg" in call_args + + def test_multiple_unexpected_kwargs_logs_all(self): + """Test that multiple unexpected kwargs are all logged.""" + with patch("logging.exception") as mock_log: + check_kwargs(example_function, {"arg1", "arg2", "arg3", "unexpected1", "unexpected2"}) + mock_log.assert_called_once() + call_args = mock_log.call_args[0][0] + assert "unexpected1" in call_args or "unexpected2" in call_args + + def test_only_unexpected_kwargs(self): + """Test when only unexpected kwargs are provided.""" + with patch("logging.exception") as mock_log: + check_kwargs(example_function, {"unexpected_arg"}) + mock_log.assert_called_once() + call_args = mock_log.call_args[0][0] + assert "unexpected keyword argument" in call_args.lower() + + def test_empty_kwargs_set(self): + """Test with empty kwargs set.""" + with patch("logging.exception") as mock_log: + check_kwargs(example_function, set()) + mock_log.assert_not_called() + + def test_function_with_no_args(self): + """Test with function that has no arguments.""" + + def no_args_function(): + pass + + with patch("logging.exception") as mock_log: + check_kwargs(no_args_function, {"any_arg"}) + mock_log.assert_called_once() + + def test_function_with_varargs(self): + """Test with function that has *args.""" + + def varargs_function(*args): + pass + + with patch("logging.exception") as mock_log: + check_kwargs(varargs_function, {"any_arg"}) + mock_log.assert_called_once() + + def test_function_with_kwargs(self): + """Test with function that has **kwargs. + + Note: check_kwargs only checks explicit args, not **kwargs, + so it will still log unexpected args even for functions with **kwargs. + This is the current behavior of the function. + """ + + def kwargs_function(**kwargs): + pass + + with patch("logging.exception") as mock_log: + check_kwargs(kwargs_function, {"any_arg"}) + # check_kwargs doesn't check for **kwargs, so it will log + mock_log.assert_called_once() diff --git a/tests/unit/utils/test_error_handling.py b/tests/unit/utils/test_error_handling.py new file mode 100644 index 000000000..8faa2a959 --- /dev/null +++ b/tests/unit/utils/test_error_handling.py @@ -0,0 +1,97 @@ +import pytest + +from pinecone.utils.error_handling import validate_and_convert_errors, ProtocolError + + +class TestValidateAndConvertErrors: + """Test validate_and_convert_errors decorator.""" + + def test_successful_function_execution(self): + """Test that successful function execution is not affected.""" + + @validate_and_convert_errors + def test_func(): + return "success" + + result = test_func() + assert result == "success" + + def test_unrelated_exception_passed_through(self): + """Test that unrelated exceptions are passed through unchanged.""" + + @validate_and_convert_errors + def test_func(): + raise ValueError("test error") + + with pytest.raises(ValueError, match="test error"): + test_func() + + def test_max_retry_error_with_protocol_error_reason(self): + """Test that MaxRetryError with ProtocolError reason is converted.""" + + @validate_and_convert_errors + def test_func(): + from urllib3.exceptions import MaxRetryError, ProtocolError as Urllib3ProtocolError + + error = MaxRetryError( + pool=None, url="http://test.com", reason=Urllib3ProtocolError("test") + ) + raise error + + with pytest.raises(ProtocolError) as exc_info: + test_func() + assert "Failed to connect" in str(exc_info.value) + assert "http://test.com" in str(exc_info.value) + + def test_max_retry_error_with_non_protocol_reason(self): + """Test that MaxRetryError with non-ProtocolError reason is passed through.""" + + from urllib3.exceptions import MaxRetryError + + @validate_and_convert_errors + def test_func(): + # Create a MaxRetryError with a non-ProtocolError reason (ValueError) + # We'll use a simple ValueError as the reason + error = MaxRetryError(pool=None, url="http://test.com", reason=ValueError("test")) + raise error + + with pytest.raises(MaxRetryError): + test_func() + + def test_urllib3_protocol_error_converted(self): + """Test that urllib3 ProtocolError is converted to ProtocolError.""" + + @validate_and_convert_errors + def test_func(): + from urllib3.exceptions import ProtocolError as Urllib3ProtocolError + + raise Urllib3ProtocolError("connection failed") + + with pytest.raises(ProtocolError) as exc_info: + test_func() + assert "Connection failed" in str(exc_info.value) + assert "index host" in str(exc_info.value).lower() + + def test_preserves_function_signature(self): + """Test that function signature is preserved.""" + + @validate_and_convert_errors + def test_func(arg1: str, arg2: int = 10) -> str: + return f"{arg1}_{arg2}" + + assert test_func("test", 20) == "test_20" + assert test_func("test") == "test_10" + + def test_exception_chaining_preserved(self): + """Test that exception chaining is preserved.""" + + @validate_and_convert_errors + def test_func(): + from urllib3.exceptions import ProtocolError as Urllib3ProtocolError + + original = Urllib3ProtocolError("original error") + raise original + + with pytest.raises(ProtocolError) as exc_info: + test_func() + assert exc_info.value.__cause__ is not None diff --git a/tests/unit/utils/test_fix_tuple_length.py b/tests/unit/utils/test_fix_tuple_length.py new file mode 100644 index 000000000..19e849eaf --- /dev/null +++ b/tests/unit/utils/test_fix_tuple_length.py @@ -0,0 +1,64 @@ +from pinecone.utils.fix_tuple_length import fix_tuple_length + + +class TestFixTupleLength: + """Test fix_tuple_length utility function.""" + + def test_tuple_shorter_than_target(self): + """Test extending a tuple that's shorter than target length.""" + result = fix_tuple_length(("a", "b"), 4) + assert result == ("a", "b", None, None) + assert len(result) == 4 + + def test_tuple_equal_to_target(self): + """Test tuple that's already the target length.""" + result = fix_tuple_length(("a", "b", "c"), 3) + assert result == ("a", "b", "c") + assert len(result) == 3 + + def test_tuple_longer_than_target(self): + """Test tuple that's longer than target length (should be unchanged).""" + result = fix_tuple_length(("a", "b", "c", "d"), 3) + assert result == ("a", "b", "c", "d") + assert len(result) == 4 + + def test_empty_tuple_extended(self): + """Test extending an empty tuple.""" + result = fix_tuple_length((), 3) + assert result == (None, None, None) + assert len(result) == 3 + + def test_single_element_tuple(self): + """Test extending a single-element tuple.""" + result = fix_tuple_length(("a",), 3) + assert result == ("a", None, None) + assert len(result) == 3 + + def test_extend_to_length_one(self): + """Test extending to length 1.""" + result = fix_tuple_length((), 1) + assert result == (None,) + assert len(result) == 1 + + def test_extend_to_length_zero(self): + """Test extending to length 0 (edge case).""" + result = fix_tuple_length((), 0) + assert result == () + assert len(result) == 0 + + def test_preserves_original_values(self): + """Test that original values are preserved in correct positions.""" + result = fix_tuple_length(("id", "values", "metadata"), 5) + assert result[0] == "id" + assert result[1] == "values" + assert result[2] == "metadata" + assert result[3] is None + assert result[4] is None + + def test_with_none_values(self): + """Test tuple that already contains None values.""" + result = fix_tuple_length(("a", None, "c"), 5) + assert result == ("a", None, "c", None, None) + assert result[1] is None + assert result[3] is None + assert result[4] is None