code-shredding
This commit is contained in:
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests package
|
||||
89
tests/conftest.py
Normal file
89
tests/conftest.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Shared pytest fixtures for knowledge_pipeline tests."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from knowledge_pipeline.chunker.base_chunker import BaseChunker, Document
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gcs_client():
|
||||
"""Mock Google Cloud Storage client."""
|
||||
client = Mock()
|
||||
bucket = Mock()
|
||||
blob = Mock()
|
||||
|
||||
client.bucket.return_value = bucket
|
||||
bucket.blob.return_value = blob
|
||||
bucket.list_blobs.return_value = []
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chunker():
|
||||
"""Mock BaseChunker implementation."""
|
||||
chunker = Mock(spec=BaseChunker)
|
||||
chunker.max_chunk_size = 1000
|
||||
chunker.process_text.return_value = [
|
||||
{"page_content": "Test chunk content", "metadata": {"id": "test_chunk"}}
|
||||
]
|
||||
return chunker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedder():
|
||||
"""Mock pydantic_ai Embedder."""
|
||||
embedder = Mock()
|
||||
embeddings_result = Mock()
|
||||
embeddings_result.embeddings = [[0.1, 0.2, 0.3]]
|
||||
embedder.embed_documents_sync.return_value = embeddings_result
|
||||
return embedder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_converter():
|
||||
"""Mock MarkItDown converter."""
|
||||
converter = Mock()
|
||||
result = Mock()
|
||||
result.text_content = "# Markdown Content\n\nTest content here."
|
||||
converter.convert.return_value = result
|
||||
return converter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chunks() -> list[Document]:
|
||||
"""Sample document chunks for testing."""
|
||||
return [
|
||||
{"page_content": "First chunk content", "metadata": {"id": "doc_1_0"}},
|
||||
{"page_content": "Second chunk content", "metadata": {"id": "doc_1_1"}},
|
||||
{"page_content": "Third chunk content", "metadata": {"id": "doc_1_2"}},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embeddings():
|
||||
"""Sample embeddings for testing."""
|
||||
return [
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
[0.6, 0.7, 0.8, 0.9, 1.0],
|
||||
[0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_vectors():
|
||||
"""Sample vector records for testing."""
|
||||
return [
|
||||
{
|
||||
"id": "doc_1_0",
|
||||
"embedding": [0.1, 0.2, 0.3],
|
||||
"restricts": [{"namespace": "source", "allow": ["documents"]}],
|
||||
},
|
||||
{
|
||||
"id": "doc_1_1",
|
||||
"embedding": [0.4, 0.5, 0.6],
|
||||
"restricts": [{"namespace": "source", "allow": ["documents"]}],
|
||||
},
|
||||
]
|
||||
553
tests/test_pipeline.py
Normal file
553
tests/test_pipeline.py
Normal file
@@ -0,0 +1,553 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from google.cloud.aiplatform.matching_engine.matching_engine_index_config import (
|
||||
DistanceMeasureType,
|
||||
)
|
||||
|
||||
from knowledge_pipeline.chunker.base_chunker import BaseChunker
|
||||
from knowledge_pipeline.pipeline import (
|
||||
_parse_gcs_uri,
|
||||
build_vectors,
|
||||
create_vector_index,
|
||||
gather_pdfs,
|
||||
normalize_string,
|
||||
process_file,
|
||||
run_pipeline,
|
||||
split_into_chunks,
|
||||
upload_to_gcs,
|
||||
)
|
||||
|
||||
|
||||
class TestParseGcsUri:
|
||||
"""Tests for _parse_gcs_uri function."""
|
||||
|
||||
def test_basic_gcs_uri(self):
|
||||
bucket, path = _parse_gcs_uri("gs://my-bucket/path/to/file.pdf")
|
||||
assert bucket == "my-bucket"
|
||||
assert path == "path/to/file.pdf"
|
||||
|
||||
def test_gcs_uri_with_nested_path(self):
|
||||
bucket, path = _parse_gcs_uri("gs://test-bucket/deep/nested/path/file.txt")
|
||||
assert bucket == "test-bucket"
|
||||
assert path == "deep/nested/path/file.txt"
|
||||
|
||||
def test_gcs_uri_bucket_only(self):
|
||||
bucket, path = _parse_gcs_uri("gs://my-bucket/")
|
||||
assert bucket == "my-bucket"
|
||||
assert path == ""
|
||||
|
||||
def test_gcs_uri_no_trailing_slash(self):
|
||||
bucket, path = _parse_gcs_uri("gs://bucket-name")
|
||||
assert bucket == "bucket-name"
|
||||
assert path == ""
|
||||
|
||||
|
||||
class TestNormalizeString:
|
||||
"""Tests for normalize_string function."""
|
||||
|
||||
def test_normalize_basic_string(self):
|
||||
result = normalize_string("Hello World")
|
||||
assert result == "hello_world"
|
||||
|
||||
def test_normalize_special_characters(self):
|
||||
result = normalize_string("File#Name@2024!.pdf")
|
||||
assert result == "filename2024.pdf"
|
||||
|
||||
def test_normalize_unicode(self):
|
||||
result = normalize_string("Café Münchën")
|
||||
assert result == "cafe_munchen"
|
||||
|
||||
def test_normalize_multiple_spaces(self):
|
||||
result = normalize_string("Multiple Spaces Here")
|
||||
assert result == "multiple_spaces_here"
|
||||
|
||||
def test_normalize_with_hyphens_and_periods(self):
|
||||
result = normalize_string("valid-filename.2024")
|
||||
assert result == "valid-filename.2024"
|
||||
|
||||
def test_normalize_empty_string(self):
|
||||
result = normalize_string("")
|
||||
assert result == ""
|
||||
|
||||
def test_normalize_only_special_chars(self):
|
||||
result = normalize_string("@#$%^&*()")
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestGatherFiles:
|
||||
"""Tests for gather_files function."""
|
||||
|
||||
def test_gather_files_finds_pdfs(self):
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
|
||||
# Create mock blobs
|
||||
mock_blob1 = Mock()
|
||||
mock_blob1.name = "docs/file1.pdf"
|
||||
mock_blob2 = Mock()
|
||||
mock_blob2.name = "docs/file2.pdf"
|
||||
mock_blob3 = Mock()
|
||||
mock_blob3.name = "docs/readme.txt"
|
||||
|
||||
mock_bucket.list_blobs.return_value = [mock_blob1, mock_blob2, mock_blob3]
|
||||
|
||||
files = gather_pdfs("gs://my-bucket/docs", mock_client)
|
||||
|
||||
assert len(files) == 2
|
||||
assert "gs://my-bucket/docs/file1.pdf" in files
|
||||
assert "gs://my-bucket/docs/file2.pdf" in files
|
||||
assert "gs://my-bucket/docs/readme.txt" not in files
|
||||
|
||||
def test_gather_files_no_pdfs(self):
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
|
||||
mock_blob = Mock()
|
||||
mock_blob.name = "docs/readme.txt"
|
||||
mock_bucket.list_blobs.return_value = [mock_blob]
|
||||
|
||||
files = gather_pdfs("gs://my-bucket/docs", mock_client)
|
||||
|
||||
assert len(files) == 0
|
||||
|
||||
def test_gather_files_empty_bucket(self):
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.list_blobs.return_value = []
|
||||
|
||||
files = gather_pdfs("gs://my-bucket/docs", mock_client)
|
||||
|
||||
assert len(files) == 0
|
||||
|
||||
def test_gather_files_correct_prefix(self):
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.list_blobs.return_value = []
|
||||
|
||||
gather_pdfs("gs://my-bucket/docs/subfolder", mock_client)
|
||||
|
||||
mock_client.bucket.assert_called_once_with("my-bucket")
|
||||
mock_bucket.list_blobs.assert_called_once_with(prefix="docs/subfolder")
|
||||
|
||||
|
||||
class TestSplitIntoChunks:
|
||||
"""Tests for split_into_chunks function."""
|
||||
|
||||
def test_split_small_text_single_chunk(self):
|
||||
mock_chunker = Mock(spec=BaseChunker)
|
||||
mock_chunker.max_chunk_size = 1000
|
||||
|
||||
text = "Small text"
|
||||
file_id = "test_file"
|
||||
|
||||
chunks = split_into_chunks(text, file_id, mock_chunker)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0]["page_content"] == "Small text"
|
||||
assert chunks[0]["metadata"]["id"] == "test_file"
|
||||
mock_chunker.process_text.assert_not_called()
|
||||
|
||||
def test_split_large_text_multiple_chunks(self):
|
||||
mock_chunker = Mock(spec=BaseChunker)
|
||||
mock_chunker.max_chunk_size = 10
|
||||
|
||||
# Create text larger than max_chunk_size
|
||||
text = "This is a very long text that needs to be split into chunks"
|
||||
file_id = "test_file"
|
||||
|
||||
# Mock the chunker to return multiple chunks
|
||||
mock_chunker.process_text.return_value = [
|
||||
{"page_content": "This is a very", "metadata": {}},
|
||||
{"page_content": "long text that", "metadata": {}},
|
||||
{"page_content": "needs to be split", "metadata": {}},
|
||||
]
|
||||
|
||||
chunks = split_into_chunks(text, file_id, mock_chunker)
|
||||
|
||||
assert len(chunks) == 3
|
||||
assert chunks[0]["metadata"]["id"] == "test_file_0"
|
||||
assert chunks[1]["metadata"]["id"] == "test_file_1"
|
||||
assert chunks[2]["metadata"]["id"] == "test_file_2"
|
||||
mock_chunker.process_text.assert_called_once_with(text)
|
||||
|
||||
def test_split_exactly_max_size(self):
|
||||
mock_chunker = Mock(spec=BaseChunker)
|
||||
mock_chunker.max_chunk_size = 10
|
||||
|
||||
text = "0123456789" # Exactly 10 characters
|
||||
file_id = "test_file"
|
||||
|
||||
chunks = split_into_chunks(text, file_id, mock_chunker)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0]["page_content"] == text
|
||||
mock_chunker.process_text.assert_not_called()
|
||||
|
||||
|
||||
class TestUploadToGcs:
|
||||
"""Tests for upload_to_gcs function."""
|
||||
|
||||
def test_upload_single_chunk_and_vectors(self):
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_blob = Mock()
|
||||
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
|
||||
chunks = [
|
||||
{
|
||||
"page_content": "Test content",
|
||||
"metadata": {"id": "chunk_1"},
|
||||
}
|
||||
]
|
||||
vectors = [{"id": "chunk_1", "embedding": [0.1, 0.2]}]
|
||||
|
||||
upload_to_gcs(
|
||||
chunks,
|
||||
vectors,
|
||||
"gs://my-bucket/contents",
|
||||
"gs://my-bucket/vectors/vectors.jsonl",
|
||||
mock_client,
|
||||
)
|
||||
|
||||
blob_calls = [call[0][0] for call in mock_bucket.blob.call_args_list]
|
||||
assert "contents/chunk_1.md" in blob_calls
|
||||
assert "vectors/vectors.jsonl" in blob_calls
|
||||
|
||||
def test_upload_multiple_chunks(self):
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_blob = Mock()
|
||||
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
|
||||
chunks = [
|
||||
{"page_content": "Content 1", "metadata": {"id": "chunk_1"}},
|
||||
{"page_content": "Content 2", "metadata": {"id": "chunk_2"}},
|
||||
{"page_content": "Content 3", "metadata": {"id": "chunk_3"}},
|
||||
]
|
||||
vectors = [{"id": "chunk_1", "embedding": [0.1]}]
|
||||
|
||||
upload_to_gcs(
|
||||
chunks,
|
||||
vectors,
|
||||
"gs://my-bucket/contents",
|
||||
"gs://my-bucket/vectors/vectors.jsonl",
|
||||
mock_client,
|
||||
)
|
||||
|
||||
# 3 chunk blobs + 1 vectors blob
|
||||
assert mock_bucket.blob.call_count == 4
|
||||
|
||||
blob_calls = [call[0][0] for call in mock_bucket.blob.call_args_list]
|
||||
assert blob_calls == [
|
||||
"contents/chunk_1.md",
|
||||
"contents/chunk_2.md",
|
||||
"contents/chunk_3.md",
|
||||
"vectors/vectors.jsonl",
|
||||
]
|
||||
|
||||
|
||||
class TestBuildVectors:
|
||||
"""Tests for build_vectors function."""
|
||||
|
||||
def test_build_vectors_basic(self):
|
||||
chunks = [
|
||||
{"metadata": {"id": "doc_1"}, "page_content": "content 1"},
|
||||
{"metadata": {"id": "doc_2"}, "page_content": "content 2"},
|
||||
]
|
||||
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
source_folder = "documents/reports"
|
||||
|
||||
vectors = build_vectors(chunks, embeddings, source_folder)
|
||||
|
||||
assert len(vectors) == 2
|
||||
assert vectors[0]["id"] == "doc_1"
|
||||
assert vectors[0]["embedding"] == [0.1, 0.2, 0.3]
|
||||
assert vectors[0]["restricts"] == [
|
||||
{"namespace": "source", "allow": ["documents"]}
|
||||
]
|
||||
assert vectors[1]["id"] == "doc_2"
|
||||
assert vectors[1]["embedding"] == [0.4, 0.5, 0.6]
|
||||
|
||||
def test_build_vectors_empty_source(self):
|
||||
chunks = [{"metadata": {"id": "doc_1"}, "page_content": "content"}]
|
||||
embeddings = [[0.1, 0.2]]
|
||||
source_folder = ""
|
||||
|
||||
vectors = build_vectors(chunks, embeddings, source_folder)
|
||||
|
||||
assert len(vectors) == 1
|
||||
assert vectors[0]["restricts"] == [{"namespace": "source", "allow": [""]}]
|
||||
|
||||
def test_build_vectors_nested_path(self):
|
||||
chunks = [{"metadata": {"id": "doc_1"}, "page_content": "content"}]
|
||||
embeddings = [[0.1]]
|
||||
source_folder = "a/b/c/d"
|
||||
|
||||
vectors = build_vectors(chunks, embeddings, source_folder)
|
||||
|
||||
assert vectors[0]["restricts"] == [{"namespace": "source", "allow": ["a"]}]
|
||||
|
||||
|
||||
class TestCreateVectorIndex:
|
||||
"""Tests for create_vector_index function."""
|
||||
|
||||
@patch("knowledge_pipeline.pipeline.aiplatform.MatchingEngineIndexEndpoint")
|
||||
@patch("knowledge_pipeline.pipeline.aiplatform.MatchingEngineIndex")
|
||||
def test_create_vector_index(self, mock_index_class, mock_endpoint_class):
|
||||
mock_index = Mock()
|
||||
mock_endpoint = Mock()
|
||||
|
||||
mock_index_class.create_tree_ah_index.return_value = mock_index
|
||||
mock_endpoint_class.create.return_value = mock_endpoint
|
||||
|
||||
create_vector_index(
|
||||
index_name="test-index",
|
||||
index_vectors_dir="gs://bucket/vectors",
|
||||
index_dimensions=768,
|
||||
index_distance_measure_type=DistanceMeasureType.DOT_PRODUCT_DISTANCE,
|
||||
index_deployment="test_index_deployed",
|
||||
index_machine_type="e2-standard-16",
|
||||
)
|
||||
|
||||
mock_index_class.create_tree_ah_index.assert_called_once_with(
|
||||
display_name="test-index",
|
||||
contents_delta_uri="gs://bucket/vectors",
|
||||
dimensions=768,
|
||||
approximate_neighbors_count=150,
|
||||
distance_measure_type=DistanceMeasureType.DOT_PRODUCT_DISTANCE,
|
||||
leaf_node_embedding_count=1000,
|
||||
leaf_nodes_to_search_percent=10,
|
||||
)
|
||||
|
||||
mock_endpoint_class.create.assert_called_once_with(
|
||||
display_name="test-index-endpoint",
|
||||
public_endpoint_enabled=True,
|
||||
)
|
||||
|
||||
mock_endpoint.deploy_index.assert_called_once_with(
|
||||
index=mock_index,
|
||||
deployed_index_id="test_index_deployed",
|
||||
machine_type="e2-standard-16",
|
||||
sync=False,
|
||||
)
|
||||
|
||||
|
||||
class TestProcessFile:
|
||||
"""Tests for process_file function."""
|
||||
|
||||
def test_process_file_success(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
# Mock dependencies
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_blob = Mock()
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
|
||||
mock_converter = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.text_content = "Converted markdown content"
|
||||
mock_converter.convert.return_value = mock_result
|
||||
|
||||
mock_embedder = Mock()
|
||||
mock_embeddings_result = Mock()
|
||||
mock_embeddings_result.embeddings = [[0.1, 0.2, 0.3]]
|
||||
mock_embedder.embed_documents_sync.return_value = mock_embeddings_result
|
||||
|
||||
mock_chunker = Mock(spec=BaseChunker)
|
||||
mock_chunker.max_chunk_size = 1000
|
||||
|
||||
file_uri = "gs://my-bucket/docs/test-file.pdf"
|
||||
|
||||
chunks, vectors = process_file(
|
||||
file_uri,
|
||||
temp_path,
|
||||
mock_client,
|
||||
mock_converter,
|
||||
mock_embedder,
|
||||
mock_chunker,
|
||||
)
|
||||
|
||||
# Verify download was called
|
||||
mock_client.bucket.assert_called_with("my-bucket")
|
||||
mock_bucket.blob.assert_called_with("docs/test-file.pdf")
|
||||
assert mock_blob.download_to_filename.called
|
||||
|
||||
# Verify converter was called
|
||||
assert mock_converter.convert.called
|
||||
|
||||
# Verify embedder was called
|
||||
mock_embedder.embed_documents_sync.assert_called_once()
|
||||
|
||||
# Verify results
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0]["page_content"] == "Converted markdown content"
|
||||
assert len(vectors) == 1
|
||||
assert vectors[0]["embedding"] == [0.1, 0.2, 0.3]
|
||||
|
||||
def test_process_file_cleans_up_temp_file(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_blob = Mock()
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
|
||||
mock_converter = Mock()
|
||||
mock_converter.convert.side_effect = Exception("Conversion failed")
|
||||
|
||||
mock_embedder = Mock()
|
||||
mock_chunker = Mock(spec=BaseChunker)
|
||||
|
||||
file_uri = "gs://my-bucket/docs/test.pdf"
|
||||
|
||||
# This should raise an exception but still clean up
|
||||
with pytest.raises(Exception, match="Conversion failed"):
|
||||
process_file(
|
||||
file_uri,
|
||||
temp_path,
|
||||
mock_client,
|
||||
mock_converter,
|
||||
mock_embedder,
|
||||
mock_chunker,
|
||||
)
|
||||
|
||||
# File should be cleaned up even after exception
|
||||
temp_file = temp_path / "test.pdf"
|
||||
assert not temp_file.exists()
|
||||
|
||||
|
||||
class TestRunPipeline:
|
||||
"""Tests for run_pipeline function."""
|
||||
|
||||
@patch("knowledge_pipeline.pipeline.create_vector_index")
|
||||
@patch("knowledge_pipeline.pipeline.upload_to_gcs")
|
||||
@patch("knowledge_pipeline.pipeline.process_file")
|
||||
@patch("knowledge_pipeline.pipeline.gather_pdfs")
|
||||
def test_run_pipeline_integration(
|
||||
self,
|
||||
mock_gather,
|
||||
mock_process,
|
||||
mock_upload,
|
||||
mock_create_index,
|
||||
):
|
||||
# Mock settings
|
||||
mock_settings = Mock()
|
||||
mock_settings.index_origin = "gs://bucket/input"
|
||||
mock_settings.index_contents_dir = "gs://bucket/contents"
|
||||
mock_settings.index_vectors_jsonl_path = "gs://bucket/vectors/vectors.jsonl"
|
||||
mock_settings.index_name = "test-index"
|
||||
mock_settings.index_vectors_dir = "gs://bucket/vectors"
|
||||
mock_settings.index_dimensions = 768
|
||||
mock_settings.index_distance_measure_type = (
|
||||
DistanceMeasureType.DOT_PRODUCT_DISTANCE
|
||||
)
|
||||
mock_settings.index_deployment = "test_index_deployed"
|
||||
mock_settings.index_machine_type = "e2-standard-16"
|
||||
|
||||
mock_gcs_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_blob = Mock()
|
||||
mock_gcs_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
mock_settings.gcs_client = mock_gcs_client
|
||||
|
||||
mock_settings.converter = Mock()
|
||||
mock_settings.embedder = Mock()
|
||||
mock_settings.chunker = Mock()
|
||||
|
||||
# Mock gather_files to return test files
|
||||
mock_gather.return_value = ["gs://bucket/input/file1.pdf"]
|
||||
|
||||
# Mock process_file to return chunks and vectors
|
||||
mock_chunks = [{"page_content": "content", "metadata": {"id": "chunk_1"}}]
|
||||
mock_vectors = [
|
||||
{
|
||||
"id": "chunk_1",
|
||||
"embedding": [0.1, 0.2],
|
||||
"restricts": [{"namespace": "source", "allow": ["input"]}],
|
||||
}
|
||||
]
|
||||
mock_process.return_value = (mock_chunks, mock_vectors)
|
||||
|
||||
run_pipeline(mock_settings)
|
||||
|
||||
# Verify all steps were called
|
||||
mock_gather.assert_called_once_with("gs://bucket/input", mock_gcs_client)
|
||||
mock_process.assert_called_once()
|
||||
mock_upload.assert_called_once_with(
|
||||
mock_chunks,
|
||||
mock_vectors,
|
||||
"gs://bucket/contents",
|
||||
"gs://bucket/vectors/vectors.jsonl",
|
||||
mock_gcs_client,
|
||||
)
|
||||
mock_create_index.assert_called_once()
|
||||
|
||||
@patch("knowledge_pipeline.pipeline.create_vector_index")
|
||||
@patch("knowledge_pipeline.pipeline.upload_to_gcs")
|
||||
@patch("knowledge_pipeline.pipeline.process_file")
|
||||
@patch("knowledge_pipeline.pipeline.gather_pdfs")
|
||||
def test_run_pipeline_multiple_files(
|
||||
self,
|
||||
mock_gather,
|
||||
mock_process,
|
||||
mock_upload,
|
||||
mock_create_index,
|
||||
):
|
||||
mock_settings = Mock()
|
||||
mock_settings.index_origin = "gs://bucket/input"
|
||||
mock_settings.index_contents_dir = "gs://bucket/contents"
|
||||
mock_settings.index_vectors_jsonl_path = "gs://bucket/vectors/vectors.jsonl"
|
||||
mock_settings.index_name = "test-index"
|
||||
mock_settings.index_vectors_dir = "gs://bucket/vectors"
|
||||
mock_settings.index_dimensions = 768
|
||||
mock_settings.index_distance_measure_type = (
|
||||
DistanceMeasureType.DOT_PRODUCT_DISTANCE
|
||||
)
|
||||
mock_settings.index_deployment = "test_index_deployed"
|
||||
mock_settings.index_machine_type = "e2-standard-16"
|
||||
|
||||
mock_gcs_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_blob = Mock()
|
||||
mock_gcs_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
mock_settings.gcs_client = mock_gcs_client
|
||||
|
||||
mock_settings.converter = Mock()
|
||||
mock_settings.embedder = Mock()
|
||||
mock_settings.chunker = Mock()
|
||||
|
||||
# Return multiple files
|
||||
mock_gather.return_value = [
|
||||
"gs://bucket/input/file1.pdf",
|
||||
"gs://bucket/input/file2.pdf",
|
||||
]
|
||||
|
||||
mock_process.return_value = (
|
||||
[{"page_content": "content", "metadata": {"id": "chunk_1"}}],
|
||||
[{"id": "chunk_1", "embedding": [0.1], "restricts": []}],
|
||||
)
|
||||
|
||||
run_pipeline(mock_settings)
|
||||
|
||||
# Verify process_file was called for each file
|
||||
assert mock_process.call_count == 2
|
||||
# Upload is called once with all accumulated chunks and vectors
|
||||
mock_upload.assert_called_once()
|
||||
Reference in New Issue
Block a user