diff --git a/app/api/endpoints/system.py b/app/api/endpoints/system.py index ad60ae00..6a830afa 100644 --- a/app/api/endpoints/system.py +++ b/app/api/endpoints/system.py @@ -73,8 +73,9 @@ async def fetch_image( # 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择 cache_path = cache_path.with_suffix(".jpg") - # 缓存对像 - cache_backend = get_async_file_cache_backend(base=settings.CACHE_PATH) + # 缓存对像,缓存过期时间为全局图片缓存天数 + cache_backend = get_async_file_cache_backend(base=settings.CACHE_PATH, + ttl=settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600) if use_cache: content = await cache_backend.get(cache_path.as_posix(), region="images") diff --git a/app/chain/recommend.py b/app/chain/recommend.py index cde78336..83391c03 100644 --- a/app/chain/recommend.py +++ b/app/chain/recommend.py @@ -110,8 +110,9 @@ class RecommendChain(ChainBase, metaclass=Singleton): if not cache_path.suffix: cache_path = cache_path.with_suffix(".jpg") - # 获取缓存后端 - cache_backend = get_file_cache_backend(base=settings.CACHE_PATH) + # 获取缓存后端,并设置缓存时间为全局配置的缓存天数 + cache_backend = get_file_cache_backend(base=settings.CACHE_PATH, + ttl=settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600) # 本地存在缓存图片,则直接跳过 if cache_backend.get(cache_path.as_posix(), region="images"): diff --git a/app/core/cache.py b/app/core/cache.py index 6111b3e1..1638674f 100644 --- a/app/core/cache.py +++ b/app/core/cache.py @@ -5,7 +5,7 @@ import threading from abc import ABC, abstractmethod from functools import wraps from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Generator, AsyncGenerator, Tuple import aiofiles import aioshutil @@ -84,7 +84,7 @@ class CacheBackend(ABC): pass @abstractmethod - def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> Dict[str, Any]: + def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> Generator[Tuple[str, Any], None, None]: """ 获取指定区域的所有缓存项 @@ -197,7 +197,7 @@ class AsyncCacheBackend(ABC): pass @abstractmethod - async def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> Dict[str, Any]: + async def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> AsyncGenerator[Tuple[str, Any], None]: """ 获取指定区域的所有缓存项 @@ -352,7 +352,7 @@ class CacheToolsBackend(CacheBackend): region_cache.clear() logger.info("Cleared all cache") - def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> Dict[str, Any]: + def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> Generator[Tuple[str, Any], None, None]: """ 获取指定区域的所有缓存项 @@ -361,8 +361,9 @@ class CacheToolsBackend(CacheBackend): """ region_cache = self.__get_region_cache(region) if region_cache is None: - return {} - return dict(region_cache.items()) + yield from () + for item in region_cache.items(): + yield item def close(self) -> None: """ @@ -436,7 +437,7 @@ class RedisBackend(CacheBackend): """ self.redis_helper.clear(region=region) - def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> Dict[str, Any]: + def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> Generator[Tuple[Any, Any], None, None]: """ 获取指定区域的所有缓存项 @@ -517,14 +518,15 @@ class AsyncRedisBackend(AsyncCacheBackend): """ await self.redis_helper.clear(region=region) - async def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> Dict[str, Any]: + async def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> AsyncGenerator[Tuple[str, Any], None]: """ 获取指定区域的所有缓存项 :param region: 缓存的区 :return: 返回一个字典,包含所有缓存键值对 """ - return await self.redis_helper.items(region=region) + async for item in self.redis_helper.items(region=region): + yield item async def close(self) -> None: """ @@ -572,7 +574,7 @@ class FileBackend(CacheBackend): :param region: 缓存的区 :return: 存在返回 True,否则返回 False """ - cache_path = self.base / key + cache_path = self.base / region / key return cache_path.exists() def get(self, key: str, region: Optional[str] = DEFAULT_CACHE_REGION) -> Optional[Any]: @@ -623,7 +625,7 @@ class FileBackend(CacheBackend): else: shutil.rmtree(item, ignore_errors=True) - def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> Dict[str, Any]: + def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> Generator[Tuple[str, Any], None, None]: """ 获取指定区域的所有缓存项 @@ -632,11 +634,11 @@ class FileBackend(CacheBackend): """ cache_path = self.base / region if not cache_path.exists(): - return {} + yield from () for item in cache_path.iterdir(): if item.is_file(): with open(item, 'r') as f: - yield f.read() + yield item.name, f.read() def close(self) -> None: """ @@ -722,20 +724,20 @@ class AsyncFileBackend(AsyncCacheBackend): # 清理指定缓存区 cache_path = AsyncPath(self.base) / region if await cache_path.exists(): - for item in cache_path.iterdir(): + async for item in cache_path.iterdir(): if await item.is_file(): await item.unlink() else: await aioshutil.rmtree(item, ignore_errors=True) else: # 清除所有区域的缓存 - for item in AsyncPath(self.base).iterdir(): + async for item in AsyncPath(self.base).iterdir(): if await item.is_file(): await item.unlink() else: await aioshutil.rmtree(item, ignore_errors=True) - async def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> Dict[str, Any]: + async def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> AsyncGenerator[Tuple[str, Any], None]: """ 获取指定区域的所有缓存项 @@ -744,11 +746,11 @@ class AsyncFileBackend(AsyncCacheBackend): """ cache_path = AsyncPath(self.base) / region if not await cache_path.exists(): - yield None - for item in cache_path.iterdir(): + yield "", None + async for item in cache_path.iterdir(): if await item.is_file(): async with aiofiles.open(item, 'r') as f: - yield await f.read() + yield item.name, await f.read() async def close(self) -> None: """ @@ -757,23 +759,27 @@ class AsyncFileBackend(AsyncCacheBackend): pass -def get_file_cache_backend(base: Path = settings.TEMP_PATH) -> CacheBackend: +def get_file_cache_backend(base: Path = settings.TEMP_PATH, ttl: Optional[int] = None) -> CacheBackend: """ 获取文件缓存后端实例(Redis或文件系统) """ if settings.CACHE_BACKEND_TYPE == "redis": - return RedisBackend() + # 如果使用 Redis,则设置缓存的存活时间为配置的天数转换为秒 + return RedisBackend(ttl=ttl or settings.TEMP_FILE_DAYS * 24 * 3600) else: + # 如果使用文件系统,在停止服务时会自动清理过期文件 return FileBackend(base=base) -def get_async_file_cache_backend(base: Path = settings.TEMP_PATH) -> AsyncCacheBackend: +def get_async_file_cache_backend(base: Path = settings.TEMP_PATH, ttl: Optional[int] = None) -> AsyncCacheBackend: """ 获取文件异步缓存后端实例(Redis或文件系统) """ if settings.CACHE_BACKEND_TYPE == "redis": - return AsyncRedisBackend() + # 如果使用 Redis,则设置缓存的存活时间为配置的天数转换为秒 + return AsyncRedisBackend(ttl=ttl or settings.TEMP_FILE_DAYS * 24 * 3600) else: + # 如果使用文件系统,在停止服务时会自动清理过期文件 return AsyncFileBackend(base=base) diff --git a/app/helper/redis.py b/app/helper/redis.py index 8fa72be3..74107e09 100644 --- a/app/helper/redis.py +++ b/app/helper/redis.py @@ -1,6 +1,6 @@ import json import pickle -from typing import Any, Optional +from typing import Any, Optional, Generator, Tuple, AsyncGenerator from urllib.parse import quote import redis @@ -245,7 +245,7 @@ class RedisHelper(metaclass=Singleton): except Exception as e: logger.error(f"Failed to clear cache, region: {region}, error: {e}") - def items(self, region: Optional[str] = None): + def items(self, region: Optional[str] = None) -> Generator[Tuple[Any, Any], None, None]: """ 获取指定区域的所有缓存键值对 @@ -525,7 +525,7 @@ class AsyncRedisHelper(metaclass=Singleton): except Exception as e: logger.error(f"Failed to clear cache (async), region: {region}, error: {e}") - async def items(self, region: Optional[str] = None): + async def items(self, region: Optional[str] = None) -> AsyncGenerator[Tuple[Any, Any], None]: """ 获取指定区域的所有缓存键值对 @@ -537,12 +537,12 @@ class AsyncRedisHelper(metaclass=Singleton): if region: cache_region = self.get_region(quote(region)) redis_key = f"{cache_region}:key:*" - for key in self.client.scan_iter(redis_key): + async for key in self.client.scan_iter(redis_key): value = await self.client.get(key) if value is not None: yield key, self.deserialize(value) else: - for key in self.client.scan_iter("*"): + async for key in self.client.scan_iter("*"): value = await self.client.get(key) if value is not None: yield key, self.deserialize(value)