189 lines
5.8 KiB
Python
189 lines
5.8 KiB
Python
"""Google Cloud Storage file storage implementation."""
|
|
|
|
import asyncio
|
|
import io
|
|
import logging
|
|
from typing import BinaryIO
|
|
|
|
import aiohttp
|
|
from gcloud.aio.storage import Storage
|
|
from google.cloud import storage
|
|
|
|
from .base import BaseFileStorage
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
HTTP_TOO_MANY_REQUESTS = 429
|
|
HTTP_SERVER_ERROR = 500
|
|
|
|
|
|
class GoogleCloudFileStorage(BaseFileStorage):
|
|
"""File storage backed by Google Cloud Storage."""
|
|
|
|
def __init__(self, bucket: str) -> None: # noqa: D107
|
|
self.bucket_name = bucket
|
|
|
|
self.storage_client = storage.Client()
|
|
self.bucket_client = self.storage_client.bucket(self.bucket_name)
|
|
self._aio_session: aiohttp.ClientSession | None = None
|
|
self._aio_storage: Storage | None = None
|
|
self._cache: dict[str, bytes] = {}
|
|
|
|
def upload_file(
|
|
self,
|
|
file_path: str,
|
|
destination_blob_name: str,
|
|
content_type: str | None = None,
|
|
) -> None:
|
|
"""Upload a file to Cloud Storage.
|
|
|
|
Args:
|
|
file_path: The local path to the file to upload.
|
|
destination_blob_name: Name of the blob in the bucket.
|
|
content_type: The content type of the file.
|
|
|
|
"""
|
|
blob = self.bucket_client.blob(destination_blob_name)
|
|
blob.upload_from_filename(
|
|
file_path,
|
|
content_type=content_type,
|
|
if_generation_match=0,
|
|
)
|
|
self._cache.pop(destination_blob_name, None)
|
|
|
|
def list_files(self, path: str | None = None) -> list[str]:
|
|
"""List all files at the given path in the bucket.
|
|
|
|
If path is None, recursively lists all files.
|
|
|
|
Args:
|
|
path: Prefix to filter files by.
|
|
|
|
Returns:
|
|
A list of blob names.
|
|
|
|
"""
|
|
blobs = self.storage_client.list_blobs(
|
|
self.bucket_name, prefix=path,
|
|
)
|
|
return [blob.name for blob in blobs]
|
|
|
|
def get_file_stream(self, file_name: str) -> BinaryIO:
|
|
"""Get a file as a file-like object, using cache.
|
|
|
|
Args:
|
|
file_name: The blob name to retrieve.
|
|
|
|
Returns:
|
|
A BytesIO stream with the file contents.
|
|
|
|
"""
|
|
if file_name not in self._cache:
|
|
blob = self.bucket_client.blob(file_name)
|
|
self._cache[file_name] = blob.download_as_bytes()
|
|
file_stream = io.BytesIO(self._cache[file_name])
|
|
file_stream.name = file_name
|
|
return file_stream
|
|
|
|
def _get_aio_session(self) -> aiohttp.ClientSession:
|
|
if self._aio_session is None or self._aio_session.closed:
|
|
connector = aiohttp.TCPConnector(
|
|
limit=300, limit_per_host=50,
|
|
)
|
|
timeout = aiohttp.ClientTimeout(total=60)
|
|
self._aio_session = aiohttp.ClientSession(
|
|
timeout=timeout, connector=connector,
|
|
)
|
|
return self._aio_session
|
|
|
|
def _get_aio_storage(self) -> Storage:
|
|
if self._aio_storage is None:
|
|
self._aio_storage = Storage(
|
|
session=self._get_aio_session(),
|
|
)
|
|
return self._aio_storage
|
|
|
|
async def async_get_file_stream(
|
|
self, file_name: str, max_retries: int = 3,
|
|
) -> BinaryIO:
|
|
"""Get a file asynchronously with retry on transient errors.
|
|
|
|
Args:
|
|
file_name: The blob name to retrieve.
|
|
max_retries: Maximum number of retry attempts.
|
|
|
|
Returns:
|
|
A BytesIO stream with the file contents.
|
|
|
|
Raises:
|
|
TimeoutError: If all retry attempts fail.
|
|
|
|
"""
|
|
if file_name in self._cache:
|
|
file_stream = io.BytesIO(self._cache[file_name])
|
|
file_stream.name = file_name
|
|
return file_stream
|
|
|
|
storage_client = self._get_aio_storage()
|
|
last_exception: Exception | None = None
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
self._cache[file_name] = await storage_client.download(
|
|
self.bucket_name, file_name,
|
|
)
|
|
file_stream = io.BytesIO(self._cache[file_name])
|
|
file_stream.name = file_name
|
|
except TimeoutError as exc:
|
|
last_exception = exc
|
|
logger.warning(
|
|
"Timeout downloading gs://%s/%s (attempt %d/%d)",
|
|
self.bucket_name,
|
|
file_name,
|
|
attempt + 1,
|
|
max_retries,
|
|
)
|
|
except aiohttp.ClientResponseError as exc:
|
|
last_exception = exc
|
|
if (
|
|
exc.status == HTTP_TOO_MANY_REQUESTS
|
|
or exc.status >= HTTP_SERVER_ERROR
|
|
):
|
|
logger.warning(
|
|
"HTTP %d downloading gs://%s/%s "
|
|
"(attempt %d/%d)",
|
|
exc.status,
|
|
self.bucket_name,
|
|
file_name,
|
|
attempt + 1,
|
|
max_retries,
|
|
)
|
|
else:
|
|
raise
|
|
else:
|
|
return file_stream
|
|
|
|
if attempt < max_retries - 1:
|
|
delay = 0.5 * (2**attempt)
|
|
await asyncio.sleep(delay)
|
|
|
|
msg = (
|
|
f"Failed to download gs://{self.bucket_name}/{file_name} "
|
|
f"after {max_retries} attempts"
|
|
)
|
|
raise TimeoutError(msg) from last_exception
|
|
|
|
def delete_files(self, path: str) -> None:
|
|
"""Delete all files at the given path in the bucket.
|
|
|
|
Args:
|
|
path: Prefix of blobs to delete.
|
|
|
|
"""
|
|
blobs = self.storage_client.list_blobs(
|
|
self.bucket_name, prefix=path,
|
|
)
|
|
for blob in blobs:
|
|
blob.delete()
|
|
self._cache.pop(blob.name, None)
|