mirror of
https://github.com/DrizzleTime/Foxel.git
synced 2026-05-07 05:52:57 +08:00
492 lines
21 KiB
Python
492 lines
21 KiB
Python
from typing import List, Dict, Optional, Tuple, AsyncIterator
|
||
import httpx
|
||
from urllib.parse import urljoin, quote
|
||
from urllib.parse import urlparse, unquote
|
||
import xml.etree.ElementTree as ET
|
||
from models import StorageAdapter
|
||
import mimetypes
|
||
import logging
|
||
from fastapi import HTTPException
|
||
from fastapi.responses import StreamingResponse, Response
|
||
|
||
NS = {"d": "DAV:"}
|
||
|
||
|
||
class WebDAVAdapter:
|
||
def __init__(self, record: StorageAdapter):
|
||
self.record = record
|
||
cfg = record.config
|
||
self.base_url: str = cfg.get("base_url", "").rstrip('/') + '/'
|
||
if not self.base_url.startswith("http"):
|
||
raise ValueError("webdav requires base_url http/https")
|
||
self.username = cfg.get("username")
|
||
self.password = cfg.get("password")
|
||
self.timeout = cfg.get("timeout", 15)
|
||
|
||
def get_effective_root(self, sub_path: str | None) -> str:
|
||
base_url = self.record.config.get("base_url", "").rstrip('/') + '/'
|
||
if sub_path:
|
||
return base_url + sub_path.strip('/') + '/'
|
||
return base_url
|
||
|
||
def _client(self):
|
||
auth = (self.username, self.password) if self.username else None
|
||
return httpx.AsyncClient(auth=auth, timeout=self.timeout, follow_redirects=True)
|
||
|
||
def _build_url(self, rel: str):
|
||
rel = rel.strip('/')
|
||
return self.base_url if not rel else urljoin(self.base_url, quote(rel) + ('/' if rel.endswith('/') else ''))
|
||
|
||
async def list_dir(self, root: str, rel: str, page_num: int = 1, page_size: int = 50, sort_by: str = "name", sort_order: str = "asc") -> Tuple[List[Dict], int]:
|
||
raw_url = self._build_url(rel)
|
||
url = raw_url if raw_url.endswith('/') else raw_url + '/'
|
||
depth = "1"
|
||
body = """<?xml version="1.0" encoding="utf-8" ?>
|
||
<d:propfind xmlns:d="DAV:">
|
||
<d:prop>
|
||
<d:displayname />
|
||
<d:getcontentlength />
|
||
<d:getlastmodified />
|
||
<d:resourcetype />
|
||
</d:prop>
|
||
</d:propfind>"""
|
||
async with self._client() as client:
|
||
resp = await client.request("PROPFIND", url, data=body, headers={"Depth": depth})
|
||
resp.raise_for_status()
|
||
xml_text = resp.text
|
||
root_el = ET.fromstring(xml_text)
|
||
all_entries: List[Dict] = []
|
||
parsed_req = urlparse(url)
|
||
base_path = parsed_req.path
|
||
if not base_path.endswith('/'):
|
||
base_path += '/'
|
||
seen = set()
|
||
for resp_el in root_el.findall("d:response", NS):
|
||
href_el = resp_el.find("d:href", NS)
|
||
if href_el is None:
|
||
continue
|
||
href = (href_el.text or "")
|
||
parsed_href = urlparse(href)
|
||
href_path = parsed_href.path or ""
|
||
if not href_path.startswith(base_path):
|
||
continue
|
||
rel_path = href_path[len(base_path):].strip('/')
|
||
if rel_path == "":
|
||
continue
|
||
name = unquote(rel_path.split('/')[0]).rstrip('/')
|
||
if not name or name in seen:
|
||
continue
|
||
seen.add(name)
|
||
propstat = resp_el.find("d:propstat", NS)
|
||
if propstat is None:
|
||
continue
|
||
prop = propstat.find("d:prop", NS)
|
||
if prop is None:
|
||
continue
|
||
size_el = prop.find("d:getcontentlength", NS)
|
||
lm_el = prop.find("d:getlastmodified", NS)
|
||
rt_el = prop.find("d:resourcetype", NS)
|
||
is_dir = rt_el.find(
|
||
"d:collection", NS) is not None if rt_el is not None else href_path.endswith('/')
|
||
size = int(
|
||
size_el.text) if size_el is not None and size_el.text and size_el.text.isdigit() else 0
|
||
|
||
from email.utils import parsedate_to_datetime
|
||
mtime = 0
|
||
if lm_el is not None and lm_el.text:
|
||
try:
|
||
mtime = int(parsedate_to_datetime(lm_el.text).timestamp())
|
||
except Exception:
|
||
mtime = 0
|
||
|
||
all_entries.append({
|
||
"name": name,
|
||
"is_dir": is_dir,
|
||
"size": 0 if is_dir else size,
|
||
"mtime": mtime,
|
||
"type": "dir" if is_dir else "file",
|
||
})
|
||
|
||
# 排序所有条目
|
||
reverse = sort_order.lower() == "desc"
|
||
def get_sort_key(item):
|
||
key = (not item["is_dir"],)
|
||
sort_field = sort_by.lower()
|
||
if sort_field == "name":
|
||
key += (item["name"].lower(),)
|
||
elif sort_field == "size":
|
||
key += (item["size"],)
|
||
elif sort_field == "mtime":
|
||
key += (item["mtime"],)
|
||
else:
|
||
key += (item["name"].lower(),)
|
||
return key
|
||
all_entries.sort(key=get_sort_key, reverse=reverse)
|
||
|
||
total_count = len(all_entries)
|
||
|
||
# 应用分页
|
||
start_idx = (page_num - 1) * page_size
|
||
end_idx = start_idx + page_size
|
||
page_entries = all_entries[start_idx:end_idx]
|
||
|
||
return page_entries, total_count
|
||
|
||
async def read_file(self, root: str, rel: str) -> bytes:
|
||
url = self._build_url(rel)
|
||
async with self._client() as client:
|
||
resp = await client.get(url)
|
||
if resp.status_code == 404:
|
||
raise FileNotFoundError(rel)
|
||
resp.raise_for_status()
|
||
return resp.content
|
||
|
||
async def write_file(self, root: str, rel: str, data: bytes):
|
||
url = self._build_url(rel)
|
||
async with self._client() as client:
|
||
resp = await client.put(url, content=data)
|
||
resp.raise_for_status()
|
||
|
||
async def mkdir(self, root: str, rel: str):
|
||
url = self._build_url(rel.rstrip('/') + '/')
|
||
async with self._client() as client:
|
||
resp = await client.request("MKCOL", url)
|
||
if resp.status_code not in (201, 405):
|
||
resp.raise_for_status()
|
||
|
||
async def delete(self, root: str, rel: str):
|
||
url = self._build_url(rel)
|
||
async with self._client() as client:
|
||
resp = await client.delete(url)
|
||
if resp.status_code not in (204, 200, 404):
|
||
resp.raise_for_status()
|
||
|
||
async def move(self, root: str, src_rel: str, dst_rel: str):
|
||
src_url = self._build_url(src_rel)
|
||
dst_url = self._build_url(dst_rel)
|
||
async with self._client() as client:
|
||
resp = await client.request("MOVE", src_url, headers={"Destination": dst_url})
|
||
resp.raise_for_status()
|
||
|
||
async def rename(self, root: str, src_rel: str, dst_rel: str):
|
||
src_url = self._build_url(src_rel)
|
||
dst_url = self._build_url(dst_rel)
|
||
async with self._client() as client:
|
||
resp = await client.request("MOVE", src_url, headers={"Destination": dst_url})
|
||
resp.raise_for_status()
|
||
|
||
async def get_file_size(self, root: str, rel: str) -> int:
|
||
"""获取文件大小"""
|
||
url = self._build_url(rel)
|
||
async with self._client() as client:
|
||
# 使用HEAD请求获取文件信息
|
||
resp = await client.head(url)
|
||
if resp.status_code == 404:
|
||
raise FileNotFoundError(rel)
|
||
resp.raise_for_status()
|
||
|
||
content_length = resp.headers.get('content-length')
|
||
if content_length:
|
||
return int(content_length)
|
||
|
||
# 如果HEAD不返回content-length,尝试PROPFIND
|
||
body = """<?xml version="1.0" encoding="utf-8" ?>
|
||
<d:propfind xmlns:d="DAV:">
|
||
<d:prop>
|
||
<d:getcontentlength />
|
||
</d:prop>
|
||
</d:propfind>"""
|
||
resp = await client.request("PROPFIND", url, data=body, headers={"Depth": "0"})
|
||
resp.raise_for_status()
|
||
|
||
root_el = ET.fromstring(resp.text)
|
||
for resp_el in root_el.findall("d:response", NS):
|
||
propstat = resp_el.find("d:propstat", NS)
|
||
if propstat is None:
|
||
continue
|
||
prop = propstat.find("d:prop", NS)
|
||
if prop is None:
|
||
continue
|
||
size_el = prop.find("d:getcontentlength", NS)
|
||
if size_el is not None and size_el.text and size_el.text.isdigit():
|
||
return int(size_el.text)
|
||
|
||
return 0
|
||
|
||
async def read_file_range(self, root: str, rel: str, start: int, end: Optional[int] = None) -> bytes:
|
||
"""读取文件的指定范围"""
|
||
url = self._build_url(rel)
|
||
|
||
# 构建Range头
|
||
if end is None:
|
||
range_header = f"bytes={start}-"
|
||
else:
|
||
range_header = f"bytes={start}-{end}"
|
||
|
||
async with self._client() as client:
|
||
resp = await client.get(url, headers={"Range": range_header})
|
||
if resp.status_code == 404:
|
||
raise FileNotFoundError(rel)
|
||
if resp.status_code not in (200, 206): # 206是Partial Content
|
||
resp.raise_for_status()
|
||
return resp.content
|
||
|
||
async def stream_file(self, root: str, rel: str, range_header: str | None):
|
||
url = self._build_url(rel)
|
||
mime, _ = mimetypes.guess_type(rel)
|
||
content_type = mime or "application/octet-stream"
|
||
logger = logging.getLogger(__name__)
|
||
timeout = self.timeout
|
||
auth = (self.username, self.password) if self.username else None
|
||
|
||
client_start = 0
|
||
client_end = None
|
||
status_code = 200
|
||
if range_header and range_header.startswith("bytes="):
|
||
status_code = 206
|
||
part = range_header.removeprefix("bytes=")
|
||
s, e = part.split("-", 1)
|
||
if s.strip():
|
||
client_start = int(s)
|
||
if e.strip():
|
||
client_end = int(e)
|
||
|
||
total_size = None
|
||
accept_ranges = False
|
||
async with httpx.AsyncClient(timeout=timeout, auth=auth, follow_redirects=True) as client:
|
||
try:
|
||
head_resp = await client.head(url)
|
||
if head_resp.status_code == 404:
|
||
raise HTTPException(404, detail="File not found")
|
||
if head_resp.status_code == 200:
|
||
cl = head_resp.headers.get("Content-Length")
|
||
if cl and cl.isdigit():
|
||
total_size = int(cl)
|
||
ar = head_resp.headers.get("Accept-Ranges", "").lower()
|
||
accept_ranges = "bytes" in ar
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.debug("HEAD failed %s err=%s", url, e)
|
||
if total_size is None and (client_end is None):
|
||
try:
|
||
probe_req = client.build_request("GET", url, headers={"Range": "bytes=0-0"})
|
||
probe_resp = await client.send(probe_req, stream=True)
|
||
if probe_resp.status_code in (200, 206):
|
||
cr = probe_resp.headers.get("Content-Range")
|
||
if cr and "/" in cr:
|
||
try:
|
||
total_size = int(cr.rsplit("/", 1)[1])
|
||
except Exception:
|
||
pass
|
||
await probe_resp.aclose()
|
||
except Exception as e:
|
||
logger.debug("Probe 0-0 failed %s err=%s", url, e)
|
||
|
||
if total_size is not None and client_end is None:
|
||
client_end = total_size - 1
|
||
if client_end is not None and client_end < client_start:
|
||
raise HTTPException(416, detail="Requested Range Not Satisfiable")
|
||
|
||
# 若客户端未请求范围且上游不支持 Range,直接透传
|
||
if status_code == 200 and (range_header is None) and not accept_ranges:
|
||
async with httpx.AsyncClient(timeout=timeout, auth=auth, follow_redirects=True) as client:
|
||
req = client.build_request("GET", url)
|
||
resp = await client.send(req, stream=True)
|
||
if resp.status_code == 404:
|
||
await resp.aclose()
|
||
raise HTTPException(404, detail="File not found")
|
||
upstream_ct = resp.headers.get("Content-Type", content_type)
|
||
|
||
async def passthrough():
|
||
try:
|
||
async for chunk in resp.aiter_bytes():
|
||
if chunk:
|
||
yield chunk
|
||
finally:
|
||
await resp.aclose()
|
||
return StreamingResponse(passthrough(), status_code=resp.status_code,
|
||
headers={"Accept-Ranges": "bytes",
|
||
"X-VFS-Remote-Status": str(resp.status_code)},
|
||
media_type=upstream_ct)
|
||
|
||
SEGMENT_SIZE = 5 * 1024 * 1024
|
||
MAX_RETRY_PER_SEG = 3
|
||
FIRST_BYTE_MAX_RETRY = 3
|
||
|
||
resp_headers = {
|
||
"Accept-Ranges": "bytes",
|
||
"Content-Type": content_type,
|
||
"X-VFS-Segmented": "1",
|
||
}
|
||
if status_code == 206 and total_size is not None:
|
||
resp_headers["Content-Range"] = f"bytes {client_start}-{client_end}/{total_size}"
|
||
|
||
async def segmented_body():
|
||
current = client_start
|
||
first_byte_sent = False
|
||
while True:
|
||
if client_end is not None and current > client_end:
|
||
break
|
||
seg_start = current
|
||
seg_end = (min(seg_start + SEGMENT_SIZE - 1, client_end)
|
||
if client_end is not None else seg_start + SEGMENT_SIZE - 1)
|
||
attempt = 0
|
||
ok = False
|
||
while attempt < MAX_RETRY_PER_SEG and not ok:
|
||
attempt += 1
|
||
headers_req = {"Range": f"bytes={seg_start}-{seg_end}"}
|
||
try:
|
||
async with httpx.AsyncClient(timeout=timeout, auth=auth, follow_redirects=True) as cseg:
|
||
req = cseg.build_request("GET", url, headers=headers_req)
|
||
rseg = await cseg.send(req, stream=True)
|
||
if rseg.status_code in (200, 206):
|
||
async for chunk in rseg.aiter_bytes():
|
||
if chunk:
|
||
first_byte_sent = True
|
||
yield chunk
|
||
await rseg.aclose()
|
||
ok = True
|
||
elif rseg.status_code == 404:
|
||
await rseg.aclose()
|
||
if not first_byte_sent:
|
||
raise HTTPException(404, detail="File not found")
|
||
return
|
||
else:
|
||
await rseg.aclose()
|
||
logger.warning("Segment unexpected status %s %s-%s %s", rel, seg_start, seg_end, rseg.status_code)
|
||
if not ok:
|
||
continue
|
||
except (httpx.ReadError, httpx.HTTPError, httpx.StreamError) as e:
|
||
if not first_byte_sent and attempt >= FIRST_BYTE_MAX_RETRY:
|
||
raise HTTPException(502, detail=f"Upstream error before first byte err={e}")
|
||
logger.warning("Segment error %s %s-%s attempt=%d err=%s", rel, seg_start, seg_end, attempt, e)
|
||
except Exception as e:
|
||
if not first_byte_sent:
|
||
raise
|
||
logger.error("Segment unexpected %s %s-%s attempt=%d err=%s", rel, seg_start, seg_end, attempt, e)
|
||
if not ok:
|
||
logger.error("Abort streaming %s at %s-%s", rel, seg_start, seg_end)
|
||
break
|
||
current = seg_end + 1
|
||
if client_end is None:
|
||
continue
|
||
if current > client_end:
|
||
break
|
||
|
||
return StreamingResponse(segmented_body(), status_code=status_code, headers=resp_headers, media_type=content_type)
|
||
|
||
async def stat_file(self, root: str, rel: str, include_metadata: bool = False):
|
||
url = self._build_url(rel)
|
||
async with self._client() as client:
|
||
# PROPFIND 获取属性
|
||
body = """<?xml version="1.0" encoding="utf-8" ?>
|
||
<d:propfind xmlns:d="DAV:">
|
||
<d:prop>
|
||
<d:getcontentlength />
|
||
<d:getlastmodified />
|
||
<d:resourcetype />
|
||
</d:prop>
|
||
</d:propfind>"""
|
||
resp = await client.request("PROPFIND", url, data=body, headers={"Depth": "0"})
|
||
if resp.status_code == 404:
|
||
raise FileNotFoundError(rel)
|
||
resp.raise_for_status()
|
||
root_el = ET.fromstring(resp.text)
|
||
info = {
|
||
"name": rel.split("/")[-1],
|
||
"is_dir": False,
|
||
"size": None,
|
||
"mtime": None,
|
||
"type": "file",
|
||
"path": url,
|
||
}
|
||
for resp_el in root_el.findall("d:response", NS):
|
||
propstat = resp_el.find("d:propstat", NS)
|
||
if propstat is None:
|
||
continue
|
||
prop = propstat.find("d:prop", NS)
|
||
if prop is None:
|
||
continue
|
||
size_el = prop.find("d:getcontentlength", NS)
|
||
lm_el = prop.find("d:getlastmodified", NS)
|
||
rt_el = prop.find("d:resourcetype", NS)
|
||
is_dir = rt_el.find("d:collection", NS) is not None if rt_el is not None else False
|
||
info["is_dir"] = is_dir
|
||
info["type"] = "dir" if is_dir else "file"
|
||
if size_el is not None and size_el.text and size_el.text.isdigit():
|
||
info["size"] = int(size_el.text)
|
||
elif info["size"] is None:
|
||
info["size"] = 0
|
||
if lm_el is not None and lm_el.text:
|
||
from email.utils import parsedate_to_datetime
|
||
try:
|
||
info["mtime"] = int(parsedate_to_datetime(lm_el.text).timestamp())
|
||
except Exception:
|
||
info["mtime"] = 0
|
||
elif info["mtime"] is None:
|
||
info["mtime"] = 0
|
||
if include_metadata and not info["is_dir"]:
|
||
exif = None
|
||
mime, _ = mimetypes.guess_type(info["name"])
|
||
if mime and mime.startswith("image/"):
|
||
try:
|
||
resp_img = await client.get(url)
|
||
if resp_img.status_code == 200:
|
||
from PIL import Image
|
||
from io import BytesIO
|
||
img = Image.open(BytesIO(resp_img.content))
|
||
exif_data = img._getexif()
|
||
if exif_data:
|
||
exif = {str(k): str(v) for k, v in exif_data.items()}
|
||
except Exception:
|
||
exif = None
|
||
info["exif"] = exif
|
||
return info
|
||
|
||
async def exists(self, root: str, rel: str) -> bool:
|
||
url = self._build_url(rel)
|
||
async with self._client() as client:
|
||
try:
|
||
r = await client.head(url)
|
||
return r.status_code in (200, 204)
|
||
except Exception:
|
||
return False
|
||
|
||
async def write_file_stream(self, root: str, rel: str, data_iter: AsyncIterator[bytes]):
|
||
url = self._build_url(rel)
|
||
async def agen():
|
||
async for chunk in data_iter:
|
||
if chunk:
|
||
yield chunk
|
||
async with self._client() as client:
|
||
resp = await client.put(url, content=agen())
|
||
resp.raise_for_status()
|
||
return True
|
||
|
||
async def copy(self, root: str, src_rel: str, dst_rel: str, overwrite: bool = False):
|
||
src_url = self._build_url(src_rel)
|
||
dst_url = self._build_url(dst_rel)
|
||
headers = {
|
||
"Destination": dst_url,
|
||
"Overwrite": "T" if overwrite else "F"
|
||
}
|
||
async with self._client() as client:
|
||
resp = await client.request("COPY", src_url, headers=headers)
|
||
if resp.status_code == 412:
|
||
raise FileExistsError(dst_rel)
|
||
if resp.status_code == 404:
|
||
raise FileNotFoundError(src_rel)
|
||
resp.raise_for_status()
|
||
|
||
ADAPTER_TYPE = "webdav"
|
||
CONFIG_SCHEMA = [
|
||
{"key": "base_url", "label": "基础地址", "type": "string",
|
||
"required": True, "placeholder": "https://example.com/dav/"},
|
||
{"key": "username", "label": "用户名", "type": "string", "required": False},
|
||
{"key": "password", "label": "密码", "type": "password", "required": False},
|
||
{"key": "timeout",
|
||
"label": "超时(秒)", "type": "number", "required": False, "default": 15},
|
||
]
|
||
def ADAPTER_FACTORY(rec): return WebDAVAdapter(rec)
|