"""Google Cloud Storage file storage implementation.""" import asyncio import io import logging from typing import BinaryIO import aiohttp from gcloud.aio.storage import Storage from google.cloud import storage from .base import BaseFileStorage logger = logging.getLogger(__name__) HTTP_TOO_MANY_REQUESTS = 429 HTTP_SERVER_ERROR = 500 class GoogleCloudFileStorage(BaseFileStorage): """File storage backed by Google Cloud Storage.""" def __init__(self, bucket: str) -> None: # noqa: D107 self.bucket_name = bucket self.storage_client = storage.Client() self.bucket_client = self.storage_client.bucket(self.bucket_name) self._aio_session: aiohttp.ClientSession | None = None self._aio_storage: Storage | None = None self._cache: dict[str, bytes] = {} def upload_file( self, file_path: str, destination_blob_name: str, content_type: str | None = None, ) -> None: """Upload a file to Cloud Storage. Args: file_path: The local path to the file to upload. destination_blob_name: Name of the blob in the bucket. content_type: The content type of the file. """ blob = self.bucket_client.blob(destination_blob_name) blob.upload_from_filename( file_path, content_type=content_type, if_generation_match=0, ) self._cache.pop(destination_blob_name, None) def list_files(self, path: str | None = None) -> list[str]: """List all files at the given path in the bucket. If path is None, recursively lists all files. Args: path: Prefix to filter files by. Returns: A list of blob names. """ blobs = self.storage_client.list_blobs( self.bucket_name, prefix=path, ) return [blob.name for blob in blobs] def get_file_stream(self, file_name: str) -> BinaryIO: """Get a file as a file-like object, using cache. Args: file_name: The blob name to retrieve. Returns: A BytesIO stream with the file contents. """ if file_name not in self._cache: blob = self.bucket_client.blob(file_name) self._cache[file_name] = blob.download_as_bytes() file_stream = io.BytesIO(self._cache[file_name]) file_stream.name = file_name return file_stream def _get_aio_session(self) -> aiohttp.ClientSession: if self._aio_session is None or self._aio_session.closed: connector = aiohttp.TCPConnector( limit=300, limit_per_host=50, ) timeout = aiohttp.ClientTimeout(total=60) self._aio_session = aiohttp.ClientSession( timeout=timeout, connector=connector, ) return self._aio_session def _get_aio_storage(self) -> Storage: if self._aio_storage is None: self._aio_storage = Storage( session=self._get_aio_session(), ) return self._aio_storage async def async_get_file_stream( self, file_name: str, max_retries: int = 3, ) -> BinaryIO: """Get a file asynchronously with retry on transient errors. Args: file_name: The blob name to retrieve. max_retries: Maximum number of retry attempts. Returns: A BytesIO stream with the file contents. Raises: TimeoutError: If all retry attempts fail. """ if file_name in self._cache: file_stream = io.BytesIO(self._cache[file_name]) file_stream.name = file_name return file_stream storage_client = self._get_aio_storage() last_exception: Exception | None = None for attempt in range(max_retries): try: self._cache[file_name] = await storage_client.download( self.bucket_name, file_name, ) file_stream = io.BytesIO(self._cache[file_name]) file_stream.name = file_name except TimeoutError as exc: last_exception = exc logger.warning( "Timeout downloading gs://%s/%s (attempt %d/%d)", self.bucket_name, file_name, attempt + 1, max_retries, ) except aiohttp.ClientResponseError as exc: last_exception = exc if ( exc.status == HTTP_TOO_MANY_REQUESTS or exc.status >= HTTP_SERVER_ERROR ): logger.warning( "HTTP %d downloading gs://%s/%s " "(attempt %d/%d)", exc.status, self.bucket_name, file_name, attempt + 1, max_retries, ) else: raise else: return file_stream if attempt < max_retries - 1: delay = 0.5 * (2**attempt) await asyncio.sleep(delay) msg = ( f"Failed to download gs://{self.bucket_name}/{file_name} " f"after {max_retries} attempts" ) raise TimeoutError(msg) from last_exception def delete_files(self, path: str) -> None: """Delete all files at the given path in the bucket. Args: path: Prefix of blobs to delete. """ blobs = self.storage_client.list_blobs( self.bucket_name, prefix=path, ) for blob in blobs: blob.delete() self._cache.pop(blob.name, None)