"""Google Cloud Storage file storage implementation.""" import asyncio import io import logging from typing import BinaryIO import aiohttp from gcloud.aio.storage import Storage from google.cloud import storage from rag_eval.file_storage.base import BaseFileStorage logger = logging.getLogger(__name__) HTTP_TOO_MANY_REQUESTS = 429 HTTP_SERVER_ERROR = 500 class GoogleCloudFileStorage(BaseFileStorage): """File storage backed by Google Cloud Storage.""" def __init__(self, bucket: str) -> None: # noqa: D107 self.bucket_name = bucket self.storage_client = storage.Client() self.bucket_client = self.storage_client.bucket(self.bucket_name) self._aio_session: aiohttp.ClientSession | None = None self._aio_storage: Storage | None = None self._cache: dict[str, bytes] = {} def upload_file( self, file_path: str, destination_blob_name: str, content_type: str | None = None, ) -> None: """Upload a file to Cloud Storage. Args: file_path: The local path to the file to upload. destination_blob_name: Name of the blob in the bucket. content_type: The content type of the file. """ blob = self.bucket_client.blob(destination_blob_name) blob.upload_from_filename( file_path, content_type=content_type, if_generation_match=0, ) self._cache.pop(destination_blob_name, None) def list_files(self, path: str | None = None) -> list[str]: """List all files at the given path in the bucket. If path is None, recursively lists all files. Args: path: Prefix to filter files by. Returns: A list of blob names. """ blobs = self.storage_client.list_blobs( self.bucket_name, prefix=path, ) return [blob.name for blob in blobs] def get_file_stream(self, file_name: str) -> BinaryIO: """Get a file as a file-like object, using cache. Args: file_name: The blob name to retrieve. Returns: A BytesIO stream with the file contents. """ if file_name not in self._cache: blob = self.bucket_client.blob(file_name) self._cache[file_name] = blob.download_as_bytes() file_stream = io.BytesIO(self._cache[file_name]) file_stream.name = file_name return file_stream def _get_aio_session(self) -> aiohttp.ClientSession: if self._aio_session is None or self._aio_session.closed: connector = aiohttp.TCPConnector( limit=300, limit_per_host=50, ) timeout = aiohttp.ClientTimeout(total=60) self._aio_session = aiohttp.ClientSession( timeout=timeout, connector=connector, ) return self._aio_session def _get_aio_storage(self) -> Storage: if self._aio_storage is None: self._aio_storage = Storage( session=self._get_aio_session(), ) return self._aio_storage async def async_get_file_stream( self, file_name: str, max_retries: int = 3, ) -> BinaryIO: """Get a file asynchronously with retry on transient errors. Args: file_name: The blob name to retrieve. max_retries: Maximum number of retry attempts. Returns: A BytesIO stream with the file contents. Raises: TimeoutError: If all retry attempts fail. """ if file_name in self._cache: file_stream = io.BytesIO(self._cache[file_name]) file_stream.name = file_name return file_stream storage_client = self._get_aio_storage() last_exception: Exception | None = None for attempt in range(max_retries): try: self._cache[file_name] = await storage_client.download( self.bucket_name, file_name, ) file_stream = io.BytesIO(self._cache[file_name]) file_stream.name = file_name except TimeoutError as exc: last_exception = exc logger.warning( "Timeout downloading gs://%s/%s (attempt %d/%d)", self.bucket_name, file_name, attempt + 1, max_retries, ) except aiohttp.ClientResponseError as exc: last_exception = exc if ( exc.status == HTTP_TOO_MANY_REQUESTS or exc.status >= HTTP_SERVER_ERROR ): logger.warning( "HTTP %d downloading gs://%s/%s " "(attempt %d/%d)", exc.status, self.bucket_name, file_name, attempt + 1, max_retries, ) else: raise else: return file_stream if attempt < max_retries - 1: delay = 0.5 * (2**attempt) await asyncio.sleep(delay) msg = ( f"Failed to download gs://{self.bucket_name}/{file_name} " f"after {max_retries} attempts" ) raise TimeoutError(msg) from last_exception def delete_files(self, path: str) -> None: """Delete all files at the given path in the bucket. Args: path: Prefix of blobs to delete. """ blobs = self.storage_client.list_blobs( self.bucket_name, prefix=path, ) for blob in blobs: blob.delete() self._cache.pop(blob.name, None)