mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-05-11 18:10:10 +08:00
448 lines
16 KiB
Python
448 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import mimetypes
|
|
import stat as statmod
|
|
from typing import List, Dict, Tuple, AsyncIterator, Optional
|
|
|
|
from fastapi import HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
import paramiko
|
|
|
|
from models import StorageAdapter
|
|
from services.logging import LogService
|
|
|
|
|
|
def _join_remote(root: str, rel: str) -> str:
|
|
root = (root or "/").rstrip("/") or "/"
|
|
rel = (rel or "").lstrip("/")
|
|
if not rel:
|
|
return root
|
|
return f"{root}/{rel}"
|
|
|
|
|
|
class SFTPAdapter:
|
|
def __init__(self, record: StorageAdapter):
|
|
self.record = record
|
|
cfg = record.config
|
|
self.host: str = cfg.get("host")
|
|
self.port: int = int(cfg.get("port", 22))
|
|
self.username: str | None = cfg.get("username")
|
|
self.password: str | None = cfg.get("password")
|
|
self.timeout: int = int(cfg.get("timeout", 15))
|
|
self.root_path: str = cfg.get("root") # 必填
|
|
self.allow_unknown_host: bool = bool(cfg.get("allow_unknown_host", True))
|
|
|
|
if not self.host:
|
|
raise ValueError("SFTP adapter requires 'host'")
|
|
if not self.username or not self.password:
|
|
raise ValueError("SFTP adapter requires 'username' and 'password'")
|
|
if not self.root_path:
|
|
raise ValueError("SFTP adapter requires 'root'")
|
|
|
|
def get_effective_root(self, sub_path: str | None) -> str:
|
|
base = self.root_path.rstrip("/") or "/"
|
|
if sub_path:
|
|
return _join_remote(base, sub_path)
|
|
return base
|
|
|
|
def _connect(self) -> paramiko.SFTPClient:
|
|
ssh = paramiko.SSHClient()
|
|
if self.allow_unknown_host:
|
|
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
ssh.connect(
|
|
hostname=self.host,
|
|
port=self.port,
|
|
username=self.username,
|
|
password=self.password,
|
|
timeout=self.timeout,
|
|
allow_agent=False,
|
|
look_for_keys=False,
|
|
)
|
|
return ssh.open_sftp()
|
|
|
|
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
|
path = _join_remote(root, rel)
|
|
|
|
def _do_list() -> List[Dict]:
|
|
sftp = self._connect()
|
|
try:
|
|
attrs = sftp.listdir_attr(path)
|
|
entries: List[Dict] = []
|
|
for a in attrs:
|
|
name = a.filename
|
|
is_dir = statmod.S_ISDIR(a.st_mode)
|
|
entries.append({
|
|
"name": name,
|
|
"is_dir": is_dir,
|
|
"size": 0 if is_dir else int(a.st_size or 0),
|
|
"mtime": int(a.st_mtime or 0),
|
|
"type": "dir" if is_dir else "file",
|
|
})
|
|
return entries
|
|
finally:
|
|
try:
|
|
sftp.close()
|
|
except Exception:
|
|
pass
|
|
|
|
entries = await asyncio.to_thread(_do_list)
|
|
|
|
reverse = sort_order.lower() == "desc"
|
|
|
|
def get_sort_key(item):
|
|
key = (not item["is_dir"],)
|
|
f = sort_by.lower()
|
|
if f == "name":
|
|
key += (item["name"].lower(),)
|
|
elif f == "size":
|
|
key += (item.get("size", 0),)
|
|
elif f == "mtime":
|
|
key += (item.get("mtime", 0),)
|
|
else:
|
|
key += (item["name"].lower(),)
|
|
return key
|
|
|
|
entries.sort(key=get_sort_key, reverse=reverse)
|
|
total = len(entries)
|
|
start = (page_num - 1) * page_size
|
|
end = start + page_size
|
|
return entries[start:end], total
|
|
|
|
async def read_file(self, root: str, rel: str) -> bytes:
|
|
path = _join_remote(root, rel)
|
|
|
|
def _do_read() -> bytes:
|
|
sftp = self._connect()
|
|
try:
|
|
with sftp.open(path, "rb") as f:
|
|
return f.read()
|
|
except FileNotFoundError:
|
|
raise
|
|
except IOError as e:
|
|
if getattr(e, "errno", None) == 2:
|
|
raise FileNotFoundError(rel)
|
|
raise
|
|
finally:
|
|
try:
|
|
sftp.close()
|
|
except Exception:
|
|
pass
|
|
|
|
return await asyncio.to_thread(_do_read)
|
|
|
|
async def write_file(self, root: str, rel: str, data: bytes):
|
|
path = _join_remote(root, rel)
|
|
|
|
def _ensure_dirs(sftp: paramiko.SFTPClient, dir_path: str):
|
|
parts = [p for p in dir_path.strip("/").split("/") if p]
|
|
cur = "/"
|
|
for p in parts:
|
|
cur = _join_remote(cur, p)
|
|
try:
|
|
sftp.mkdir(cur)
|
|
except IOError:
|
|
# likely exists
|
|
pass
|
|
|
|
def _do_write():
|
|
sftp = self._connect()
|
|
try:
|
|
parent = "/" if "/" not in path.strip("/") else path.rsplit("/", 1)[0]
|
|
_ensure_dirs(sftp, parent)
|
|
with sftp.open(path, "wb") as f:
|
|
f.write(data)
|
|
finally:
|
|
try:
|
|
sftp.close()
|
|
except Exception:
|
|
pass
|
|
|
|
await asyncio.to_thread(_do_write)
|
|
await LogService.info("adapter:sftp", f"Wrote file to {rel}", details={"adapter_id": self.record.id, "path": path, "size": len(data)})
|
|
|
|
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
|
buf = bytearray()
|
|
async for chunk in data_iter:
|
|
if chunk:
|
|
buf.extend(chunk)
|
|
await self.write_file(root, rel, bytes(buf))
|
|
return len(buf)
|
|
|
|
async def mkdir(self, root: str, rel: str):
|
|
path = _join_remote(root, rel)
|
|
|
|
def _do_mkdir():
|
|
sftp = self._connect()
|
|
try:
|
|
parts = [p for p in path.strip("/").split("/") if p]
|
|
cur = "/"
|
|
for p in parts:
|
|
cur = _join_remote(cur, p)
|
|
try:
|
|
sftp.mkdir(cur)
|
|
except IOError:
|
|
pass
|
|
finally:
|
|
try:
|
|
sftp.close()
|
|
except Exception:
|
|
pass
|
|
|
|
await asyncio.to_thread(_do_mkdir)
|
|
await LogService.info("adapter:sftp", f"Created directory {rel}", details={"adapter_id": self.record.id, "path": path})
|
|
|
|
async def delete(self, root: str, rel: str):
|
|
path = _join_remote(root, rel)
|
|
|
|
def _do_delete():
|
|
sftp = self._connect()
|
|
try:
|
|
# Try file remove first
|
|
try:
|
|
sftp.remove(path)
|
|
return
|
|
except IOError:
|
|
pass
|
|
|
|
def _rm_tree(dp: str):
|
|
try:
|
|
for a in sftp.listdir_attr(dp):
|
|
child = _join_remote(dp, a.filename)
|
|
if statmod.S_ISDIR(a.st_mode):
|
|
_rm_tree(child)
|
|
else:
|
|
try:
|
|
sftp.remove(child)
|
|
except Exception:
|
|
pass
|
|
sftp.rmdir(dp)
|
|
except IOError:
|
|
pass
|
|
|
|
_rm_tree(path)
|
|
finally:
|
|
try:
|
|
sftp.close()
|
|
except Exception:
|
|
pass
|
|
|
|
await asyncio.to_thread(_do_delete)
|
|
await LogService.info("adapter:sftp", f"Deleted {rel}", details={"adapter_id": self.record.id, "path": path})
|
|
|
|
async def move(self, root: str, src_rel: str, dst_rel: str):
|
|
src = _join_remote(root, src_rel)
|
|
dst = _join_remote(root, dst_rel)
|
|
|
|
def _do_move():
|
|
sftp = self._connect()
|
|
try:
|
|
# ensure dst parent exists
|
|
parent = "/" if "/" not in dst.strip("/") else dst.rsplit("/", 1)[0]
|
|
parts = [p for p in parent.strip("/").split("/") if p]
|
|
cur = "/"
|
|
for p in parts:
|
|
cur = _join_remote(cur, p)
|
|
try:
|
|
sftp.mkdir(cur)
|
|
except IOError:
|
|
pass
|
|
sftp.rename(src, dst)
|
|
finally:
|
|
try:
|
|
sftp.close()
|
|
except Exception:
|
|
pass
|
|
|
|
await asyncio.to_thread(_do_move)
|
|
await LogService.info("adapter:sftp", f"Moved {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
|
|
|
|
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
|
await self.move(root, src_rel, dst_rel)
|
|
|
|
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
|
src = _join_remote(root, src_rel)
|
|
dst = _join_remote(root, dst_rel)
|
|
|
|
def _is_dir() -> bool:
|
|
sftp = self._connect()
|
|
try:
|
|
st = sftp.stat(src)
|
|
return statmod.S_ISDIR(st.st_mode)
|
|
finally:
|
|
try:
|
|
sftp.close()
|
|
except Exception:
|
|
pass
|
|
|
|
if await asyncio.to_thread(_is_dir):
|
|
await self.mkdir(root, dst_rel)
|
|
|
|
children, _ = await self.list_dir(root, src_rel, page_num=1, page_size=10_000)
|
|
for ent in children:
|
|
child_src = f"{src_rel.rstrip('/')}/{ent['name']}"
|
|
child_dst = f"{dst_rel.rstrip('/')}/{ent['name']}"
|
|
await self.copy(root, child_src, child_dst, overwrite)
|
|
await LogService.info("adapter:sftp", f"Copied directory {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
|
|
return
|
|
|
|
# file copy
|
|
data = await self.read_file(root, src_rel)
|
|
if not overwrite:
|
|
try:
|
|
await self.stat_file(root, dst_rel)
|
|
raise FileExistsError(dst_rel)
|
|
except FileNotFoundError:
|
|
pass
|
|
await self.write_file(root, dst_rel, data)
|
|
await LogService.info("adapter:sftp", f"Copied {src_rel} to {dst_rel}", details={"adapter_id": self.record.id, "src": src, "dst": dst})
|
|
|
|
async def stat_file(self, root: str, rel: str):
|
|
path = _join_remote(root, rel)
|
|
|
|
def _do_stat():
|
|
sftp = self._connect()
|
|
try:
|
|
st = sftp.stat(path)
|
|
is_dir = statmod.S_ISDIR(st.st_mode)
|
|
info = {
|
|
"name": rel.split("/")[-1],
|
|
"is_dir": is_dir,
|
|
"size": 0 if is_dir else int(st.st_size or 0),
|
|
"mtime": int(st.st_mtime or 0),
|
|
"type": "dir" if is_dir else "file",
|
|
"path": path,
|
|
}
|
|
return info
|
|
except FileNotFoundError:
|
|
raise
|
|
except IOError as e:
|
|
if getattr(e, "errno", None) == 2:
|
|
raise FileNotFoundError(rel)
|
|
raise
|
|
finally:
|
|
try:
|
|
sftp.close()
|
|
except Exception:
|
|
pass
|
|
|
|
return await asyncio.to_thread(_do_stat)
|
|
|
|
async def exists(self, root: str, rel: str) -> bool:
|
|
try:
|
|
await self.stat_file(root, rel)
|
|
return True
|
|
except FileNotFoundError:
|
|
return False
|
|
except Exception:
|
|
return False
|
|
|
|
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
|
path = _join_remote(root, rel)
|
|
|
|
def _get_stat():
|
|
sftp = self._connect()
|
|
try:
|
|
st = sftp.stat(path)
|
|
return int(st.st_size or 0)
|
|
finally:
|
|
try:
|
|
sftp.close()
|
|
except Exception:
|
|
pass
|
|
|
|
file_size = await asyncio.to_thread(_get_stat)
|
|
if file_size is None:
|
|
raise HTTPException(404, detail="File not found")
|
|
|
|
mime, _ = mimetypes.guess_type(rel)
|
|
content_type = mime or "application/octet-stream"
|
|
|
|
start = 0
|
|
end = file_size - 1
|
|
status = 200
|
|
headers = {
|
|
"Accept-Ranges": "bytes",
|
|
"Content-Type": content_type,
|
|
"Content-Length": str(file_size),
|
|
}
|
|
|
|
if range_header and range_header.startswith("bytes="):
|
|
try:
|
|
s, e = (range_header.removeprefix("bytes=").split("-", 1))
|
|
if s.strip():
|
|
start = int(s)
|
|
if e.strip():
|
|
end = int(e)
|
|
if start >= file_size:
|
|
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
|
if end >= file_size:
|
|
end = file_size - 1
|
|
status = 206
|
|
headers["Content-Length"] = str(end - start + 1)
|
|
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
|
|
except ValueError:
|
|
raise HTTPException(400, detail="Invalid Range header")
|
|
|
|
queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue(maxsize=8)
|
|
|
|
def _worker():
|
|
sftp = self._connect()
|
|
try:
|
|
with sftp.open(path, "rb") as f:
|
|
f.seek(start)
|
|
remaining = end - start + 1
|
|
chunk_size = 64 * 1024
|
|
while remaining > 0:
|
|
to_read = chunk_size if remaining > chunk_size else remaining
|
|
data = f.read(to_read)
|
|
if not data:
|
|
break
|
|
try:
|
|
queue.put_nowait(data)
|
|
except Exception:
|
|
break
|
|
remaining -= len(data)
|
|
try:
|
|
queue.put_nowait(None)
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
try:
|
|
sftp.close()
|
|
except Exception:
|
|
pass
|
|
|
|
async def agen():
|
|
worker_fut = asyncio.to_thread(_worker)
|
|
try:
|
|
while True:
|
|
chunk = await queue.get()
|
|
if chunk is None:
|
|
break
|
|
yield chunk
|
|
finally:
|
|
try:
|
|
await worker_fut
|
|
except Exception:
|
|
pass
|
|
|
|
return StreamingResponse(agen(), status_code=status, headers=headers, media_type=content_type)
|
|
|
|
|
|
ADAPTER_TYPE = "sftp"
|
|
|
|
CONFIG_SCHEMA = [
|
|
{"key": "host", "label": "主机", "type": "string", "required": True, "placeholder": "sftp.example.com"},
|
|
{"key": "port", "label": "端口", "type": "number", "required": False, "default": 22},
|
|
{"key": "username", "label": "用户名", "type": "string", "required": True},
|
|
{"key": "password", "label": "密码", "type": "password", "required": True},
|
|
{"key": "root", "label": "根路径", "type": "string", "required": True, "placeholder": "/data"},
|
|
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 15},
|
|
{"key": "allow_unknown_host", "label": "允许未知主机指纹", "type": "boolean", "required": False, "default": True},
|
|
]
|
|
|
|
|
|
def ADAPTER_FACTORY(rec: StorageAdapter):
|
|
return SFTPAdapter(rec)
|