Files
knowledge-search-mcp/tests/test_search.py
2026-03-03 18:34:57 +00:00

111 lines
3.6 KiB
Python

"""Tests for vector search functionality."""
import io
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from knowledge_search_mcp import (
GoogleCloudFileStorage,
GoogleCloudVectorSearch,
LRUCache,
SourceNamespace,
)
class TestGoogleCloudFileStorage:
"""Tests for GoogleCloudFileStorage."""
def test_init(self):
"""Test storage initialization."""
storage = GoogleCloudFileStorage(bucket="test-bucket")
assert storage.bucket_name == "test-bucket"
assert isinstance(storage._cache, LRUCache)
assert storage._cache.max_size == 100
@pytest.mark.asyncio
async def test_cache_hit(self):
"""Test that cached files are returned without fetching."""
storage = GoogleCloudFileStorage(bucket="test-bucket")
test_content = b"cached content"
storage._cache.put("test.md", test_content)
result = await storage.async_get_file_stream("test.md")
assert result.read() == test_content
assert result.name == "test.md"
@pytest.mark.asyncio
async def test_cache_miss(self):
"""Test that uncached files are fetched from GCS."""
storage = GoogleCloudFileStorage(bucket="test-bucket")
test_content = b"fetched content"
# Mock the storage download
with patch.object(storage, '_get_aio_storage') as mock_storage_getter:
mock_storage = AsyncMock()
mock_storage.download = AsyncMock(return_value=test_content)
mock_storage_getter.return_value = mock_storage
result = await storage.async_get_file_stream("test.md")
assert result.read() == test_content
assert storage._cache.get("test.md") == test_content
class TestGoogleCloudVectorSearch:
"""Tests for GoogleCloudVectorSearch."""
def test_init(self):
"""Test vector search client initialization."""
vs = GoogleCloudVectorSearch(
project_id="test-project",
location="us-central1",
bucket="test-bucket",
index_name="test-index",
)
assert vs.project_id == "test-project"
assert vs.location == "us-central1"
assert vs.index_name == "test-index"
def test_configure_index_endpoint(self):
"""Test endpoint configuration."""
vs = GoogleCloudVectorSearch(
project_id="test-project",
location="us-central1",
bucket="test-bucket",
)
vs.configure_index_endpoint(
name="test-endpoint",
public_domain="test.domain.com",
)
assert vs._endpoint_name == "test-endpoint"
assert vs._endpoint_domain == "test.domain.com"
def test_configure_index_endpoint_validation(self):
"""Test that endpoint configuration validates inputs."""
vs = GoogleCloudVectorSearch(
project_id="test-project",
location="us-central1",
bucket="test-bucket",
)
with pytest.raises(ValueError, match="endpoint name"):
vs.configure_index_endpoint(name="", public_domain="test.com")
with pytest.raises(ValueError, match="endpoint domain"):
vs.configure_index_endpoint(name="test", public_domain="")
class TestSourceNamespace:
"""Tests for SourceNamespace enum."""
def test_source_namespace_values(self):
"""Test that SourceNamespace has expected values."""
assert SourceNamespace.EDUCACION_FINANCIERA.value == "Educacion Financiera"
assert SourceNamespace.PRODUCTOS_Y_SERVICIOS.value == "Productos y Servicios"
assert SourceNamespace.FUNCIONALIDADES_APP_MOVIL.value == "Funcionalidades de la App Movil"