Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions tests/unit/data/test_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from pinecone.db_data.errors import (
VectorDictionaryMissingKeysError,
VectorDictionaryExcessKeysError,
VectorTupleLengthError,
SparseValuesTypeError,
SparseValuesMissingKeysError,
SparseValuesDictionaryExpectedError,
MetadataDictionaryExpectedError,
)


class TestVectorDictionaryMissingKeysError:
"""Test VectorDictionaryMissingKeysError exception."""

def test_error_message_includes_missing_keys(self):
"""Test that error message lists missing required fields."""
item = {"values": [0.1, 0.2]}
error = VectorDictionaryMissingKeysError(item)
assert "missing required fields" in str(error).lower()
assert "id" in str(error)

def test_error_message_with_multiple_missing_keys(self):
"""Test error message when multiple keys are missing."""
item = {}
error = VectorDictionaryMissingKeysError(item)
assert "missing required fields" in str(error).lower()


class TestVectorDictionaryExcessKeysError:
"""Test VectorDictionaryExcessKeysError exception."""

def test_error_message_includes_excess_keys(self):
"""Test that error message lists excess keys."""
item = {"id": "1", "values": [0.1, 0.2], "extra_field": "value", "another_extra": 123}
error = VectorDictionaryExcessKeysError(item)
assert "excess keys" in str(error).lower()
assert "extra_field" in str(error) or "another_extra" in str(error)

def test_error_message_includes_allowed_keys(self):
"""Test that error message includes list of allowed keys."""
item = {"id": "1", "values": [0.1, 0.2], "invalid": "key"}
error = VectorDictionaryExcessKeysError(item)
assert "allowed keys" in str(error).lower()


class TestVectorTupleLengthError:
"""Test VectorTupleLengthError exception."""

def test_error_message_includes_tuple_length(self):
"""Test that error message includes the tuple length."""
item = ("id", "values", "metadata", "extra")
error = VectorTupleLengthError(item)
assert str(len(item)) in str(error)
assert "tuple" in str(error).lower()

def test_error_message_with_length_one(self):
"""Test error message for tuple of length 1."""
item = ("id",)
error = VectorTupleLengthError(item)
assert "1" in str(error)

def test_error_message_with_length_four(self):
"""Test error message for tuple of length 4."""
item = ("id", "values", "metadata", "extra")
error = VectorTupleLengthError(item)
assert "4" in str(error)


class TestSparseValuesTypeError:
"""Test SparseValuesTypeError exception."""

def test_error_message_mentions_sparse_values(self):
"""Test that error message mentions sparse_values."""
error = SparseValuesTypeError()
assert "sparse_values" in str(error).lower()

def test_error_is_both_value_and_type_error(self):
"""Test that SparseValuesTypeError is both ValueError and TypeError."""
error = SparseValuesTypeError()
assert isinstance(error, ValueError)
assert isinstance(error, TypeError)


class TestSparseValuesMissingKeysError:
"""Test SparseValuesMissingKeysError exception."""

def test_error_message_includes_found_keys(self):
"""Test that error message includes the keys that were found."""
sparse_values_dict = {"indices": [0, 2]}
error = SparseValuesMissingKeysError(sparse_values_dict)
assert "missing required keys" in str(error).lower()
assert "indices" in str(error) or "values" in str(error)

def test_error_message_with_empty_dict(self):
"""Test error message when dictionary is empty."""
sparse_values_dict = {}
error = SparseValuesMissingKeysError(sparse_values_dict)
assert "missing required keys" in str(error).lower()


class TestSparseValuesDictionaryExpectedError:
"""Test SparseValuesDictionaryExpectedError exception."""

def test_error_message_includes_actual_type(self):
"""Test that error message includes the actual type found."""
sparse_values_dict = "not a dict"
error = SparseValuesDictionaryExpectedError(sparse_values_dict)
assert "dictionary" in str(error).lower()
assert "str" in str(error) or type(sparse_values_dict).__name__ in str(error)

def test_error_message_with_integer(self):
"""Test error message when integer is provided."""
sparse_values_dict = 123
error = SparseValuesDictionaryExpectedError(sparse_values_dict)
assert "dictionary" in str(error).lower()
assert isinstance(error, ValueError)
assert isinstance(error, TypeError)


class TestMetadataDictionaryExpectedError:
"""Test MetadataDictionaryExpectedError exception."""

def test_error_message_includes_actual_type(self):
"""Test that error message includes the actual type found."""
item = {"metadata": "not a dict"}
error = MetadataDictionaryExpectedError(item)
assert "dictionary" in str(error).lower()
assert "metadata" in str(error).lower()

def test_error_message_with_list(self):
"""Test error message when list is provided as metadata."""
item = {"metadata": [1, 2, 3]}
error = MetadataDictionaryExpectedError(item)
assert "dictionary" in str(error).lower()
assert isinstance(error, ValueError)
assert isinstance(error, TypeError)
152 changes: 152 additions & 0 deletions tests/unit/data/test_sparse_values_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import numpy as np
import pandas as pd
import pytest

from pinecone.db_data.sparse_values_factory import SparseValuesFactory
from pinecone import SparseValues
from pinecone.core.openapi.db_data.models import SparseValues as OpenApiSparseValues
from pinecone.db_data.errors import (
SparseValuesTypeError,
SparseValuesMissingKeysError,
SparseValuesDictionaryExpectedError,
)


class TestSparseValuesFactory:
"""Test SparseValuesFactory for REST API (db_data module)."""

def test_build_when_none_returns_none(self):
"""Test that None input returns None."""
assert SparseValuesFactory.build(None) is None

def test_build_when_passed_openapi_sparse_values(self):
"""Test that OpenApiSparseValues are returned unchanged."""
sv = OpenApiSparseValues(indices=[0, 2], values=[0.1, 0.3])
actual = SparseValuesFactory.build(sv)
assert actual == sv
assert actual is sv

def test_build_when_given_sparse_values_dataclass(self):
"""Test conversion from SparseValues dataclass to OpenApiSparseValues."""
sv = SparseValues(indices=[0, 2], values=[0.1, 0.3])
actual = SparseValuesFactory.build(sv)
expected = OpenApiSparseValues(indices=[0, 2], values=[0.1, 0.3])
assert isinstance(actual, OpenApiSparseValues)
assert actual.indices == expected.indices
assert actual.values == expected.values

@pytest.mark.parametrize(
"input_dict",
[
{"indices": [2], "values": [0.3]},
{"indices": [88, 102], "values": [-0.1, 0.3]},
{"indices": [0, 2, 4], "values": [0.1, 0.3, 0.5]},
{"indices": [0, 2, 4, 6], "values": [0.1, 0.3, 0.5, 0.7]},
],
)
def test_build_when_valid_dictionary(self, input_dict):
"""Test building from valid dictionary input."""
actual = SparseValuesFactory.build(input_dict)
expected = OpenApiSparseValues(indices=input_dict["indices"], values=input_dict["values"])
assert actual.indices == expected.indices
assert actual.values == expected.values

@pytest.mark.parametrize(
"input_dict",
[
{"indices": np.array([0, 2]), "values": [0.1, 0.3]},
{"indices": [0, 2], "values": np.array([0.1, 0.3])},
{"indices": np.array([0, 2]), "values": np.array([0.1, 0.3])},
{"indices": pd.array([0, 2]), "values": [0.1, 0.3]},
{"indices": [0, 2], "values": pd.array([0.1, 0.3])},
{"indices": pd.array([0, 2]), "values": pd.array([0.1, 0.3])},
],
)
def test_build_when_special_data_types(self, input_dict):
"""Test that the factory handles numpy/pandas arrays correctly."""
actual = SparseValuesFactory.build(input_dict)
expected = OpenApiSparseValues(indices=[0, 2], values=[0.1, 0.3])
assert actual.indices == expected.indices
assert actual.values == expected.values

@pytest.mark.parametrize(
"input_dict",
[{"indices": [2], "values": [0.3, 0.3]}, {"indices": [88, 102], "values": [-0.1]}],
)
def test_build_when_list_sizes_dont_match(self, input_dict):
"""Test that mismatched indices and values lengths raise ValueError."""
with pytest.raises(
ValueError, match="Sparse values indices and values must have the same length"
):
SparseValuesFactory.build(input_dict)

@pytest.mark.parametrize(
"input_dict",
[
{"indices": [2.0], "values": [0.3]},
{"indices": ["2"], "values": [0.3]},
{"indices": np.array([2.0]), "values": [0.3]},
{"indices": pd.array([2.0]), "values": [0.3]},
],
)
def test_build_when_non_integer_indices(self, input_dict):
"""Test that non-integer indices raise SparseValuesTypeError."""
with pytest.raises(SparseValuesTypeError):
SparseValuesFactory.build(input_dict)

@pytest.mark.parametrize(
"input_dict", [{"indices": [2], "values": ["3.2"]}, {"indices": [2], "values": [True]}]
)
def test_build_when_non_float_values(self, input_dict):
"""Test that non-float values raise SparseValuesTypeError."""
with pytest.raises(SparseValuesTypeError):
SparseValuesFactory.build(input_dict)

def test_build_when_missing_indices_key(self):
"""Test that missing 'indices' key raises SparseValuesMissingKeysError."""
input_dict = {"values": [0.1, 0.3]}
with pytest.raises(SparseValuesMissingKeysError) as exc_info:
SparseValuesFactory.build(input_dict)
assert "indices" in str(exc_info.value)

def test_build_when_missing_values_key(self):
"""Test that missing 'values' key raises SparseValuesMissingKeysError."""
input_dict = {"indices": [0, 2]}
with pytest.raises(SparseValuesMissingKeysError) as exc_info:
SparseValuesFactory.build(input_dict)
assert "values" in str(exc_info.value)

def test_build_when_missing_both_keys(self):
"""Test that missing both keys raises SparseValuesMissingKeysError."""
input_dict = {}
with pytest.raises(SparseValuesMissingKeysError) as exc_info:
SparseValuesFactory.build(input_dict)
assert "indices" in str(exc_info.value) or "values" in str(exc_info.value)

def test_build_when_not_a_dictionary(self):
"""Test that non-dictionary input raises SparseValuesDictionaryExpectedError."""
with pytest.raises(SparseValuesDictionaryExpectedError) as exc_info:
SparseValuesFactory.build("not a dict")
assert "dictionary" in str(exc_info.value).lower()

with pytest.raises(SparseValuesDictionaryExpectedError):
SparseValuesFactory.build(123)

with pytest.raises(SparseValuesDictionaryExpectedError):
SparseValuesFactory.build([1, 2, 3])

def test_build_when_empty_indices_list(self):
"""Test that empty indices list is handled correctly."""
input_dict = {"indices": [], "values": []}
actual = SparseValuesFactory.build(input_dict)
expected = OpenApiSparseValues(indices=[], values=[])
assert actual.indices == expected.indices
assert actual.values == expected.values

def test_build_when_empty_values_list(self):
"""Test that empty values list is handled correctly."""
input_dict = {"indices": [], "values": []}
actual = SparseValuesFactory.build(input_dict)
expected = OpenApiSparseValues(indices=[], values=[])
assert actual.indices == expected.indices
assert actual.values == expected.values
85 changes: 85 additions & 0 deletions tests/unit/utils/test_check_kwargs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from unittest.mock import patch

from pinecone.utils import check_kwargs


def example_function(arg1, arg2, arg3=None):
"""Example function for testing."""
pass


class TestCheckKwargs:
"""Test check_kwargs utility function."""

def test_no_unexpected_kwargs_no_logging(self):
"""Test that no logging occurs when all kwargs are valid."""
with patch("logging.exception") as mock_log:
check_kwargs(example_function, {"arg1", "arg2", "arg3"})
mock_log.assert_not_called()

def test_unexpected_kwargs_logs_warning(self):
"""Test that unexpected kwargs trigger logging."""
with patch("logging.exception") as mock_log:
check_kwargs(example_function, {"arg1", "arg2", "arg3", "unexpected_arg"})
mock_log.assert_called_once()
call_args = mock_log.call_args[0][0]
assert "unexpected keyword argument" in call_args.lower()
assert "unexpected_arg" in call_args

def test_multiple_unexpected_kwargs_logs_all(self):
"""Test that multiple unexpected kwargs are all logged."""
with patch("logging.exception") as mock_log:
check_kwargs(example_function, {"arg1", "arg2", "arg3", "unexpected1", "unexpected2"})
mock_log.assert_called_once()
call_args = mock_log.call_args[0][0]
assert "unexpected1" in call_args or "unexpected2" in call_args

def test_only_unexpected_kwargs(self):
"""Test when only unexpected kwargs are provided."""
with patch("logging.exception") as mock_log:
check_kwargs(example_function, {"unexpected_arg"})
mock_log.assert_called_once()
call_args = mock_log.call_args[0][0]
assert "unexpected keyword argument" in call_args.lower()

def test_empty_kwargs_set(self):
"""Test with empty kwargs set."""
with patch("logging.exception") as mock_log:
check_kwargs(example_function, set())
mock_log.assert_not_called()

def test_function_with_no_args(self):
"""Test with function that has no arguments."""

def no_args_function():
pass

with patch("logging.exception") as mock_log:
check_kwargs(no_args_function, {"any_arg"})
mock_log.assert_called_once()

def test_function_with_varargs(self):
"""Test with function that has *args."""

def varargs_function(*args):
pass

with patch("logging.exception") as mock_log:
check_kwargs(varargs_function, {"any_arg"})
mock_log.assert_called_once()

def test_function_with_kwargs(self):
"""Test with function that has **kwargs.

Note: check_kwargs only checks explicit args, not **kwargs,
so it will still log unexpected args even for functions with **kwargs.
This is the current behavior of the function.
"""

def kwargs_function(**kwargs):
pass

with patch("logging.exception") as mock_log:
check_kwargs(kwargs_function, {"any_arg"})
# check_kwargs doesn't check for **kwargs, so it will log
mock_log.assert_called_once()
Loading
Loading