Add more tests
This commit is contained in:
436
tests/test_validation_services.py
Normal file
436
tests/test_validation_services.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""Tests for validation service functions."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from aiohttp import ClientResponse
|
||||
|
||||
from knowledge_search_mcp.services.validation import (
|
||||
validate_genai_access,
|
||||
validate_gcs_access,
|
||||
validate_vector_search_access,
|
||||
)
|
||||
from knowledge_search_mcp.config import Settings
|
||||
|
||||
|
||||
class TestValidateGenAIAccess:
|
||||
"""Tests for validate_genai_access function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(self):
|
||||
"""Create mock settings."""
|
||||
settings = MagicMock(spec=Settings)
|
||||
settings.embedding_model = "models/text-embedding-004"
|
||||
settings.project_id = "test-project"
|
||||
settings.location = "us-central1"
|
||||
return settings
|
||||
|
||||
@pytest.fixture
|
||||
def mock_genai_client(self):
|
||||
"""Create a mock genai client."""
|
||||
client = MagicMock()
|
||||
client.aio = MagicMock()
|
||||
client.aio.models = MagicMock()
|
||||
return client
|
||||
|
||||
async def test_successful_validation(self, mock_genai_client, mock_settings):
|
||||
"""Test successful GenAI access validation."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_embedding = MagicMock()
|
||||
mock_embedding.values = [0.1] * 768 # Typical embedding dimension
|
||||
mock_response.embeddings = [mock_embedding]
|
||||
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Execute
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
# Assert
|
||||
assert error is None
|
||||
mock_genai_client.aio.models.embed_content.assert_called_once()
|
||||
call_kwargs = mock_genai_client.aio.models.embed_content.call_args.kwargs
|
||||
assert call_kwargs["model"] == "models/text-embedding-004"
|
||||
assert call_kwargs["contents"] == "test"
|
||||
assert call_kwargs["config"].task_type == "RETRIEVAL_QUERY"
|
||||
|
||||
async def test_empty_response(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of empty response."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.embeddings = []
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error == "Embedding validation returned empty response"
|
||||
|
||||
async def test_none_response(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of None response."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=None)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error == "Embedding validation returned empty response"
|
||||
|
||||
async def test_api_permission_error(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of permission denied error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=PermissionError("Permission denied for GenAI API")
|
||||
)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GenAI:" in error
|
||||
assert "Permission denied for GenAI API" in error
|
||||
|
||||
async def test_api_quota_error(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of quota exceeded error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=Exception("Quota exceeded")
|
||||
)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GenAI:" in error
|
||||
assert "Quota exceeded" in error
|
||||
|
||||
async def test_network_error(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of network error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=ConnectionError("Network unreachable")
|
||||
)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GenAI:" in error
|
||||
assert "Network unreachable" in error
|
||||
|
||||
async def test_invalid_model_error(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of invalid model error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=ValueError("Invalid model name")
|
||||
)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GenAI:" in error
|
||||
assert "Invalid model name" in error
|
||||
|
||||
async def test_embeddings_with_zero_values(self, mock_genai_client, mock_settings):
|
||||
"""Test validation with empty embedding values."""
|
||||
mock_response = MagicMock()
|
||||
mock_embedding = MagicMock()
|
||||
mock_embedding.values = []
|
||||
mock_response.embeddings = [mock_embedding]
|
||||
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
# Should succeed even with empty values, as long as embeddings exist
|
||||
assert error is None
|
||||
|
||||
|
||||
class TestValidateGCSAccess:
|
||||
"""Tests for validate_gcs_access function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(self):
|
||||
"""Create mock settings."""
|
||||
settings = MagicMock(spec=Settings)
|
||||
settings.bucket = "test-bucket"
|
||||
settings.project_id = "test-project"
|
||||
return settings
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_search(self):
|
||||
"""Create a mock vector search client."""
|
||||
vs = MagicMock()
|
||||
vs.storage = MagicMock()
|
||||
return vs
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create a mock aiohttp session."""
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
"""Create a mock HTTP response."""
|
||||
response = MagicMock()
|
||||
response.text = AsyncMock(return_value='{"items": []}')
|
||||
return response
|
||||
|
||||
async def test_successful_validation(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test successful GCS bucket access validation."""
|
||||
# Setup mocks
|
||||
mock_response.status = 200
|
||||
mock_response.ok = True
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is None
|
||||
mock_session.get.assert_called_once()
|
||||
call_args = mock_session.get.call_args
|
||||
assert "test-bucket" in call_args[0][0]
|
||||
assert call_args[1]["headers"]["Authorization"] == "Bearer fake-access-token"
|
||||
|
||||
async def test_access_denied_403(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 403 access denied."""
|
||||
mock_response.status = 403
|
||||
mock_response.ok = False
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Access denied to bucket 'test-bucket'" in error
|
||||
assert "permissions" in error.lower()
|
||||
|
||||
async def test_bucket_not_found_404(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 404 bucket not found."""
|
||||
mock_response.status = 404
|
||||
mock_response.ok = False
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Bucket 'test-bucket' not found" in error
|
||||
assert "bucket name" in error.lower()
|
||||
|
||||
async def test_server_error_500(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 500 server error."""
|
||||
mock_response.status = 500
|
||||
mock_response.ok = False
|
||||
mock_response.text = AsyncMock(return_value='{"error": "Internal server error"}')
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Failed to access bucket 'test-bucket': 500" in error
|
||||
|
||||
async def test_token_acquisition_error(self, mock_vector_search, mock_settings, mock_session):
|
||||
"""Test handling of token acquisition error."""
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(side_effect=Exception("Failed to get access token"))
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GCS:" in error
|
||||
assert "Failed to get access token" in error
|
||||
|
||||
async def test_network_error(self, mock_vector_search, mock_settings, mock_session):
|
||||
"""Test handling of network error."""
|
||||
mock_session.get = MagicMock(side_effect=ConnectionError("Network unreachable"))
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GCS:" in error
|
||||
assert "Network unreachable" in error
|
||||
|
||||
|
||||
class TestValidateVectorSearchAccess:
|
||||
"""Tests for validate_vector_search_access function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(self):
|
||||
"""Create mock settings."""
|
||||
settings = MagicMock(spec=Settings)
|
||||
settings.endpoint_name = "projects/test/locations/us-central1/indexEndpoints/test-endpoint"
|
||||
settings.location = "us-central1"
|
||||
return settings
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_search(self):
|
||||
"""Create a mock vector search client."""
|
||||
vs = MagicMock()
|
||||
vs._async_get_auth_headers = AsyncMock(return_value={"Authorization": "Bearer fake-token"})
|
||||
return vs
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create a mock aiohttp session."""
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
"""Create a mock HTTP response."""
|
||||
response = MagicMock()
|
||||
response.text = AsyncMock(return_value='{"name": "test-endpoint"}')
|
||||
return response
|
||||
|
||||
async def test_successful_validation(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test successful vector search endpoint validation."""
|
||||
mock_response.status = 200
|
||||
mock_response.ok = True
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is None
|
||||
mock_vector_search._async_get_auth_headers.assert_called_once()
|
||||
mock_session.get.assert_called_once()
|
||||
call_args = mock_session.get.call_args
|
||||
assert "us-central1-aiplatform.googleapis.com" in call_args[0][0]
|
||||
assert "test-endpoint" in call_args[0][0]
|
||||
assert call_args[1]["headers"]["Authorization"] == "Bearer fake-token"
|
||||
|
||||
async def test_access_denied_403(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 403 access denied."""
|
||||
mock_response.status = 403
|
||||
mock_response.ok = False
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Access denied to endpoint" in error
|
||||
assert "test-endpoint" in error
|
||||
assert "permissions" in error.lower()
|
||||
|
||||
async def test_endpoint_not_found_404(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 404 endpoint not found."""
|
||||
mock_response.status = 404
|
||||
mock_response.ok = False
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "not found" in error.lower()
|
||||
assert "test-endpoint" in error
|
||||
|
||||
async def test_server_error_503(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 503 service unavailable."""
|
||||
mock_response.status = 503
|
||||
mock_response.ok = False
|
||||
mock_response.text = AsyncMock(return_value='{"error": "Service unavailable"}')
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Failed to access endpoint" in error
|
||||
assert "503" in error
|
||||
|
||||
async def test_auth_header_error(self, mock_vector_search, mock_settings):
|
||||
"""Test handling of authentication header error."""
|
||||
mock_vector_search._async_get_auth_headers = AsyncMock(
|
||||
side_effect=Exception("Failed to get auth headers")
|
||||
)
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Vector Search:" in error
|
||||
assert "Failed to get auth headers" in error
|
||||
|
||||
async def test_network_timeout(self, mock_vector_search, mock_settings, mock_session):
|
||||
"""Test handling of network timeout."""
|
||||
mock_session.get = MagicMock(side_effect=TimeoutError("Request timed out"))
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Vector Search:" in error
|
||||
assert "Request timed out" in error
|
||||
|
||||
async def test_connection_error(self, mock_vector_search, mock_settings, mock_session):
|
||||
"""Test handling of connection error."""
|
||||
mock_session.get = MagicMock(side_effect=ConnectionError("Connection refused"))
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Vector Search:" in error
|
||||
assert "Connection refused" in error
|
||||
|
||||
async def test_endpoint_url_construction(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test that endpoint URL is constructed correctly."""
|
||||
mock_response.status = 200
|
||||
mock_response.ok = True
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
# Custom location
|
||||
mock_settings.location = "europe-west1"
|
||||
mock_settings.endpoint_name = "projects/my-project/locations/europe-west1/indexEndpoints/my-endpoint"
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is None
|
||||
call_args = mock_session.get.call_args
|
||||
url = call_args[0][0]
|
||||
assert "europe-west1-aiplatform.googleapis.com" in url
|
||||
assert "my-endpoint" in url
|
||||
Reference in New Issue
Block a user