Files
agent/src/rag_eval/file_storage/google_cloud.py
2026-02-20 14:04:59 +00:00

189 lines
5.8 KiB
Python

"""Google Cloud Storage file storage implementation."""
import asyncio
import io
import logging
from typing import BinaryIO
import aiohttp
from gcloud.aio.storage import Storage
from google.cloud import storage
from rag_eval.file_storage.base import BaseFileStorage
logger = logging.getLogger(__name__)
HTTP_TOO_MANY_REQUESTS = 429
HTTP_SERVER_ERROR = 500
class GoogleCloudFileStorage(BaseFileStorage):
"""File storage backed by Google Cloud Storage."""
def __init__(self, bucket: str) -> None: # noqa: D107
self.bucket_name = bucket
self.storage_client = storage.Client()
self.bucket_client = self.storage_client.bucket(self.bucket_name)
self._aio_session: aiohttp.ClientSession | None = None
self._aio_storage: Storage | None = None
self._cache: dict[str, bytes] = {}
def upload_file(
self,
file_path: str,
destination_blob_name: str,
content_type: str | None = None,
) -> None:
"""Upload a file to Cloud Storage.
Args:
file_path: The local path to the file to upload.
destination_blob_name: Name of the blob in the bucket.
content_type: The content type of the file.
"""
blob = self.bucket_client.blob(destination_blob_name)
blob.upload_from_filename(
file_path,
content_type=content_type,
if_generation_match=0,
)
self._cache.pop(destination_blob_name, None)
def list_files(self, path: str | None = None) -> list[str]:
"""List all files at the given path in the bucket.
If path is None, recursively lists all files.
Args:
path: Prefix to filter files by.
Returns:
A list of blob names.
"""
blobs = self.storage_client.list_blobs(
self.bucket_name, prefix=path,
)
return [blob.name for blob in blobs]
def get_file_stream(self, file_name: str) -> BinaryIO:
"""Get a file as a file-like object, using cache.
Args:
file_name: The blob name to retrieve.
Returns:
A BytesIO stream with the file contents.
"""
if file_name not in self._cache:
blob = self.bucket_client.blob(file_name)
self._cache[file_name] = blob.download_as_bytes()
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
return file_stream
def _get_aio_session(self) -> aiohttp.ClientSession:
if self._aio_session is None or self._aio_session.closed:
connector = aiohttp.TCPConnector(
limit=300, limit_per_host=50,
)
timeout = aiohttp.ClientTimeout(total=60)
self._aio_session = aiohttp.ClientSession(
timeout=timeout, connector=connector,
)
return self._aio_session
def _get_aio_storage(self) -> Storage:
if self._aio_storage is None:
self._aio_storage = Storage(
session=self._get_aio_session(),
)
return self._aio_storage
async def async_get_file_stream(
self, file_name: str, max_retries: int = 3,
) -> BinaryIO:
"""Get a file asynchronously with retry on transient errors.
Args:
file_name: The blob name to retrieve.
max_retries: Maximum number of retry attempts.
Returns:
A BytesIO stream with the file contents.
Raises:
TimeoutError: If all retry attempts fail.
"""
if file_name in self._cache:
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
return file_stream
storage_client = self._get_aio_storage()
last_exception: Exception | None = None
for attempt in range(max_retries):
try:
self._cache[file_name] = await storage_client.download(
self.bucket_name, file_name,
)
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
except TimeoutError as exc:
last_exception = exc
logger.warning(
"Timeout downloading gs://%s/%s (attempt %d/%d)",
self.bucket_name,
file_name,
attempt + 1,
max_retries,
)
except aiohttp.ClientResponseError as exc:
last_exception = exc
if (
exc.status == HTTP_TOO_MANY_REQUESTS
or exc.status >= HTTP_SERVER_ERROR
):
logger.warning(
"HTTP %d downloading gs://%s/%s "
"(attempt %d/%d)",
exc.status,
self.bucket_name,
file_name,
attempt + 1,
max_retries,
)
else:
raise
else:
return file_stream
if attempt < max_retries - 1:
delay = 0.5 * (2**attempt)
await asyncio.sleep(delay)
msg = (
f"Failed to download gs://{self.bucket_name}/{file_name} "
f"after {max_retries} attempts"
)
raise TimeoutError(msg) from last_exception
def delete_files(self, path: str) -> None:
"""Delete all files at the given path in the bucket.
Args:
path: Prefix of blobs to delete.
"""
blobs = self.storage_client.list_blobs(
self.bucket_name, prefix=path,
)
for blob in blobs:
blob.delete()
self._cache.pop(blob.name, None)