111 lines
3.6 KiB
Python
111 lines
3.6 KiB
Python
"""Tests for vector search functionality."""
|
|
|
|
import io
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from knowledge_search_mcp import (
|
|
GoogleCloudFileStorage,
|
|
GoogleCloudVectorSearch,
|
|
LRUCache,
|
|
SourceNamespace,
|
|
)
|
|
|
|
|
|
class TestGoogleCloudFileStorage:
|
|
"""Tests for GoogleCloudFileStorage."""
|
|
|
|
def test_init(self):
|
|
"""Test storage initialization."""
|
|
storage = GoogleCloudFileStorage(bucket="test-bucket")
|
|
assert storage.bucket_name == "test-bucket"
|
|
assert isinstance(storage._cache, LRUCache)
|
|
assert storage._cache.max_size == 100
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_hit(self):
|
|
"""Test that cached files are returned without fetching."""
|
|
storage = GoogleCloudFileStorage(bucket="test-bucket")
|
|
test_content = b"cached content"
|
|
storage._cache.put("test.md", test_content)
|
|
|
|
result = await storage.async_get_file_stream("test.md")
|
|
|
|
assert result.read() == test_content
|
|
assert result.name == "test.md"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_miss(self):
|
|
"""Test that uncached files are fetched from GCS."""
|
|
storage = GoogleCloudFileStorage(bucket="test-bucket")
|
|
test_content = b"fetched content"
|
|
|
|
# Mock the storage download
|
|
with patch.object(storage, '_get_aio_storage') as mock_storage_getter:
|
|
mock_storage = AsyncMock()
|
|
mock_storage.download = AsyncMock(return_value=test_content)
|
|
mock_storage_getter.return_value = mock_storage
|
|
|
|
result = await storage.async_get_file_stream("test.md")
|
|
|
|
assert result.read() == test_content
|
|
assert storage._cache.get("test.md") == test_content
|
|
|
|
|
|
class TestGoogleCloudVectorSearch:
|
|
"""Tests for GoogleCloudVectorSearch."""
|
|
|
|
def test_init(self):
|
|
"""Test vector search client initialization."""
|
|
vs = GoogleCloudVectorSearch(
|
|
project_id="test-project",
|
|
location="us-central1",
|
|
bucket="test-bucket",
|
|
index_name="test-index",
|
|
)
|
|
|
|
assert vs.project_id == "test-project"
|
|
assert vs.location == "us-central1"
|
|
assert vs.index_name == "test-index"
|
|
|
|
def test_configure_index_endpoint(self):
|
|
"""Test endpoint configuration."""
|
|
vs = GoogleCloudVectorSearch(
|
|
project_id="test-project",
|
|
location="us-central1",
|
|
bucket="test-bucket",
|
|
)
|
|
|
|
vs.configure_index_endpoint(
|
|
name="test-endpoint",
|
|
public_domain="test.domain.com",
|
|
)
|
|
|
|
assert vs._endpoint_name == "test-endpoint"
|
|
assert vs._endpoint_domain == "test.domain.com"
|
|
|
|
def test_configure_index_endpoint_validation(self):
|
|
"""Test that endpoint configuration validates inputs."""
|
|
vs = GoogleCloudVectorSearch(
|
|
project_id="test-project",
|
|
location="us-central1",
|
|
bucket="test-bucket",
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="endpoint name"):
|
|
vs.configure_index_endpoint(name="", public_domain="test.com")
|
|
|
|
with pytest.raises(ValueError, match="endpoint domain"):
|
|
vs.configure_index_endpoint(name="test", public_domain="")
|
|
|
|
|
|
class TestSourceNamespace:
|
|
"""Tests for SourceNamespace enum."""
|
|
|
|
def test_source_namespace_values(self):
|
|
"""Test that SourceNamespace has expected values."""
|
|
assert SourceNamespace.EDUCACION_FINANCIERA.value == "Educacion Financiera"
|
|
assert SourceNamespace.PRODUCTOS_Y_SERVICIOS.value == "Productos y Servicios"
|
|
assert SourceNamespace.FUNCIONALIDADES_APP_MOVIL.value == "Funcionalidades de la App Movil"
|