Files
Foxel/domain/adapters/providers/sftp.py
2026-02-09 13:19:28 +08:00

474 lines
16 KiB
Python

import asyncio
import mimetypes
import stat as statmod
from typing import List, Dict, Tuple, AsyncIterator, Optional
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
import paramiko
from models import StorageAdapter
def _join_remote(root: str, rel: str) -> str:
root = (root or "/").rstrip("/") or "/"
rel = (rel or "").lstrip("/")
if not rel:
return root
return f"{root}/{rel}"
class SFTPAdapter:
def __init__(self, record: StorageAdapter):
self.record = record
cfg = record.config
self.host: str = cfg.get("host")
self.port: int = int(cfg.get("port", 22))
self.username: str | None = cfg.get("username")
self.password: str | None = cfg.get("password")
self.timeout: int = int(cfg.get("timeout", 15))
self.root_path: str = cfg.get("root") # 必填
self.allow_unknown_host: bool = bool(cfg.get("allow_unknown_host", True))
if not self.host:
raise ValueError("SFTP adapter requires 'host'")
if not self.username or not self.password:
raise ValueError("SFTP adapter requires 'username' and 'password'")
if not self.root_path:
raise ValueError("SFTP adapter requires 'root'")
def get_effective_root(self, sub_path: str | None) -> str:
base = self.root_path.rstrip("/") or "/"
if sub_path:
return _join_remote(base, sub_path)
return base
def _connect(self) -> paramiko.SFTPClient:
ssh = paramiko.SSHClient()
if self.allow_unknown_host:
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(
hostname=self.host,
port=self.port,
username=self.username,
password=self.password,
timeout=self.timeout,
allow_agent=False,
look_for_keys=False,
)
return ssh.open_sftp()
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
path = _join_remote(root, rel)
def _do_list() -> List[Dict]:
sftp = self._connect()
try:
attrs = sftp.listdir_attr(path)
entries: List[Dict] = []
for a in attrs:
name = a.filename
is_dir = statmod.S_ISDIR(a.st_mode)
entries.append({
"name": name,
"is_dir": is_dir,
"size": 0 if is_dir else int(a.st_size or 0),
"mtime": int(a.st_mtime or 0),
"type": "dir" if is_dir else "file",
})
return entries
finally:
try:
sftp.close()
except Exception:
pass
entries = await asyncio.to_thread(_do_list)
reverse = sort_order.lower() == "desc"
def get_sort_key(item):
key = (not item["is_dir"],)
f = sort_by.lower()
if f == "name":
key += (item["name"].lower(),)
elif f == "size":
key += (item.get("size", 0),)
elif f == "mtime":
key += (item.get("mtime", 0),)
else:
key += (item["name"].lower(),)
return key
entries.sort(key=get_sort_key, reverse=reverse)
total = len(entries)
start = (page_num - 1) * page_size
end = start + page_size
return entries[start:end], total
async def read_file(self, root: str, rel: str) -> bytes:
path = _join_remote(root, rel)
def _do_read() -> bytes:
sftp = self._connect()
try:
with sftp.open(path, "rb") as f:
return f.read()
except FileNotFoundError:
raise
except IOError as e:
if getattr(e, "errno", None) == 2:
raise FileNotFoundError(rel)
raise
finally:
try:
sftp.close()
except Exception:
pass
return await asyncio.to_thread(_do_read)
async def write_file(self, root: str, rel: str, data: bytes):
path = _join_remote(root, rel)
def _ensure_dirs(sftp: paramiko.SFTPClient, dir_path: str):
parts = [p for p in dir_path.strip("/").split("/") if p]
cur = "/"
for p in parts:
cur = _join_remote(cur, p)
try:
sftp.mkdir(cur)
except IOError:
# likely exists
pass
def _do_write():
sftp = self._connect()
try:
parent = "/" if "/" not in path.strip("/") else path.rsplit("/", 1)[0]
_ensure_dirs(sftp, parent)
with sftp.open(path, "wb") as f:
f.write(data)
finally:
try:
sftp.close()
except Exception:
pass
await asyncio.to_thread(_do_write)
async def write_upload_file(self, root: str, rel: str, file_obj, filename: str | None, file_size: int | None = None, content_type: str | None = None):
path = _join_remote(root, rel)
def _ensure_dirs(sftp: paramiko.SFTPClient, dir_path: str):
parts = [p for p in dir_path.strip("/").split("/") if p]
cur = "/"
for p in parts:
cur = _join_remote(cur, p)
try:
sftp.mkdir(cur)
except IOError:
pass
def _do_upload():
sftp = self._connect()
try:
parent = "/" if "/" not in path.strip("/") else path.rsplit("/", 1)[0]
_ensure_dirs(sftp, parent)
try:
if callable(getattr(file_obj, "seek", None)):
file_obj.seek(0)
except Exception:
pass
with sftp.open(path, "wb") as f:
import shutil
shutil.copyfileobj(file_obj, f)
finally:
try:
sftp.close()
except Exception:
pass
await asyncio.to_thread(_do_upload)
return {"size": file_size or 0}
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
buf = bytearray()
async for chunk in data_iter:
if chunk:
buf.extend(chunk)
await self.write_file(root, rel, bytes(buf))
return len(buf)
async def mkdir(self, root: str, rel: str):
path = _join_remote(root, rel)
def _do_mkdir():
sftp = self._connect()
try:
parts = [p for p in path.strip("/").split("/") if p]
cur = "/"
for p in parts:
cur = _join_remote(cur, p)
try:
sftp.mkdir(cur)
except IOError:
pass
finally:
try:
sftp.close()
except Exception:
pass
await asyncio.to_thread(_do_mkdir)
async def delete(self, root: str, rel: str):
path = _join_remote(root, rel)
def _do_delete():
sftp = self._connect()
try:
# Try file remove first
try:
sftp.remove(path)
return
except IOError:
pass
def _rm_tree(dp: str):
try:
for a in sftp.listdir_attr(dp):
child = _join_remote(dp, a.filename)
if statmod.S_ISDIR(a.st_mode):
_rm_tree(child)
else:
try:
sftp.remove(child)
except Exception:
pass
sftp.rmdir(dp)
except IOError:
pass
_rm_tree(path)
finally:
try:
sftp.close()
except Exception:
pass
await asyncio.to_thread(_do_delete)
async def move(self, root: str, src_rel: str, dst_rel: str):
src = _join_remote(root, src_rel)
dst = _join_remote(root, dst_rel)
def _do_move():
sftp = self._connect()
try:
# ensure dst parent exists
parent = "/" if "/" not in dst.strip("/") else dst.rsplit("/", 1)[0]
parts = [p for p in parent.strip("/").split("/") if p]
cur = "/"
for p in parts:
cur = _join_remote(cur, p)
try:
sftp.mkdir(cur)
except IOError:
pass
sftp.rename(src, dst)
finally:
try:
sftp.close()
except Exception:
pass
await asyncio.to_thread(_do_move)
async def rename(self, root: str, src_rel: str, dst_rel: str):
await self.move(root, src_rel, dst_rel)
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
src = _join_remote(root, src_rel)
dst = _join_remote(root, dst_rel)
def _is_dir() -> bool:
sftp = self._connect()
try:
st = sftp.stat(src)
return statmod.S_ISDIR(st.st_mode)
finally:
try:
sftp.close()
except Exception:
pass
if await asyncio.to_thread(_is_dir):
await self.mkdir(root, dst_rel)
children, _ = await self.list_dir(root, src_rel, page_num=1, page_size=10_000)
for ent in children:
child_src = f"{src_rel.rstrip('/')}/{ent['name']}"
child_dst = f"{dst_rel.rstrip('/')}/{ent['name']}"
await self.copy(root, child_src, child_dst, overwrite)
return
# file copy
data = await self.read_file(root, src_rel)
if not overwrite:
try:
await self.stat_file(root, dst_rel)
raise FileExistsError(dst_rel)
except FileNotFoundError:
pass
await self.write_file(root, dst_rel, data)
async def stat_file(self, root: str, rel: str):
path = _join_remote(root, rel)
def _do_stat():
sftp = self._connect()
try:
st = sftp.stat(path)
is_dir = statmod.S_ISDIR(st.st_mode)
info = {
"name": rel.split("/")[-1],
"is_dir": is_dir,
"size": 0 if is_dir else int(st.st_size or 0),
"mtime": int(st.st_mtime or 0),
"type": "dir" if is_dir else "file",
"path": path,
}
return info
except FileNotFoundError:
raise
except IOError as e:
if getattr(e, "errno", None) == 2:
raise FileNotFoundError(rel)
raise
finally:
try:
sftp.close()
except Exception:
pass
return await asyncio.to_thread(_do_stat)
async def exists(self, root: str, rel: str) -> bool:
try:
await self.stat_file(root, rel)
return True
except FileNotFoundError:
return False
except Exception:
return False
async def stream_file(self, root: str, rel: str, range_header: str | None):
path = _join_remote(root, rel)
def _get_stat():
sftp = self._connect()
try:
st = sftp.stat(path)
return int(st.st_size or 0)
finally:
try:
sftp.close()
except Exception:
pass
file_size = await asyncio.to_thread(_get_stat)
if file_size is None:
raise HTTPException(404, detail="File not found")
mime, _ = mimetypes.guess_type(rel)
content_type = mime or "application/octet-stream"
start = 0
end = file_size - 1
status = 200
headers = {
"Accept-Ranges": "bytes",
"Content-Type": content_type,
"Content-Length": str(file_size),
}
if range_header and range_header.startswith("bytes="):
try:
s, e = (range_header.removeprefix("bytes=").split("-", 1))
if s.strip():
start = int(s)
if e.strip():
end = int(e)
if start >= file_size:
raise HTTPException(416, detail="Requested Range Not Satisfiable")
if end >= file_size:
end = file_size - 1
status = 206
headers["Content-Length"] = str(end - start + 1)
headers["Content-Range"] = f"bytes {start}-{end}/{file_size}"
except ValueError:
raise HTTPException(400, detail="Invalid Range header")
queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue(maxsize=8)
def _worker():
sftp = self._connect()
try:
with sftp.open(path, "rb") as f:
f.seek(start)
remaining = end - start + 1
chunk_size = 64 * 1024
while remaining > 0:
to_read = chunk_size if remaining > chunk_size else remaining
data = f.read(to_read)
if not data:
break
try:
queue.put_nowait(data)
except Exception:
break
remaining -= len(data)
try:
queue.put_nowait(None)
except Exception:
pass
finally:
try:
sftp.close()
except Exception:
pass
async def agen():
worker_fut = asyncio.to_thread(_worker)
try:
while True:
chunk = await queue.get()
if chunk is None:
break
yield chunk
finally:
try:
await worker_fut
except Exception:
pass
return StreamingResponse(agen(), status_code=status, headers=headers, media_type=content_type)
ADAPTER_TYPE = "sftp"
CONFIG_SCHEMA = [
{"key": "host", "label": "主机", "type": "string", "required": True, "placeholder": "sftp.example.com"},
{"key": "port", "label": "端口", "type": "number", "required": False, "default": 22},
{"key": "username", "label": "用户名", "type": "string", "required": True},
{"key": "password", "label": "密码", "type": "password", "required": True},
{"key": "root", "label": "根路径", "type": "string", "required": True, "placeholder": "/data"},
{"key": "timeout", "label": "超时(秒)", "type": "number", "required": False, "default": 15},
{"key": "allow_unknown_host", "label": "允许未知主机指纹", "type": "boolean", "required": False, "default": True},
]
def ADAPTER_FACTORY(rec: StorageAdapter):
return SFTPAdapter(rec)