feat: Add sorting functionality to the virtual file system and adapter list methods

This commit is contained in:
shiyu
2025-09-06 16:15:24 +08:00
parent 57919aa7ae
commit 74ffc0bb30
12 changed files with 211 additions and 48 deletions

View File

@@ -10,7 +10,7 @@ from models import StorageAdapter
@runtime_checkable
class BaseAdapter(Protocol):
record: StorageAdapter
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50) -> Tuple[List[Dict], int]: ...
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]: ...
async def read_file(self, root: str, rel: str) -> bytes: ...
async def write_file(self, root: str, rel: str, data: bytes): ...
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]): ...

View File

@@ -46,25 +46,18 @@ class LocalAdapter:
return str(Path(root) / sub_path)
return root
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50) -> Tuple[List[Dict], int]:
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
rel = rel.strip('/')
base = _safe_join(root, rel) if rel else Path(root)
if not base.exists():
return [], 0
if not base.is_dir():
raise NotADirectoryError(rel)
# 获取所有文件名并排序
all_names = await asyncio.to_thread(lambda: sorted(os.listdir(base), key=str.lower))
total_count = len(all_names)
# 计算分页范围
start_idx = (page_num - 1) * page_size
end_idx = start_idx + page_size
page_names = all_names[start_idx:end_idx]
all_names = await asyncio.to_thread(os.listdir, base)
entries = []
for name in page_names:
for name in all_names:
fp = base / name
try:
st = await asyncio.to_thread(fp.stat)
@@ -79,10 +72,35 @@ class LocalAdapter:
"mode": stat.S_IMODE(st.st_mode),
"type": "dir" if is_dir else "file",
})
# 排序
reverse = sort_order.lower() == "desc"
# 按目录优先排序
entries.sort(key=lambda x: (not x["is_dir"], x["name"].lower()))
return entries, total_count
def get_sort_key(item):
# 基础排序键,目录优先
key = (not item["is_dir"],)
sort_field = sort_by.lower()
if sort_field == "name":
key += (item["name"].lower(),)
elif sort_field == "size":
key += (item["size"],)
elif sort_field == "mtime":
key += (item["mtime"],)
else: # 默认按名称
key += (item["name"].lower(),)
return key
entries.sort(key=get_sort_key, reverse=reverse)
total_count = len(entries)
# 分页
start_idx = (page_num - 1) * page_size
end_idx = start_idx + page_size
page_entries = entries[start_idx:end_idx]
return page_entries, total_count
async def read_file(self, root: str, rel: str) -> bytes:
fp = _safe_join(root, rel)

View File

@@ -114,7 +114,7 @@ class OneDriveAdapter:
"type": "dir" if is_dir else "file",
}
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50) -> Tuple[List[Dict], int]:
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
"""
列出目录内容。
由于 Graph API 不支持基于偏移($skip)的分页,此方法将获取所有项目,
@@ -122,6 +122,8 @@ class OneDriveAdapter:
:param rel: 相对路径。
:param page_num: 页码。
:param page_size: 每页大小。
:param sort_by: 排序字段
:param sort_order: 排序顺序
:return: 文件/目录列表和总数。
"""
api_path = self._get_api_path(rel)
@@ -149,8 +151,23 @@ class OneDriveAdapter:
resp = await self._request("GET", full_url=next_link)
formatted_items = [self._format_item(item) for item in all_items]
formatted_items.sort(key=lambda x: (
not x["is_dir"], x["name"].lower()))
# 排序
reverse = sort_order.lower() == "desc"
def get_sort_key(item):
key = (not item["is_dir"],)
sort_field = sort_by.lower()
if sort_field == "name":
key += (item["name"].lower(),)
elif sort_field == "size":
key += (item["size"],)
elif sort_field == "mtime":
key += (item["mtime"],)
else:
key += (item["name"].lower(),)
return key
formatted_items.sort(key=get_sort_key, reverse=reverse)
total_count = len(formatted_items)
start_idx = (page_num - 1) * page_size
end_idx = start_idx + page_size

View File

@@ -52,7 +52,7 @@ class S3Adapter:
def _get_client(self):
return self.session.client("s3", endpoint_url=self.endpoint_url)
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50) -> Tuple[List[Dict], int]:
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
prefix = self._get_s3_key(rel)
if prefix and not prefix.endswith("/"):
prefix += "/"
@@ -91,7 +91,21 @@ class S3Adapter:
})
# 在内存中排序和分页
all_items.sort(key=lambda x: (not x["is_dir"], x["name"].lower()))
reverse = sort_order.lower() == "desc"
def get_sort_key(item):
key = (not item["is_dir"],)
sort_field = sort_by.lower()
if sort_field == "name":
key += (item["name"].lower(),)
elif sort_field == "size":
key += (item["size"],)
elif sort_field == "mtime":
key += (item["mtime"],)
else:
key += (item["name"].lower(),)
return key
all_items.sort(key=get_sort_key, reverse=reverse)
total_count = len(all_items)
start_idx = (page_num - 1) * page_size
end_idx = start_idx + page_size

View File

@@ -62,7 +62,7 @@ class TelegramAdapter:
def get_effective_root(self, sub_path: str | None) -> str:
return ""
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50) -> Tuple[List[Dict], int]:
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
if rel:
return [], 0
@@ -70,7 +70,7 @@ class TelegramAdapter:
entries = []
try:
await client.connect()
messages = await client.get_messages(self.chat_id, limit=50)
messages = await client.get_messages(self.chat_id, limit=200)
for message in messages:
if not message:
continue
@@ -113,7 +113,30 @@ class TelegramAdapter:
if client.is_connected():
await client.disconnect()
return entries, len(entries)
# 排序
reverse = sort_order.lower() == "desc"
def get_sort_key(item):
key = (not item["is_dir"],)
sort_field = sort_by.lower()
if sort_field == "name":
key += (item["name"].lower(),)
elif sort_field == "size":
key += (item["size"],)
elif sort_field == "mtime":
key += (item["mtime"],)
else:
key += (item["name"].lower(),)
return key
entries.sort(key=get_sort_key, reverse=reverse)
total_count = len(entries)
# 分页
start_idx = (page_num - 1) * page_size
end_idx = start_idx + page_size
page_entries = entries[start_idx:end_idx]
return page_entries, total_count
async def read_file(self, root: str, rel: str) -> bytes:
try:

View File

@@ -39,7 +39,7 @@ class WebDAVAdapter:
rel = rel.strip('/')
return self.base_url if not rel else urljoin(self.base_url, quote(rel) + ('/' if rel.endswith('/') else ''))
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50) -> Tuple[List[Dict], int]:
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
raw_url = self._build_url(rel)
url = raw_url if raw_url.endswith('/') else raw_url + '/'
depth = "1"
@@ -92,16 +92,39 @@ class WebDAVAdapter:
"d:collection", NS) is not None if rt_el is not None else href_path.endswith('/')
size = int(
size_el.text) if size_el is not None and size_el.text and size_el.text.isdigit() else 0
from email.utils import parsedate_to_datetime
mtime = 0
if lm_el is not None and lm_el.text:
try:
mtime = int(parsedate_to_datetime(lm_el.text).timestamp())
except Exception:
mtime = 0
all_entries.append({
"name": name,
"is_dir": is_dir,
"size": 0 if is_dir else size,
"mtime": 0,
"mtime": mtime,
"type": "dir" if is_dir else "file",
})
# 排序所有条目
all_entries.sort(key=lambda x: (not x["is_dir"], x["name"].lower()))
reverse = sort_order.lower() == "desc"
def get_sort_key(item):
key = (not item["is_dir"],)
sort_field = sort_by.lower()
if sort_field == "name":
key += (item["name"].lower(),)
elif sort_field == "size":
key += (item["size"],)
elif sort_field == "mtime":
key += (item["mtime"],)
else:
key += (item["name"].lower(),)
return key
all_entries.sort(key=get_sort_key, reverse=reverse)
total_count = len(all_entries)
# 应用分页

View File

@@ -59,7 +59,7 @@ async def _ensure_method(adapter: Any, method: str):
return func
async def list_virtual_dir(path: str, page_num: int = 1, page_size: int = 50) -> Dict:
async def list_virtual_dir(path: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Dict:
norm = (path if path.startswith('/') else '/' + path).rstrip('/') or '/'
adapters = await StorageAdapter.filter(enabled=True)
@@ -100,7 +100,7 @@ async def list_virtual_dir(path: str, page_num: int = 1, page_size: int = 50) ->
if adapter_model and adapter_instance:
list_dir = await _ensure_method(adapter_instance, "list_dir")
try:
adapter_entries, adapter_total = await list_dir(effective_root, rel, page_num, page_size)
adapter_entries, adapter_total = await list_dir(effective_root, rel, page_num, page_size, sort_by, sort_order)
except NotADirectoryError:
raise HTTPException(400, detail="Not a directory")
@@ -118,17 +118,32 @@ async def list_virtual_dir(path: str, page_num: int = 1, page_size: int = 50) ->
ent['is_image'] = is_image_filename(ent['name'])
else:
ent['is_image'] = False
all_entries = adapter_entries + mount_entries
all_entries.sort(key=lambda x: (not x.get("is_dir"), x["name"].lower()))
total_entries = adapter_total + len(mount_entries)
if mount_entries:
reverse = sort_order.lower() == "desc"
def get_sort_key(item):
key = (not item.get("is_dir"),)
sort_field = sort_by.lower()
if sort_field == "name":
key += (item["name"].lower(),)
elif sort_field == "size":
key += (item.get("size", 0),)
elif sort_field == "mtime":
key += (item.get("mtime", 0),)
else:
key += (item["name"].lower(),)
return key
all_entries.sort(key=get_sort_key, reverse=reverse)
total_entries = adapter_total + len(mount_entries)
start_idx = (page_num - 1) * page_size
end_idx = start_idx + page_size
page_entries = all_entries[start_idx:end_idx]
return page(page_entries, total_entries, page_num, page_size)
else:
return page(adapter_entries, adapter_total, page_num, page_size)
return page(adapter_entries, adapter_total, page_num, page_size)
async def read_file(path: str) -> Union[bytes, Any]: