252 lines
8.5 KiB
Python
252 lines
8.5 KiB
Python
"""Tests for RAG services."""
|
|
|
|
from unittest.mock import AsyncMock, Mock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from capa_de_integracion.services.rag import (
|
|
EchoRAGService,
|
|
HTTPRAGService,
|
|
RAGServiceBase,
|
|
)
|
|
from capa_de_integracion.services.rag.base import Message, RAGRequest, RAGResponse
|
|
|
|
|
|
class TestEchoRAGService:
|
|
"""Tests for EchoRAGService."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_echo_default_prefix(self):
|
|
"""Test echo service with default prefix."""
|
|
service = EchoRAGService()
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
|
|
response = await service.query(messages)
|
|
|
|
assert response == "Echo: Hello"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_echo_custom_prefix(self):
|
|
"""Test echo service with custom prefix."""
|
|
service = EchoRAGService(prefix="Bot: ")
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
|
|
response = await service.query(messages)
|
|
|
|
assert response == "Bot: Hello"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_echo_multiple_messages(self):
|
|
"""Test echo service returns last user message."""
|
|
service = EchoRAGService()
|
|
messages = [
|
|
{"role": "user", "content": "First message"},
|
|
{"role": "assistant", "content": "Response"},
|
|
{"role": "user", "content": "Last message"},
|
|
]
|
|
|
|
response = await service.query(messages)
|
|
|
|
assert response == "Echo: Last message"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_echo_mixed_roles(self):
|
|
"""Test echo service finds last user message among mixed roles."""
|
|
service = EchoRAGService()
|
|
messages = [
|
|
{"role": "system", "content": "System prompt"},
|
|
{"role": "user", "content": "User question"},
|
|
{"role": "assistant", "content": "Assistant response"},
|
|
]
|
|
|
|
response = await service.query(messages)
|
|
|
|
assert response == "Echo: User question"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_echo_no_messages_error(self):
|
|
"""Test echo service raises error when no messages provided."""
|
|
service = EchoRAGService()
|
|
|
|
with pytest.raises(ValueError, match="No messages provided"):
|
|
await service.query([])
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_echo_no_user_message_error(self):
|
|
"""Test echo service raises error when no user message found."""
|
|
service = EchoRAGService()
|
|
messages = [
|
|
{"role": "system", "content": "System"},
|
|
{"role": "assistant", "content": "Assistant"},
|
|
]
|
|
|
|
with pytest.raises(ValueError, match="No user message found"):
|
|
await service.query(messages)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_echo_close(self):
|
|
"""Test echo service close method."""
|
|
service = EchoRAGService()
|
|
await service.close() # Should not raise
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_echo_context_manager(self):
|
|
"""Test echo service as async context manager."""
|
|
async with EchoRAGService() as service:
|
|
messages = [{"role": "user", "content": "Test"}]
|
|
response = await service.query(messages)
|
|
assert response == "Echo: Test"
|
|
|
|
|
|
class TestHTTPRAGService:
|
|
"""Tests for HTTPRAGService."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_successful_query(self):
|
|
"""Test HTTP RAG service successful query."""
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {"response": "AI response"}
|
|
mock_response.raise_for_status = Mock()
|
|
|
|
with patch("httpx.AsyncClient") as mock_client_class:
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
mock_client_class.return_value = mock_client
|
|
|
|
service = HTTPRAGService(
|
|
endpoint_url="http://test.example.com/rag",
|
|
max_connections=10,
|
|
max_keepalive_connections=5,
|
|
timeout=15.0,
|
|
)
|
|
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
response = await service.query(messages)
|
|
|
|
assert response == "AI response"
|
|
mock_client.post.assert_called_once()
|
|
call_kwargs = mock_client.post.call_args.kwargs
|
|
assert call_kwargs["json"]["messages"][0]["role"] == "user"
|
|
assert call_kwargs["json"]["messages"][0]["content"] == "Hello"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_status_error(self):
|
|
"""Test HTTP RAG service handles HTTP status errors."""
|
|
mock_response = Mock()
|
|
mock_response.status_code = 500
|
|
mock_response.text = "Internal Server Error"
|
|
|
|
with patch("httpx.AsyncClient") as mock_client_class:
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(
|
|
side_effect=httpx.HTTPStatusError(
|
|
"Error", request=Mock(), response=mock_response,
|
|
),
|
|
)
|
|
mock_client_class.return_value = mock_client
|
|
|
|
service = HTTPRAGService(endpoint_url="http://test.example.com/rag")
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
|
|
with pytest.raises(httpx.HTTPStatusError):
|
|
await service.query(messages)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_request_error(self):
|
|
"""Test HTTP RAG service handles request errors."""
|
|
with patch("httpx.AsyncClient") as mock_client_class:
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(
|
|
side_effect=httpx.RequestError("Connection failed", request=Mock()),
|
|
)
|
|
mock_client_class.return_value = mock_client
|
|
|
|
service = HTTPRAGService(endpoint_url="http://test.example.com/rag")
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
|
|
with pytest.raises(httpx.RequestError):
|
|
await service.query(messages)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_close(self):
|
|
"""Test HTTP RAG service close method."""
|
|
with patch("httpx.AsyncClient") as mock_client_class:
|
|
mock_client = AsyncMock()
|
|
mock_client.aclose = AsyncMock()
|
|
mock_client_class.return_value = mock_client
|
|
|
|
service = HTTPRAGService(endpoint_url="http://test.example.com/rag")
|
|
await service.close()
|
|
|
|
mock_client.aclose.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_context_manager(self):
|
|
"""Test HTTP RAG service as async context manager."""
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {"response": "AI response"}
|
|
mock_response.raise_for_status = Mock()
|
|
|
|
with patch("httpx.AsyncClient") as mock_client_class:
|
|
mock_client = AsyncMock()
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
mock_client.aclose = AsyncMock()
|
|
mock_client_class.return_value = mock_client
|
|
|
|
async with HTTPRAGService(endpoint_url="http://test.example.com/rag") as service:
|
|
messages = [{"role": "user", "content": "Test"}]
|
|
response = await service.query(messages)
|
|
assert response == "AI response"
|
|
|
|
mock_client.aclose.assert_called_once()
|
|
|
|
|
|
class TestRAGModels:
|
|
"""Tests for RAG data models."""
|
|
|
|
def test_message_model(self):
|
|
"""Test Message model."""
|
|
msg = Message(role="user", content="Hello")
|
|
assert msg.role == "user"
|
|
assert msg.content == "Hello"
|
|
|
|
def test_rag_request_model(self):
|
|
"""Test RAGRequest model."""
|
|
messages = [
|
|
Message(role="user", content="Hello"),
|
|
Message(role="assistant", content="Hi"),
|
|
]
|
|
request = RAGRequest(messages=messages)
|
|
assert len(request.messages) == 2
|
|
assert request.messages[0].role == "user"
|
|
|
|
def test_rag_response_model(self):
|
|
"""Test RAGResponse model."""
|
|
response = RAGResponse(response="AI response")
|
|
assert response.response == "AI response"
|
|
|
|
|
|
class TestRAGServiceBase:
|
|
"""Tests for RAGServiceBase abstract methods."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_base_context_manager_calls_close(self):
|
|
"""Test that context manager calls close."""
|
|
|
|
class MockRAGService(RAGServiceBase):
|
|
def __init__(self):
|
|
self.closed = False
|
|
|
|
async def query(self, messages):
|
|
return "test"
|
|
|
|
async def close(self):
|
|
self.closed = True
|
|
|
|
service = MockRAGService()
|
|
async with service:
|
|
pass
|
|
|
|
assert service.closed is True
|