mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-05-13 17:29:43 +08:00
2
This commit is contained in:
32
src/core/__init__.py
Normal file
32
src/core/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
核心功能模块
|
||||
"""
|
||||
|
||||
from .oauth import OAuthManager, OAuthStart, generate_oauth_url, submit_callback_url
|
||||
from .http_client import (
|
||||
OpenAIHTTPClient,
|
||||
HTTPClient,
|
||||
HTTPClientError,
|
||||
RequestConfig,
|
||||
create_http_client,
|
||||
create_openai_client,
|
||||
)
|
||||
from .register import RegistrationEngine, RegistrationResult
|
||||
from .utils import setup_logging, get_data_dir
|
||||
|
||||
__all__ = [
|
||||
'OAuthManager',
|
||||
'OAuthStart',
|
||||
'generate_oauth_url',
|
||||
'submit_callback_url',
|
||||
'OpenAIHTTPClient',
|
||||
'HTTPClient',
|
||||
'HTTPClientError',
|
||||
'RequestConfig',
|
||||
'create_http_client',
|
||||
'create_openai_client',
|
||||
'RegistrationEngine',
|
||||
'RegistrationResult',
|
||||
'setup_logging',
|
||||
'get_data_dir',
|
||||
]
|
||||
420
src/core/http_client.py
Normal file
420
src/core/http_client.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""
|
||||
HTTP 客户端封装
|
||||
基于 curl_cffi 的 HTTP 请求封装,支持代理和错误处理
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
from typing import Optional, Dict, Any, Union, Tuple
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
from curl_cffi import requests as cffi_requests
|
||||
from curl_cffi.requests import Session, Response
|
||||
|
||||
from ..config.constants import ERROR_MESSAGES
|
||||
from ..config.settings import get_settings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestConfig:
|
||||
"""HTTP 请求配置"""
|
||||
timeout: int = 30
|
||||
max_retries: int = 3
|
||||
retry_delay: float = 1.0
|
||||
impersonate: str = "chrome"
|
||||
verify_ssl: bool = True
|
||||
follow_redirects: bool = True
|
||||
|
||||
|
||||
class HTTPClientError(Exception):
|
||||
"""HTTP 客户端异常"""
|
||||
pass
|
||||
|
||||
|
||||
class HTTPClient:
|
||||
"""
|
||||
HTTP 客户端封装
|
||||
支持代理、重试、错误处理和会话管理
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: Optional[str] = None,
|
||||
config: Optional[RequestConfig] = None,
|
||||
session: Optional[Session] = None
|
||||
):
|
||||
"""
|
||||
初始化 HTTP 客户端
|
||||
|
||||
Args:
|
||||
proxy_url: 代理 URL,如 "http://127.0.0.1:7890"
|
||||
config: 请求配置
|
||||
session: 可重用的会话对象
|
||||
"""
|
||||
self.proxy_url = proxy_url
|
||||
self.config = config or RequestConfig()
|
||||
self._session = session
|
||||
|
||||
@property
|
||||
def proxies(self) -> Optional[Dict[str, str]]:
|
||||
"""获取代理配置"""
|
||||
if not self.proxy_url:
|
||||
return None
|
||||
return {
|
||||
"http": self.proxy_url,
|
||||
"https": self.proxy_url,
|
||||
}
|
||||
|
||||
@property
|
||||
def session(self) -> Session:
|
||||
"""获取会话对象(单例)"""
|
||||
if self._session is None:
|
||||
self._session = Session(
|
||||
proxies=self.proxies,
|
||||
impersonate=self.config.impersonate,
|
||||
verify=self.config.verify_ssl,
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
return self._session
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
**kwargs
|
||||
) -> Response:
|
||||
"""
|
||||
发送 HTTP 请求
|
||||
|
||||
Args:
|
||||
method: HTTP 方法 (GET, POST, PUT, DELETE, etc.)
|
||||
url: 请求 URL
|
||||
**kwargs: 其他请求参数
|
||||
|
||||
Returns:
|
||||
Response 对象
|
||||
|
||||
Raises:
|
||||
HTTPClientError: 请求失败
|
||||
"""
|
||||
# 设置默认参数
|
||||
kwargs.setdefault("timeout", self.config.timeout)
|
||||
kwargs.setdefault("allow_redirects", self.config.follow_redirects)
|
||||
|
||||
# 添加代理配置
|
||||
if self.proxies and "proxies" not in kwargs:
|
||||
kwargs["proxies"] = self.proxies
|
||||
|
||||
last_exception = None
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
response = self.session.request(method, url, **kwargs)
|
||||
|
||||
# 检查响应状态码
|
||||
if response.status_code >= 400:
|
||||
logger.warning(
|
||||
f"HTTP {response.status_code} for {method} {url}"
|
||||
f" (attempt {attempt + 1}/{self.config.max_retries})"
|
||||
)
|
||||
|
||||
# 如果是服务器错误,重试
|
||||
if response.status_code >= 500 and attempt < self.config.max_retries - 1:
|
||||
time.sleep(self.config.retry_delay * (attempt + 1))
|
||||
continue
|
||||
|
||||
return response
|
||||
|
||||
except (cffi_requests.RequestsError, ConnectionError, TimeoutError) as e:
|
||||
last_exception = e
|
||||
logger.warning(
|
||||
f"请求失败: {method} {url} (attempt {attempt + 1}/{self.config.max_retries}): {e}"
|
||||
)
|
||||
|
||||
if attempt < self.config.max_retries - 1:
|
||||
time.sleep(self.config.retry_delay * (attempt + 1))
|
||||
else:
|
||||
break
|
||||
|
||||
raise HTTPClientError(
|
||||
f"请求失败,最大重试次数已达: {method} {url} - {last_exception}"
|
||||
)
|
||||
|
||||
def get(self, url: str, **kwargs) -> Response:
|
||||
"""发送 GET 请求"""
|
||||
return self.request("GET", url, **kwargs)
|
||||
|
||||
def post(self, url: str, data: Any = None, json: Any = None, **kwargs) -> Response:
|
||||
"""发送 POST 请求"""
|
||||
return self.request("POST", url, data=data, json=json, **kwargs)
|
||||
|
||||
def put(self, url: str, data: Any = None, json: Any = None, **kwargs) -> Response:
|
||||
"""发送 PUT 请求"""
|
||||
return self.request("PUT", url, data=data, json=json, **kwargs)
|
||||
|
||||
def delete(self, url: str, **kwargs) -> Response:
|
||||
"""发送 DELETE 请求"""
|
||||
return self.request("DELETE", url, **kwargs)
|
||||
|
||||
def head(self, url: str, **kwargs) -> Response:
|
||||
"""发送 HEAD 请求"""
|
||||
return self.request("HEAD", url, **kwargs)
|
||||
|
||||
def options(self, url: str, **kwargs) -> Response:
|
||||
"""发送 OPTIONS 请求"""
|
||||
return self.request("OPTIONS", url, **kwargs)
|
||||
|
||||
def patch(self, url: str, data: Any = None, json: Any = None, **kwargs) -> Response:
|
||||
"""发送 PATCH 请求"""
|
||||
return self.request("PATCH", url, data=data, json=json, **kwargs)
|
||||
|
||||
def download_file(self, url: str, filepath: str, chunk_size: int = 8192) -> None:
|
||||
"""
|
||||
下载文件
|
||||
|
||||
Args:
|
||||
url: 文件 URL
|
||||
filepath: 保存路径
|
||||
chunk_size: 块大小
|
||||
|
||||
Raises:
|
||||
HTTPClientError: 下载失败
|
||||
"""
|
||||
try:
|
||||
response = self.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(filepath, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=chunk_size):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPClientError(f"下载文件失败: {url} - {e}")
|
||||
|
||||
def check_proxy(self, test_url: str = "https://httpbin.org/ip") -> bool:
|
||||
"""
|
||||
检查代理是否可用
|
||||
|
||||
Args:
|
||||
test_url: 测试 URL
|
||||
|
||||
Returns:
|
||||
bool: 代理是否可用
|
||||
"""
|
||||
if not self.proxy_url:
|
||||
return False
|
||||
|
||||
try:
|
||||
response = self.get(test_url, timeout=10)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
"""关闭会话"""
|
||||
if self._session:
|
||||
self._session.close()
|
||||
self._session = None
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
|
||||
class OpenAIHTTPClient(HTTPClient):
|
||||
"""
|
||||
OpenAI 专用 HTTP 客户端
|
||||
包含 OpenAI API 特定的请求方法
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: Optional[str] = None,
|
||||
config: Optional[RequestConfig] = None
|
||||
):
|
||||
"""
|
||||
初始化 OpenAI HTTP 客户端
|
||||
|
||||
Args:
|
||||
proxy_url: 代理 URL
|
||||
config: 请求配置
|
||||
"""
|
||||
super().__init__(proxy_url, config)
|
||||
|
||||
# OpenAI 特定的默认配置
|
||||
if config is None:
|
||||
self.config.timeout = 30
|
||||
self.config.max_retries = 3
|
||||
|
||||
# 默认请求头
|
||||
self.default_headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
|
||||
"(KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
|
||||
"Accept": "application/json",
|
||||
"Accept-Language": "en-US,en;q=0.9",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
"Connection": "keep-alive",
|
||||
"Sec-Fetch-Dest": "empty",
|
||||
"Sec-Fetch-Mode": "cors",
|
||||
"Sec-Fetch-Site": "same-site",
|
||||
}
|
||||
|
||||
def check_ip_location(self) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
检查 IP 地理位置
|
||||
|
||||
Returns:
|
||||
Tuple[是否支持, 位置信息]
|
||||
"""
|
||||
try:
|
||||
response = self.get("https://cloudflare.com/cdn-cgi/trace", timeout=10)
|
||||
trace_text = response.text
|
||||
|
||||
# 解析位置信息
|
||||
import re
|
||||
loc_match = re.search(r"loc=([A-Z]+)", trace_text)
|
||||
loc = loc_match.group(1) if loc_match else None
|
||||
|
||||
# 检查是否支持
|
||||
if loc in ["CN", "HK", "MO", "TW"]:
|
||||
return False, loc
|
||||
return True, loc
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查 IP 地理位置失败: {e}")
|
||||
return False, None
|
||||
|
||||
def send_openai_request(
|
||||
self,
|
||||
endpoint: str,
|
||||
method: str = "POST",
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送 OpenAI API 请求
|
||||
|
||||
Args:
|
||||
endpoint: API 端点
|
||||
method: HTTP 方法
|
||||
data: 表单数据
|
||||
json_data: JSON 数据
|
||||
headers: 请求头
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
响应 JSON 数据
|
||||
|
||||
Raises:
|
||||
HTTPClientError: 请求失败
|
||||
"""
|
||||
# 合并请求头
|
||||
request_headers = self.default_headers.copy()
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
# 设置 Content-Type
|
||||
if json_data is not None and "Content-Type" not in request_headers:
|
||||
request_headers["Content-Type"] = "application/json"
|
||||
elif data is not None and "Content-Type" not in request_headers:
|
||||
request_headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
|
||||
try:
|
||||
response = self.request(
|
||||
method,
|
||||
endpoint,
|
||||
data=data,
|
||||
json=json_data,
|
||||
headers=request_headers,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# 检查响应状态码
|
||||
response.raise_for_status()
|
||||
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
return response.json()
|
||||
except json.JSONDecodeError:
|
||||
return {"raw_response": response.text}
|
||||
|
||||
except cffi_requests.RequestsError as e:
|
||||
raise HTTPClientError(f"OpenAI 请求失败: {endpoint} - {e}")
|
||||
|
||||
def check_sentinel(self, did: str, proxies: Optional[Dict] = None) -> Optional[str]:
|
||||
"""
|
||||
检查 Sentinel 拦截
|
||||
|
||||
Args:
|
||||
did: Device ID
|
||||
proxies: 代理配置
|
||||
|
||||
Returns:
|
||||
Sentinel token 或 None
|
||||
"""
|
||||
from ..config.constants import OPENAI_API_ENDPOINTS
|
||||
|
||||
try:
|
||||
sen_req_body = f'{{"p":"","id":"{did}","flow":"authorize_continue"}}'
|
||||
|
||||
response = self.post(
|
||||
OPENAI_API_ENDPOINTS["sentinel"],
|
||||
headers={
|
||||
"origin": "https://sentinel.openai.com",
|
||||
"referer": "https://sentinel.openai.com/backend-api/sentinel/frame.html?sv=20260219f9f6",
|
||||
"content-type": "text/plain;charset=UTF-8",
|
||||
},
|
||||
data=sen_req_body,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json().get("token")
|
||||
else:
|
||||
logger.warning(f"Sentinel 检查失败: {response.status_code}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Sentinel 检查异常: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def create_http_client(
|
||||
proxy_url: Optional[str] = None,
|
||||
config: Optional[RequestConfig] = None
|
||||
) -> HTTPClient:
|
||||
"""
|
||||
创建 HTTP 客户端工厂函数
|
||||
|
||||
Args:
|
||||
proxy_url: 代理 URL
|
||||
config: 请求配置
|
||||
|
||||
Returns:
|
||||
HTTPClient 实例
|
||||
"""
|
||||
return HTTPClient(proxy_url, config)
|
||||
|
||||
|
||||
def create_openai_client(
|
||||
proxy_url: Optional[str] = None,
|
||||
config: Optional[RequestConfig] = None
|
||||
) -> OpenAIHTTPClient:
|
||||
"""
|
||||
创建 OpenAI HTTP 客户端工厂函数
|
||||
|
||||
Args:
|
||||
proxy_url: 代理 URL
|
||||
config: 请求配置
|
||||
|
||||
Returns:
|
||||
OpenAIHTTPClient 实例
|
||||
"""
|
||||
return OpenAIHTTPClient(proxy_url, config)
|
||||
336
src/core/oauth.py
Normal file
336
src/core/oauth.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""
|
||||
OpenAI OAuth 授权模块
|
||||
从 main.py 中提取的 OAuth 相关函数
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import secrets
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ..config.constants import (
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_AUTH_URL,
|
||||
OAUTH_TOKEN_URL,
|
||||
OAUTH_REDIRECT_URI,
|
||||
OAUTH_SCOPE,
|
||||
)
|
||||
|
||||
|
||||
def _b64url_no_pad(raw: bytes) -> str:
|
||||
"""Base64 URL 编码(无填充)"""
|
||||
return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")
|
||||
|
||||
|
||||
def _sha256_b64url_no_pad(s: str) -> str:
|
||||
"""SHA256 哈希后 Base64 URL 编码"""
|
||||
return _b64url_no_pad(hashlib.sha256(s.encode("ascii")).digest())
|
||||
|
||||
|
||||
def _random_state(nbytes: int = 16) -> str:
|
||||
"""生成随机 state"""
|
||||
return secrets.token_urlsafe(nbytes)
|
||||
|
||||
|
||||
def _pkce_verifier() -> str:
|
||||
"""生成 PKCE code_verifier"""
|
||||
return secrets.token_urlsafe(64)
|
||||
|
||||
|
||||
def _parse_callback_url(callback_url: str) -> Dict[str, str]:
|
||||
"""解析回调 URL"""
|
||||
candidate = callback_url.strip()
|
||||
if not candidate:
|
||||
return {"code": "", "state": "", "error": "", "error_description": ""}
|
||||
|
||||
if "://" not in candidate:
|
||||
if candidate.startswith("?"):
|
||||
candidate = f"http://localhost{candidate}"
|
||||
elif any(ch in candidate for ch in "/?#") or ":" in candidate:
|
||||
candidate = f"http://{candidate}"
|
||||
elif "=" in candidate:
|
||||
candidate = f"http://localhost/?{candidate}"
|
||||
|
||||
parsed = urllib.parse.urlparse(candidate)
|
||||
query = urllib.parse.parse_qs(parsed.query, keep_blank_values=True)
|
||||
fragment = urllib.parse.parse_qs(parsed.fragment, keep_blank_values=True)
|
||||
|
||||
for key, values in fragment.items():
|
||||
if key not in query or not query[key] or not (query[key][0] or "").strip():
|
||||
query[key] = values
|
||||
|
||||
def get1(k: str) -> str:
|
||||
v = query.get(k, [""])
|
||||
return (v[0] or "").strip()
|
||||
|
||||
code = get1("code")
|
||||
state = get1("state")
|
||||
error = get1("error")
|
||||
error_description = get1("error_description")
|
||||
|
||||
if code and not state and "#" in code:
|
||||
code, state = code.split("#", 1)
|
||||
|
||||
if not error and error_description:
|
||||
error, error_description = error_description, ""
|
||||
|
||||
return {
|
||||
"code": code,
|
||||
"state": state,
|
||||
"error": error,
|
||||
"error_description": error_description,
|
||||
}
|
||||
|
||||
|
||||
def _jwt_claims_no_verify(id_token: str) -> Dict[str, Any]:
|
||||
"""解析 JWT ID Token(不验证签名)"""
|
||||
if not id_token or id_token.count(".") < 2:
|
||||
return {}
|
||||
payload_b64 = id_token.split(".")[1]
|
||||
pad = "=" * ((4 - (len(payload_b64) % 4)) % 4)
|
||||
try:
|
||||
payload = base64.urlsafe_b64decode((payload_b64 + pad).encode("ascii"))
|
||||
return json.loads(payload.decode("utf-8"))
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _decode_jwt_segment(seg: str) -> Dict[str, Any]:
|
||||
"""解码 JWT 片段"""
|
||||
raw = (seg or "").strip()
|
||||
if not raw:
|
||||
return {}
|
||||
pad = "=" * ((4 - (len(raw) % 4)) % 4)
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode((raw + pad).encode("ascii"))
|
||||
return json.loads(decoded.decode("utf-8"))
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _to_int(v: Any) -> int:
|
||||
"""转换为整数"""
|
||||
try:
|
||||
return int(v)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
|
||||
def _post_form(url: str, data: Dict[str, str], timeout: int = 30) -> Dict[str, Any]:
|
||||
"""发送 POST 表单请求"""
|
||||
body = urllib.parse.urlencode(data).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
url,
|
||||
data=body,
|
||||
method="POST",
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
raw = resp.read()
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(
|
||||
f"token exchange failed: {resp.status}: {raw.decode('utf-8', 'replace')}"
|
||||
)
|
||||
return json.loads(raw.decode("utf-8"))
|
||||
except urllib.error.HTTPError as exc:
|
||||
raw = exc.read()
|
||||
raise RuntimeError(
|
||||
f"token exchange failed: {exc.code}: {raw.decode('utf-8', 'replace')}"
|
||||
) from exc
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OAuthStart:
|
||||
"""OAuth 开始信息"""
|
||||
auth_url: str
|
||||
state: str
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
def generate_oauth_url(
|
||||
*,
|
||||
redirect_uri: str = OAUTH_REDIRECT_URI,
|
||||
scope: str = OAUTH_SCOPE,
|
||||
client_id: str = OAUTH_CLIENT_ID
|
||||
) -> OAuthStart:
|
||||
"""
|
||||
生成 OAuth 授权 URL
|
||||
|
||||
Args:
|
||||
redirect_uri: 回调地址
|
||||
scope: 权限范围
|
||||
client_id: OpenAI Client ID
|
||||
|
||||
Returns:
|
||||
OAuthStart 对象,包含授权 URL 和必要参数
|
||||
"""
|
||||
state = _random_state()
|
||||
code_verifier = _pkce_verifier()
|
||||
code_challenge = _sha256_b64url_no_pad(code_verifier)
|
||||
|
||||
params = {
|
||||
"client_id": client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": scope,
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"prompt": "login",
|
||||
"id_token_add_organizations": "true",
|
||||
"codex_cli_simplified_flow": "true",
|
||||
}
|
||||
auth_url = f"{OAUTH_AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
return OAuthStart(
|
||||
auth_url=auth_url,
|
||||
state=state,
|
||||
code_verifier=code_verifier,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
|
||||
def submit_callback_url(
|
||||
*,
|
||||
callback_url: str,
|
||||
expected_state: str,
|
||||
code_verifier: str,
|
||||
redirect_uri: str = OAUTH_REDIRECT_URI,
|
||||
client_id: str = OAUTH_CLIENT_ID,
|
||||
token_url: str = OAUTH_TOKEN_URL
|
||||
) -> str:
|
||||
"""
|
||||
处理 OAuth 回调 URL,获取访问令牌
|
||||
|
||||
Args:
|
||||
callback_url: 回调 URL
|
||||
expected_state: 预期的 state 值
|
||||
code_verifier: PKCE code_verifier
|
||||
redirect_uri: 回调地址
|
||||
client_id: OpenAI Client ID
|
||||
token_url: Token 交换地址
|
||||
|
||||
Returns:
|
||||
包含访问令牌等信息的 JSON 字符串
|
||||
|
||||
Raises:
|
||||
RuntimeError: OAuth 错误
|
||||
ValueError: 缺少必要参数或 state 不匹配
|
||||
"""
|
||||
cb = _parse_callback_url(callback_url)
|
||||
if cb["error"]:
|
||||
desc = cb["error_description"]
|
||||
raise RuntimeError(f"oauth error: {cb['error']}: {desc}".strip())
|
||||
|
||||
if not cb["code"]:
|
||||
raise ValueError("callback url missing ?code=")
|
||||
if not cb["state"]:
|
||||
raise ValueError("callback url missing ?state=")
|
||||
if cb["state"] != expected_state:
|
||||
raise ValueError("state mismatch")
|
||||
|
||||
token_resp = _post_form(
|
||||
token_url,
|
||||
{
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": client_id,
|
||||
"code": cb["code"],
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_verifier": code_verifier,
|
||||
},
|
||||
)
|
||||
|
||||
access_token = (token_resp.get("access_token") or "").strip()
|
||||
refresh_token = (token_resp.get("refresh_token") or "").strip()
|
||||
id_token = (token_resp.get("id_token") or "").strip()
|
||||
expires_in = _to_int(token_resp.get("expires_in"))
|
||||
|
||||
claims = _jwt_claims_no_verify(id_token)
|
||||
email = str(claims.get("email") or "").strip()
|
||||
auth_claims = claims.get("https://api.openai.com/auth") or {}
|
||||
account_id = str(auth_claims.get("chatgpt_account_id") or "").strip()
|
||||
|
||||
now = int(time.time())
|
||||
expired_rfc3339 = time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%SZ", time.gmtime(now + max(expires_in, 0))
|
||||
)
|
||||
now_rfc3339 = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(now))
|
||||
|
||||
config = {
|
||||
"id_token": id_token,
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"account_id": account_id,
|
||||
"last_refresh": now_rfc3339,
|
||||
"email": email,
|
||||
"type": "codex",
|
||||
"expired": expired_rfc3339,
|
||||
}
|
||||
|
||||
return json.dumps(config, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
|
||||
class OAuthManager:
|
||||
"""OAuth 管理器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str = OAUTH_CLIENT_ID,
|
||||
auth_url: str = OAUTH_AUTH_URL,
|
||||
token_url: str = OAUTH_TOKEN_URL,
|
||||
redirect_uri: str = OAUTH_REDIRECT_URI,
|
||||
scope: str = OAUTH_SCOPE
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.auth_url = auth_url
|
||||
self.token_url = token_url
|
||||
self.redirect_uri = redirect_uri
|
||||
self.scope = scope
|
||||
|
||||
def start_oauth(self) -> OAuthStart:
|
||||
"""开始 OAuth 流程"""
|
||||
return generate_oauth_url(
|
||||
redirect_uri=self.redirect_uri,
|
||||
scope=self.scope,
|
||||
client_id=self.client_id
|
||||
)
|
||||
|
||||
def handle_callback(
|
||||
self,
|
||||
callback_url: str,
|
||||
expected_state: str,
|
||||
code_verifier: str
|
||||
) -> Dict[str, Any]:
|
||||
"""处理 OAuth 回调"""
|
||||
result_json = submit_callback_url(
|
||||
callback_url=callback_url,
|
||||
expected_state=expected_state,
|
||||
code_verifier=code_verifier,
|
||||
redirect_uri=self.redirect_uri,
|
||||
client_id=self.client_id,
|
||||
token_url=self.token_url
|
||||
)
|
||||
return json.loads(result_json)
|
||||
|
||||
def extract_account_info(self, id_token: str) -> Dict[str, Any]:
|
||||
"""从 ID Token 中提取账户信息"""
|
||||
claims = _jwt_claims_no_verify(id_token)
|
||||
email = str(claims.get("email") or "").strip()
|
||||
auth_claims = claims.get("https://api.openai.com/auth") or {}
|
||||
account_id = str(auth_claims.get("chatgpt_account_id") or "").strip()
|
||||
|
||||
return {
|
||||
"email": email,
|
||||
"account_id": account_id,
|
||||
"claims": claims
|
||||
}
|
||||
723
src/core/register.py
Normal file
723
src/core/register.py
Normal file
@@ -0,0 +1,723 @@
|
||||
"""
|
||||
注册流程引擎
|
||||
从 main.py 中提取并重构的注册流程
|
||||
"""
|
||||
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import secrets
|
||||
import string
|
||||
from typing import Optional, Dict, Any, Tuple, Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
from curl_cffi import requests as cffi_requests
|
||||
|
||||
from .oauth import OAuthManager, OAuthStart
|
||||
from .http_client import OpenAIHTTPClient, HTTPClientError
|
||||
from ..services import EmailServiceFactory, BaseEmailService, EmailServiceType
|
||||
from ..database import crud
|
||||
from ..database.session import get_db
|
||||
from ..config.constants import (
|
||||
OPENAI_API_ENDPOINTS,
|
||||
DEFAULT_USER_INFO,
|
||||
OTP_CODE_PATTERN,
|
||||
DEFAULT_PASSWORD_LENGTH,
|
||||
PASSWORD_CHARSET,
|
||||
AccountStatus,
|
||||
TaskStatus,
|
||||
)
|
||||
from ..config.settings import get_settings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegistrationResult:
|
||||
"""注册结果"""
|
||||
success: bool
|
||||
email: str = ""
|
||||
account_id: str = ""
|
||||
workspace_id: str = ""
|
||||
access_token: str = ""
|
||||
refresh_token: str = ""
|
||||
id_token: str = ""
|
||||
error_message: str = ""
|
||||
logs: list = None
|
||||
metadata: dict = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"success": self.success,
|
||||
"email": self.email,
|
||||
"account_id": self.account_id,
|
||||
"workspace_id": self.workspace_id,
|
||||
"access_token": self.access_token[:20] + "..." if self.access_token else "",
|
||||
"refresh_token": self.refresh_token[:20] + "..." if self.refresh_token else "",
|
||||
"id_token": self.id_token[:20] + "..." if self.id_token else "",
|
||||
"error_message": self.error_message,
|
||||
"logs": self.logs or [],
|
||||
"metadata": self.metadata or {},
|
||||
}
|
||||
|
||||
|
||||
class RegistrationEngine:
|
||||
"""
|
||||
注册引擎
|
||||
负责协调邮箱服务、OAuth 流程和 OpenAI API 调用
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
email_service: BaseEmailService,
|
||||
proxy_url: Optional[str] = None,
|
||||
callback_logger: Optional[Callable[[str], None]] = None,
|
||||
task_uuid: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化注册引擎
|
||||
|
||||
Args:
|
||||
email_service: 邮箱服务实例
|
||||
proxy_url: 代理 URL
|
||||
callback_logger: 日志回调函数
|
||||
task_uuid: 任务 UUID(用于数据库记录)
|
||||
"""
|
||||
self.email_service = email_service
|
||||
self.proxy_url = proxy_url
|
||||
self.callback_logger = callback_logger or (lambda msg: logger.info(msg))
|
||||
self.task_uuid = task_uuid
|
||||
|
||||
# 创建 HTTP 客户端
|
||||
self.http_client = OpenAIHTTPClient(proxy_url=proxy_url)
|
||||
|
||||
# 创建 OAuth 管理器
|
||||
settings = get_settings()
|
||||
self.oauth_manager = OAuthManager(
|
||||
client_id=settings.openai_client_id,
|
||||
auth_url=settings.openai_auth_url,
|
||||
token_url=settings.openai_token_url,
|
||||
redirect_uri=settings.openai_redirect_uri,
|
||||
scope=settings.openai_scope
|
||||
)
|
||||
|
||||
# 状态变量
|
||||
self.email: Optional[str] = None
|
||||
self.email_info: Optional[Dict[str, Any]] = None
|
||||
self.oauth_start: Optional[OAuthStart] = None
|
||||
self.session: Optional[cffi_requests.Session] = None
|
||||
self.logs: list = []
|
||||
|
||||
def _log(self, message: str, level: str = "info"):
|
||||
"""记录日志"""
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
log_message = f"[{timestamp}] {message}"
|
||||
|
||||
# 添加到日志列表
|
||||
self.logs.append(log_message)
|
||||
|
||||
# 调用回调函数
|
||||
if self.callback_logger:
|
||||
self.callback_logger(log_message)
|
||||
|
||||
# 记录到数据库(如果有关联任务)
|
||||
if self.task_uuid:
|
||||
try:
|
||||
with get_db() as db:
|
||||
crud.append_task_log(db, self.task_uuid, log_message)
|
||||
except Exception as e:
|
||||
logger.warning(f"记录任务日志失败: {e}")
|
||||
|
||||
# 根据级别记录到日志系统
|
||||
if level == "error":
|
||||
logger.error(message)
|
||||
elif level == "warning":
|
||||
logger.warning(message)
|
||||
else:
|
||||
logger.info(message)
|
||||
|
||||
def _generate_password(self, length: int = DEFAULT_PASSWORD_LENGTH) -> str:
|
||||
"""生成随机密码"""
|
||||
return ''.join(secrets.choice(PASSWORD_CHARSET) for _ in range(length))
|
||||
|
||||
def _check_ip_location(self) -> Tuple[bool, Optional[str]]:
|
||||
"""检查 IP 地理位置"""
|
||||
try:
|
||||
return self.http_client.check_ip_location()
|
||||
except Exception as e:
|
||||
self._log(f"检查 IP 地理位置失败: {e}", "error")
|
||||
return False, None
|
||||
|
||||
def _create_email(self) -> bool:
|
||||
"""创建邮箱"""
|
||||
try:
|
||||
self._log(f"正在创建 {self.email_service.service_type.value} 邮箱...")
|
||||
self.email_info = self.email_service.create_email()
|
||||
|
||||
if not self.email_info or "email" not in self.email_info:
|
||||
self._log("创建邮箱失败: 返回信息不完整", "error")
|
||||
return False
|
||||
|
||||
self.email = self.email_info["email"]
|
||||
self._log(f"成功创建邮箱: {self.email}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"创建邮箱失败: {e}", "error")
|
||||
return False
|
||||
|
||||
def _start_oauth(self) -> bool:
|
||||
"""开始 OAuth 流程"""
|
||||
try:
|
||||
self._log("开始 OAuth 授权流程...")
|
||||
self.oauth_start = self.oauth_manager.start_oauth()
|
||||
self._log(f"OAuth URL 已生成: {self.oauth_start.auth_url[:80]}...")
|
||||
return True
|
||||
except Exception as e:
|
||||
self._log(f"生成 OAuth URL 失败: {e}", "error")
|
||||
return False
|
||||
|
||||
def _init_session(self) -> bool:
|
||||
"""初始化会话"""
|
||||
try:
|
||||
self.session = self.http_client.session
|
||||
return True
|
||||
except Exception as e:
|
||||
self._log(f"初始化会话失败: {e}", "error")
|
||||
return False
|
||||
|
||||
def _get_device_id(self) -> Optional[str]:
|
||||
"""获取 Device ID"""
|
||||
try:
|
||||
if not self.oauth_start:
|
||||
return None
|
||||
|
||||
response = self.session.get(
|
||||
self.oauth_start.auth_url,
|
||||
timeout=15
|
||||
)
|
||||
did = self.session.cookies.get("oai-did")
|
||||
self._log(f"Device ID: {did}")
|
||||
return did
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"获取 Device ID 失败: {e}", "error")
|
||||
return None
|
||||
|
||||
def _check_sentinel(self, did: str) -> Optional[str]:
|
||||
"""检查 Sentinel 拦截"""
|
||||
try:
|
||||
sen_req_body = f'{{"p":"","id":"{did}","flow":"authorize_continue"}}'
|
||||
|
||||
response = self.http_client.post(
|
||||
OPENAI_API_ENDPOINTS["sentinel"],
|
||||
headers={
|
||||
"origin": "https://sentinel.openai.com",
|
||||
"referer": "https://sentinel.openai.com/backend-api/sentinel/frame.html?sv=20260219f9f6",
|
||||
"content-type": "text/plain;charset=UTF-8",
|
||||
},
|
||||
data=sen_req_body,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
sen_token = response.json().get("token")
|
||||
self._log(f"Sentinel token 获取成功")
|
||||
return sen_token
|
||||
else:
|
||||
self._log(f"Sentinel 检查失败: {response.status_code}", "warning")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"Sentinel 检查异常: {e}", "warning")
|
||||
return None
|
||||
|
||||
def _submit_signup_form(self, did: str, sen_token: Optional[str]) -> bool:
|
||||
"""提交注册表单"""
|
||||
try:
|
||||
signup_body = f'{{"username":{{"value":"{self.email}","kind":"email"}},"screen_hint":"signup"}}'
|
||||
|
||||
headers = {
|
||||
"referer": "https://auth.openai.com/create-account",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
if sen_token:
|
||||
sentinel = f'{{"p": "", "t": "", "c": "{sen_token}", "id": "{did}", "flow": "authorize_continue"}}'
|
||||
headers["openai-sentinel-token"] = sentinel
|
||||
|
||||
response = self.session.post(
|
||||
OPENAI_API_ENDPOINTS["signup"],
|
||||
headers=headers,
|
||||
data=signup_body,
|
||||
)
|
||||
|
||||
self._log(f"提交注册表单状态: {response.status_code}")
|
||||
return response.status_code == 200
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"提交注册表单失败: {e}", "error")
|
||||
return False
|
||||
|
||||
def _register_password(self) -> Tuple[bool, Optional[str]]:
|
||||
"""注册密码"""
|
||||
try:
|
||||
# 生成密码
|
||||
password = self._generate_password()
|
||||
self._log(f"生成密码: {password}")
|
||||
|
||||
# 提交密码注册
|
||||
register_body = json.dumps({
|
||||
"password": password,
|
||||
"username": self.email
|
||||
})
|
||||
|
||||
response = self.session.post(
|
||||
OPENAI_API_ENDPOINTS["register"],
|
||||
headers={
|
||||
"referer": "https://auth.openai.com/create-account/password",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
data=register_body,
|
||||
)
|
||||
|
||||
self._log(f"提交密码状态: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self._log(f"密码注册失败: {response.text[:200]}", "warning")
|
||||
return False, None
|
||||
|
||||
return True, password
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"密码注册失败: {e}", "error")
|
||||
return False, None
|
||||
|
||||
def _send_verification_code(self) -> bool:
|
||||
"""发送验证码"""
|
||||
try:
|
||||
response = self.session.get(
|
||||
OPENAI_API_ENDPOINTS["send_otp"],
|
||||
headers={
|
||||
"referer": "https://auth.openai.com/create-account/password",
|
||||
"accept": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
self._log(f"验证码发送状态: {response.status_code}")
|
||||
return response.status_code == 200
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"发送验证码失败: {e}", "error")
|
||||
return False
|
||||
|
||||
def _get_verification_code(self) -> Optional[str]:
|
||||
"""获取验证码"""
|
||||
try:
|
||||
self._log(f"正在等待邮箱 {self.email} 的验证码...")
|
||||
|
||||
email_id = self.email_info.get("service_id") if self.email_info else None
|
||||
code = self.email_service.get_verification_code(
|
||||
email=self.email,
|
||||
email_id=email_id,
|
||||
timeout=120,
|
||||
pattern=OTP_CODE_PATTERN
|
||||
)
|
||||
|
||||
if code:
|
||||
self._log(f"成功获取验证码: {code}")
|
||||
return code
|
||||
else:
|
||||
self._log("等待验证码超时", "error")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"获取验证码失败: {e}", "error")
|
||||
return None
|
||||
|
||||
def _validate_verification_code(self, code: str) -> bool:
|
||||
"""验证验证码"""
|
||||
try:
|
||||
code_body = f'{{"code":"{code}"}}'
|
||||
|
||||
response = self.session.post(
|
||||
OPENAI_API_ENDPOINTS["validate_otp"],
|
||||
headers={
|
||||
"referer": "https://auth.openai.com/email-verification",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
data=code_body,
|
||||
)
|
||||
|
||||
self._log(f"验证码校验状态: {response.status_code}")
|
||||
return response.status_code == 200
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"验证验证码失败: {e}", "error")
|
||||
return False
|
||||
|
||||
def _create_user_account(self) -> bool:
|
||||
"""创建用户账户"""
|
||||
try:
|
||||
create_account_body = json.dumps(DEFAULT_USER_INFO)
|
||||
|
||||
response = self.session.post(
|
||||
OPENAI_API_ENDPOINTS["create_account"],
|
||||
headers={
|
||||
"referer": "https://auth.openai.com/about-you",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
data=create_account_body,
|
||||
)
|
||||
|
||||
self._log(f"账户创建状态: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
self._log(f"账户创建失败: {response.text[:200]}", "warning")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"创建账户失败: {e}", "error")
|
||||
return False
|
||||
|
||||
def _get_workspace_id(self) -> Optional[str]:
|
||||
"""获取 Workspace ID"""
|
||||
try:
|
||||
auth_cookie = self.session.cookies.get("oai-client-auth-session")
|
||||
if not auth_cookie:
|
||||
self._log("未能获取到授权 Cookie", "error")
|
||||
return None
|
||||
|
||||
# 解码 JWT
|
||||
import base64
|
||||
import json as json_module
|
||||
|
||||
try:
|
||||
segments = auth_cookie.split(".")
|
||||
if len(segments) < 1:
|
||||
self._log("授权 Cookie 格式错误", "error")
|
||||
return None
|
||||
|
||||
# 解码第一个 segment
|
||||
payload = segments[0]
|
||||
pad = "=" * ((4 - (len(payload) % 4)) % 4)
|
||||
decoded = base64.urlsafe_b64decode((payload + pad).encode("ascii"))
|
||||
auth_json = json_module.loads(decoded.decode("utf-8"))
|
||||
|
||||
workspaces = auth_json.get("workspaces") or []
|
||||
if not workspaces:
|
||||
self._log("授权 Cookie 里没有 workspace 信息", "error")
|
||||
return None
|
||||
|
||||
workspace_id = str((workspaces[0] or {}).get("id") or "").strip()
|
||||
if not workspace_id:
|
||||
self._log("无法解析 workspace_id", "error")
|
||||
return None
|
||||
|
||||
self._log(f"Workspace ID: {workspace_id}")
|
||||
return workspace_id
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"解析授权 Cookie 失败: {e}", "error")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"获取 Workspace ID 失败: {e}", "error")
|
||||
return None
|
||||
|
||||
def _select_workspace(self, workspace_id: str) -> Optional[str]:
|
||||
"""选择 Workspace"""
|
||||
try:
|
||||
select_body = f'{{"workspace_id":"{workspace_id}"}}'
|
||||
|
||||
response = self.session.post(
|
||||
OPENAI_API_ENDPOINTS["select_workspace"],
|
||||
headers={
|
||||
"referer": "https://auth.openai.com/sign-in-with-chatgpt/codex/consent",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
data=select_body,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
self._log(f"选择 workspace 失败: {response.status_code}", "error")
|
||||
self._log(f"响应: {response.text[:200]}", "warning")
|
||||
return None
|
||||
|
||||
continue_url = str((response.json() or {}).get("continue_url") or "").strip()
|
||||
if not continue_url:
|
||||
self._log("workspace/select 响应里缺少 continue_url", "error")
|
||||
return None
|
||||
|
||||
self._log(f"Continue URL: {continue_url[:100]}...")
|
||||
return continue_url
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"选择 Workspace 失败: {e}", "error")
|
||||
return None
|
||||
|
||||
def _follow_redirects(self, start_url: str) -> Optional[str]:
|
||||
"""跟随重定向链,寻找回调 URL"""
|
||||
try:
|
||||
current_url = start_url
|
||||
max_redirects = 6
|
||||
|
||||
for i in range(max_redirects):
|
||||
self._log(f"重定向 {i+1}/{max_redirects}: {current_url[:100]}...")
|
||||
|
||||
response = self.session.get(
|
||||
current_url,
|
||||
allow_redirects=False,
|
||||
timeout=15
|
||||
)
|
||||
|
||||
location = response.headers.get("Location") or ""
|
||||
|
||||
# 如果不是重定向状态码,停止
|
||||
if response.status_code not in [301, 302, 303, 307, 308]:
|
||||
self._log(f"非重定向状态码: {response.status_code}")
|
||||
break
|
||||
|
||||
if not location:
|
||||
self._log("重定向响应缺少 Location 头")
|
||||
break
|
||||
|
||||
# 构建下一个 URL
|
||||
import urllib.parse
|
||||
next_url = urllib.parse.urljoin(current_url, location)
|
||||
|
||||
# 检查是否包含回调参数
|
||||
if "code=" in next_url and "state=" in next_url:
|
||||
self._log(f"找到回调 URL: {next_url[:100]}...")
|
||||
return next_url
|
||||
|
||||
current_url = next_url
|
||||
|
||||
self._log("未能在重定向链中找到回调 URL", "error")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"跟随重定向失败: {e}", "error")
|
||||
return None
|
||||
|
||||
def _handle_oauth_callback(self, callback_url: str) -> Optional[Dict[str, Any]]:
|
||||
"""处理 OAuth 回调"""
|
||||
try:
|
||||
if not self.oauth_start:
|
||||
self._log("OAuth 流程未初始化", "error")
|
||||
return None
|
||||
|
||||
self._log("处理 OAuth 回调...")
|
||||
token_info = self.oauth_manager.handle_callback(
|
||||
callback_url=callback_url,
|
||||
expected_state=self.oauth_start.state,
|
||||
code_verifier=self.oauth_start.code_verifier
|
||||
)
|
||||
|
||||
self._log("OAuth 授权成功")
|
||||
return token_info
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"处理 OAuth 回调失败: {e}", "error")
|
||||
return None
|
||||
|
||||
def run(self) -> RegistrationResult:
|
||||
"""
|
||||
执行完整的注册流程
|
||||
|
||||
Returns:
|
||||
RegistrationResult: 注册结果
|
||||
"""
|
||||
result = RegistrationResult(success=False, logs=self.logs)
|
||||
|
||||
try:
|
||||
self._log("=" * 60)
|
||||
self._log("开始注册流程")
|
||||
self._log("=" * 60)
|
||||
|
||||
# 1. 检查 IP 地理位置
|
||||
self._log("1. 检查 IP 地理位置...")
|
||||
ip_ok, location = self._check_ip_location()
|
||||
if not ip_ok:
|
||||
result.error_message = f"IP 地理位置不支持: {location}"
|
||||
self._log(f"IP 检查失败: {location}", "error")
|
||||
return result
|
||||
|
||||
self._log(f"IP 位置: {location}")
|
||||
|
||||
# 2. 创建邮箱
|
||||
self._log("2. 创建邮箱...")
|
||||
if not self._create_email():
|
||||
result.error_message = "创建邮箱失败"
|
||||
return result
|
||||
|
||||
result.email = self.email
|
||||
|
||||
# 3. 初始化会话
|
||||
self._log("3. 初始化会话...")
|
||||
if not self._init_session():
|
||||
result.error_message = "初始化会话失败"
|
||||
return result
|
||||
|
||||
# 4. 开始 OAuth 流程
|
||||
self._log("4. 开始 OAuth 授权流程...")
|
||||
if not self._start_oauth():
|
||||
result.error_message = "开始 OAuth 流程失败"
|
||||
return result
|
||||
|
||||
# 5. 获取 Device ID
|
||||
self._log("5. 获取 Device ID...")
|
||||
did = self._get_device_id()
|
||||
if not did:
|
||||
result.error_message = "获取 Device ID 失败"
|
||||
return result
|
||||
|
||||
# 6. 检查 Sentinel 拦截
|
||||
self._log("6. 检查 Sentinel 拦截...")
|
||||
sen_token = self._check_sentinel(did)
|
||||
if sen_token:
|
||||
self._log("Sentinel 检查通过")
|
||||
else:
|
||||
self._log("Sentinel 检查失败或未启用", "warning")
|
||||
|
||||
# 7. 提交注册表单
|
||||
self._log("7. 提交注册表单...")
|
||||
if not self._submit_signup_form(did, sen_token):
|
||||
result.error_message = "提交注册表单失败"
|
||||
return result
|
||||
|
||||
# 8. 注册密码
|
||||
self._log("8. 注册密码...")
|
||||
password_ok, password = self._register_password()
|
||||
if not password_ok:
|
||||
result.error_message = "注册密码失败"
|
||||
return result
|
||||
|
||||
# 9. 发送验证码
|
||||
self._log("9. 发送验证码...")
|
||||
if not self._send_verification_code():
|
||||
result.error_message = "发送验证码失败"
|
||||
return result
|
||||
|
||||
# 10. 获取验证码
|
||||
self._log("10. 等待验证码...")
|
||||
code = self._get_verification_code()
|
||||
if not code:
|
||||
result.error_message = "获取验证码失败"
|
||||
return result
|
||||
|
||||
# 11. 验证验证码
|
||||
self._log("11. 验证验证码...")
|
||||
if not self._validate_verification_code(code):
|
||||
result.error_message = "验证验证码失败"
|
||||
return result
|
||||
|
||||
# 12. 创建用户账户
|
||||
self._log("12. 创建用户账户...")
|
||||
if not self._create_user_account():
|
||||
result.error_message = "创建用户账户失败"
|
||||
return result
|
||||
|
||||
# 13. 获取 Workspace ID
|
||||
self._log("13. 获取 Workspace ID...")
|
||||
workspace_id = self._get_workspace_id()
|
||||
if not workspace_id:
|
||||
result.error_message = "获取 Workspace ID 失败"
|
||||
return result
|
||||
|
||||
result.workspace_id = workspace_id
|
||||
|
||||
# 14. 选择 Workspace
|
||||
self._log("14. 选择 Workspace...")
|
||||
continue_url = self._select_workspace(workspace_id)
|
||||
if not continue_url:
|
||||
result.error_message = "选择 Workspace 失败"
|
||||
return result
|
||||
|
||||
# 15. 跟随重定向链
|
||||
self._log("15. 跟随重定向链...")
|
||||
callback_url = self._follow_redirects(continue_url)
|
||||
if not callback_url:
|
||||
result.error_message = "跟随重定向链失败"
|
||||
return result
|
||||
|
||||
# 16. 处理 OAuth 回调
|
||||
self._log("16. 处理 OAuth 回调...")
|
||||
token_info = self._handle_oauth_callback(callback_url)
|
||||
if not token_info:
|
||||
result.error_message = "处理 OAuth 回调失败"
|
||||
return result
|
||||
|
||||
# 提取账户信息
|
||||
result.account_id = token_info.get("account_id", "")
|
||||
result.access_token = token_info.get("access_token", "")
|
||||
result.refresh_token = token_info.get("refresh_token", "")
|
||||
result.id_token = token_info.get("id_token", "")
|
||||
|
||||
# 17. 完成
|
||||
self._log("=" * 60)
|
||||
self._log(f"注册成功!")
|
||||
self._log(f"邮箱: {result.email}")
|
||||
self._log(f"Account ID: {result.account_id}")
|
||||
self._log(f"Workspace ID: {result.workspace_id}")
|
||||
self._log("=" * 60)
|
||||
|
||||
result.success = True
|
||||
result.metadata = {
|
||||
"email_service": self.email_service.service_type.value,
|
||||
"proxy_used": self.proxy_url,
|
||||
"registered_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"注册过程中发生未预期错误: {e}", "error")
|
||||
result.error_message = str(e)
|
||||
return result
|
||||
|
||||
def save_to_database(self, result: RegistrationResult) -> bool:
|
||||
"""
|
||||
保存注册结果到数据库
|
||||
|
||||
Args:
|
||||
result: 注册结果
|
||||
|
||||
Returns:
|
||||
是否保存成功
|
||||
"""
|
||||
if not result.success:
|
||||
return False
|
||||
|
||||
try:
|
||||
with get_db() as db:
|
||||
# 保存账户信息
|
||||
account = crud.create_account(
|
||||
db,
|
||||
email=result.email,
|
||||
email_service=self.email_service.service_type.value,
|
||||
email_service_id=self.email_info.get("service_id") if self.email_info else None,
|
||||
account_id=result.account_id,
|
||||
workspace_id=result.workspace_id,
|
||||
access_token=result.access_token,
|
||||
refresh_token=result.refresh_token,
|
||||
id_token=result.id_token,
|
||||
proxy_used=self.proxy_url,
|
||||
metadata=result.metadata
|
||||
)
|
||||
|
||||
self._log(f"账户已保存到数据库,ID: {account.id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._log(f"保存到数据库失败: {e}", "error")
|
||||
return False
|
||||
566
src/core/utils.py
Normal file
566
src/core/utils.py
Normal file
@@ -0,0 +1,566 @@
|
||||
"""
|
||||
通用工具函数
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import string
|
||||
import secrets
|
||||
import hashlib
|
||||
import logging
|
||||
import base64
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional, Union, Callable
|
||||
from pathlib import Path
|
||||
|
||||
from ..config.constants import PASSWORD_CHARSET, DEFAULT_PASSWORD_LENGTH
|
||||
from ..config.settings import get_settings
|
||||
|
||||
|
||||
def setup_logging(
|
||||
log_level: str = "INFO",
|
||||
log_file: Optional[str] = None,
|
||||
log_format: str = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
配置日志系统
|
||||
|
||||
Args:
|
||||
log_level: 日志级别 (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
log_file: 日志文件路径,如果不指定则只输出到控制台
|
||||
log_format: 日志格式
|
||||
|
||||
Returns:
|
||||
根日志记录器
|
||||
"""
|
||||
# 设置日志级别
|
||||
numeric_level = getattr(logging, log_level.upper(), None)
|
||||
if not isinstance(numeric_level, int):
|
||||
numeric_level = logging.INFO
|
||||
|
||||
# 配置根日志记录器
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(numeric_level)
|
||||
|
||||
# 清除现有的处理器
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# 创建格式化器
|
||||
formatter = logging.Formatter(log_format)
|
||||
|
||||
# 控制台处理器
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(formatter)
|
||||
console_handler.setLevel(numeric_level)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# 文件处理器(如果指定了日志文件)
|
||||
if log_file:
|
||||
# 确保日志目录存在
|
||||
log_dir = os.path.dirname(log_file)
|
||||
if log_dir:
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(numeric_level)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
return root_logger
|
||||
|
||||
|
||||
def generate_password(length: int = DEFAULT_PASSWORD_LENGTH) -> str:
|
||||
"""
|
||||
生成随机密码
|
||||
|
||||
Args:
|
||||
length: 密码长度
|
||||
|
||||
Returns:
|
||||
随机密码字符串
|
||||
"""
|
||||
if length < 4:
|
||||
length = 4
|
||||
|
||||
# 确保密码包含至少一个大写字母、一个小写字母和一个数字
|
||||
password = [
|
||||
secrets.choice(string.ascii_lowercase),
|
||||
secrets.choice(string.ascii_uppercase),
|
||||
secrets.choice(string.digits),
|
||||
]
|
||||
|
||||
# 添加剩余字符
|
||||
password.extend(secrets.choice(PASSWORD_CHARSET) for _ in range(length - 3))
|
||||
|
||||
# 随机打乱
|
||||
secrets.SystemRandom().shuffle(password)
|
||||
|
||||
return ''.join(password)
|
||||
|
||||
|
||||
def generate_random_string(length: int = 8) -> str:
|
||||
"""
|
||||
生成随机字符串(仅字母)
|
||||
|
||||
Args:
|
||||
length: 字符串长度
|
||||
|
||||
Returns:
|
||||
随机字符串
|
||||
"""
|
||||
chars = string.ascii_letters
|
||||
return ''.join(secrets.choice(chars) for _ in range(length))
|
||||
|
||||
|
||||
def generate_uuid() -> str:
|
||||
"""生成 UUID 字符串"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def get_timestamp() -> int:
|
||||
"""获取当前时间戳(秒)"""
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def format_datetime(dt: Optional[datetime] = None, fmt: str = "%Y-%m-%d %H:%M:%S") -> str:
|
||||
"""
|
||||
格式化日期时间
|
||||
|
||||
Args:
|
||||
dt: 日期时间对象,如果为 None 则使用当前时间
|
||||
fmt: 格式字符串
|
||||
|
||||
Returns:
|
||||
格式化后的字符串
|
||||
"""
|
||||
if dt is None:
|
||||
dt = datetime.now()
|
||||
return dt.strftime(fmt)
|
||||
|
||||
|
||||
def parse_datetime(dt_str: str, fmt: str = "%Y-%m-%d %H:%M:%S") -> Optional[datetime]:
|
||||
"""
|
||||
解析日期时间字符串
|
||||
|
||||
Args:
|
||||
dt_str: 日期时间字符串
|
||||
fmt: 格式字符串
|
||||
|
||||
Returns:
|
||||
日期时间对象,如果解析失败返回 None
|
||||
"""
|
||||
try:
|
||||
return datetime.strptime(dt_str, fmt)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def human_readable_size(size_bytes: int) -> str:
|
||||
"""
|
||||
将字节大小转换为人类可读的格式
|
||||
|
||||
Args:
|
||||
size_bytes: 字节大小
|
||||
|
||||
Returns:
|
||||
人类可读的字符串
|
||||
"""
|
||||
if size_bytes < 0:
|
||||
return "0 B"
|
||||
|
||||
units = ["B", "KB", "MB", "GB", "TB", "PB"]
|
||||
unit_index = 0
|
||||
|
||||
while size_bytes >= 1024 and unit_index < len(units) - 1:
|
||||
size_bytes /= 1024
|
||||
unit_index += 1
|
||||
|
||||
return f"{size_bytes:.2f} {units[unit_index]}"
|
||||
|
||||
|
||||
def retry_with_backoff(
|
||||
func: Callable,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 1.0,
|
||||
max_delay: float = 30.0,
|
||||
backoff_factor: float = 2.0,
|
||||
exceptions: tuple = (Exception,)
|
||||
) -> Any:
|
||||
"""
|
||||
带有指数退避的重试装饰器/函数
|
||||
|
||||
Args:
|
||||
func: 要重试的函数
|
||||
max_retries: 最大重试次数
|
||||
base_delay: 基础延迟(秒)
|
||||
max_delay: 最大延迟(秒)
|
||||
backoff_factor: 退避因子
|
||||
exceptions: 要捕获的异常类型
|
||||
|
||||
Returns:
|
||||
函数的返回值
|
||||
|
||||
Raises:
|
||||
最后一次尝试的异常
|
||||
"""
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return func()
|
||||
except exceptions as e:
|
||||
last_exception = e
|
||||
|
||||
# 如果是最后一次尝试,直接抛出异常
|
||||
if attempt == max_retries:
|
||||
break
|
||||
|
||||
# 计算延迟时间
|
||||
delay = min(base_delay * (backoff_factor ** attempt), max_delay)
|
||||
|
||||
# 添加随机抖动
|
||||
delay *= (0.5 + random.random())
|
||||
|
||||
# 记录日志
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
f"尝试 {func.__name__} 失败 (attempt {attempt + 1}/{max_retries + 1}): {e}. "
|
||||
f"等待 {delay:.2f} 秒后重试..."
|
||||
)
|
||||
|
||||
time.sleep(delay)
|
||||
|
||||
# 所有重试都失败,抛出最后一个异常
|
||||
raise last_exception
|
||||
|
||||
|
||||
class RetryDecorator:
|
||||
"""重试装饰器类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 1.0,
|
||||
max_delay: float = 30.0,
|
||||
backoff_factor: float = 2.0,
|
||||
exceptions: tuple = (Exception,)
|
||||
):
|
||||
self.max_retries = max_retries
|
||||
self.base_delay = base_delay
|
||||
self.max_delay = max_delay
|
||||
self.backoff_factor = backoff_factor
|
||||
self.exceptions = exceptions
|
||||
|
||||
def __call__(self, func: Callable) -> Callable:
|
||||
"""装饰器调用"""
|
||||
def wrapper(*args, **kwargs):
|
||||
def func_to_retry():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return retry_with_backoff(
|
||||
func_to_retry,
|
||||
max_retries=self.max_retries,
|
||||
base_delay=self.base_delay,
|
||||
max_delay=self.max_delay,
|
||||
backoff_factor=self.backoff_factor,
|
||||
exceptions=self.exceptions
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def validate_email(email: str) -> bool:
|
||||
"""
|
||||
验证邮箱地址格式
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
|
||||
Returns:
|
||||
是否有效
|
||||
"""
|
||||
pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
return bool(re.match(pattern, email))
|
||||
|
||||
|
||||
def validate_url(url: str) -> bool:
|
||||
"""
|
||||
验证 URL 格式
|
||||
|
||||
Args:
|
||||
url: URL
|
||||
|
||||
Returns:
|
||||
是否有效
|
||||
"""
|
||||
pattern = r"^https?://[^\s/$.?#].[^\s]*$"
|
||||
return bool(re.match(pattern, url))
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""
|
||||
清理文件名,移除不安全的字符
|
||||
|
||||
Args:
|
||||
filename: 原始文件名
|
||||
|
||||
Returns:
|
||||
清理后的文件名
|
||||
"""
|
||||
# 移除危险字符
|
||||
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
|
||||
# 移除控制字符
|
||||
filename = ''.join(char for char in filename if ord(char) >= 32)
|
||||
# 限制长度
|
||||
if len(filename) > 255:
|
||||
name, ext = os.path.splitext(filename)
|
||||
filename = name[:255 - len(ext)] + ext
|
||||
return filename
|
||||
|
||||
|
||||
def read_json_file(filepath: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
读取 JSON 文件
|
||||
|
||||
Args:
|
||||
filepath: 文件路径
|
||||
|
||||
Returns:
|
||||
JSON 数据,如果读取失败返回 None
|
||||
"""
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError, IOError) as e:
|
||||
logging.getLogger(__name__).warning(f"读取 JSON 文件失败: {filepath} - {e}")
|
||||
return None
|
||||
|
||||
|
||||
def write_json_file(filepath: str, data: Dict[str, Any], indent: int = 2) -> bool:
|
||||
"""
|
||||
写入 JSON 文件
|
||||
|
||||
Args:
|
||||
filepath: 文件路径
|
||||
data: 要写入的数据
|
||||
indent: 缩进空格数
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=indent)
|
||||
|
||||
return True
|
||||
except (IOError, TypeError) as e:
|
||||
logging.getLogger(__name__).error(f"写入 JSON 文件失败: {filepath} - {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_project_root() -> Path:
|
||||
"""
|
||||
获取项目根目录
|
||||
|
||||
Returns:
|
||||
项目根目录 Path 对象
|
||||
"""
|
||||
# 当前文件所在目录
|
||||
current_dir = Path(__file__).parent
|
||||
|
||||
# 向上查找直到找到项目根目录(包含 pyproject.toml 或 setup.py)
|
||||
for parent in [current_dir] + list(current_dir.parents):
|
||||
if (parent / "pyproject.toml").exists() or (parent / "setup.py").exists():
|
||||
return parent
|
||||
|
||||
# 如果找不到,返回当前目录的父目录
|
||||
return current_dir.parent
|
||||
|
||||
|
||||
def get_data_dir() -> Path:
|
||||
"""
|
||||
获取数据目录
|
||||
|
||||
Returns:
|
||||
数据目录 Path 对象
|
||||
"""
|
||||
settings = get_settings()
|
||||
data_dir = Path(settings.database_url).parent
|
||||
|
||||
# 如果 database_url 是 SQLite URL,提取路径
|
||||
if settings.database_url.startswith("sqlite:///"):
|
||||
db_path = settings.database_url[10:] # 移除 "sqlite:///"
|
||||
data_dir = Path(db_path).parent
|
||||
|
||||
# 确保目录存在
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return data_dir
|
||||
|
||||
|
||||
def get_logs_dir() -> Path:
|
||||
"""
|
||||
获取日志目录
|
||||
|
||||
Returns:
|
||||
日志目录 Path 对象
|
||||
"""
|
||||
settings = get_settings()
|
||||
log_file = Path(settings.log_file)
|
||||
log_dir = log_file.parent
|
||||
|
||||
# 确保目录存在
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return log_dir
|
||||
|
||||
|
||||
def format_duration(seconds: int) -> str:
|
||||
"""
|
||||
格式化持续时间
|
||||
|
||||
Args:
|
||||
seconds: 秒数
|
||||
|
||||
Returns:
|
||||
格式化的持续时间字符串
|
||||
"""
|
||||
if seconds < 60:
|
||||
return f"{seconds}秒"
|
||||
|
||||
minutes, seconds = divmod(seconds, 60)
|
||||
if minutes < 60:
|
||||
return f"{minutes}分{seconds}秒"
|
||||
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
if hours < 24:
|
||||
return f"{hours}小时{minutes}分"
|
||||
|
||||
days, hours = divmod(hours, 24)
|
||||
return f"{days}天{hours}小时"
|
||||
|
||||
|
||||
def mask_sensitive_data(data: Union[str, Dict, List], mask_char: str = "*") -> Union[str, Dict, List]:
|
||||
"""
|
||||
掩码敏感数据
|
||||
|
||||
Args:
|
||||
data: 要掩码的数据
|
||||
mask_char: 掩码字符
|
||||
|
||||
Returns:
|
||||
掩码后的数据
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
# 如果是邮箱,掩码中间部分
|
||||
if "@" in data:
|
||||
local, domain = data.split("@", 1)
|
||||
if len(local) > 2:
|
||||
masked_local = local[0] + mask_char * (len(local) - 2) + local[-1]
|
||||
else:
|
||||
masked_local = mask_char * len(local)
|
||||
return f"{masked_local}@{domain}"
|
||||
|
||||
# 如果是 token 或密钥,掩码大部分内容
|
||||
if len(data) > 10:
|
||||
return data[:4] + mask_char * (len(data) - 8) + data[-4:]
|
||||
return mask_char * len(data)
|
||||
|
||||
elif isinstance(data, dict):
|
||||
masked_dict = {}
|
||||
for key, value in data.items():
|
||||
# 敏感字段名
|
||||
sensitive_keys = ["password", "token", "secret", "key", "auth", "credential"]
|
||||
if any(sensitive in key.lower() for sensitive in sensitive_keys):
|
||||
masked_dict[key] = mask_sensitive_data(value, mask_char)
|
||||
else:
|
||||
masked_dict[key] = value
|
||||
return masked_dict
|
||||
|
||||
elif isinstance(data, list):
|
||||
return [mask_sensitive_data(item, mask_char) for item in data]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def calculate_md5(data: Union[str, bytes]) -> str:
|
||||
"""
|
||||
计算 MD5 哈希
|
||||
|
||||
Args:
|
||||
data: 要哈希的数据
|
||||
|
||||
Returns:
|
||||
MD5 哈希字符串
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
|
||||
return hashlib.md5(data).hexdigest()
|
||||
|
||||
|
||||
def calculate_sha256(data: Union[str, bytes]) -> str:
|
||||
"""
|
||||
计算 SHA256 哈希
|
||||
|
||||
Args:
|
||||
data: 要哈希的数据
|
||||
|
||||
Returns:
|
||||
SHA256 哈希字符串
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
|
||||
return hashlib.sha256(data).hexdigest()
|
||||
|
||||
|
||||
def base64_encode(data: Union[str, bytes]) -> str:
|
||||
"""Base64 编码"""
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
|
||||
return base64.b64encode(data).decode('utf-8')
|
||||
|
||||
|
||||
def base64_decode(data: str) -> str:
|
||||
"""Base64 解码"""
|
||||
try:
|
||||
decoded = base64.b64decode(data)
|
||||
return decoded.decode('utf-8')
|
||||
except (base64.binascii.Error, UnicodeDecodeError):
|
||||
return ""
|
||||
|
||||
|
||||
class Timer:
|
||||
"""计时器上下文管理器"""
|
||||
|
||||
def __init__(self, name: str = "操作"):
|
||||
self.name = name
|
||||
self.start_time = None
|
||||
self.elapsed = None
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.elapsed = time.time() - self.start_time
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.debug(f"{self.name} 耗时: {self.elapsed:.2f} 秒")
|
||||
|
||||
def get_elapsed(self) -> float:
|
||||
"""获取经过的时间(秒)"""
|
||||
if self.elapsed is not None:
|
||||
return self.elapsed
|
||||
if self.start_time is not None:
|
||||
return time.time() - self.start_time
|
||||
return 0.0
|
||||
Reference in New Issue
Block a user