mirror of
https://github.com/cnlimiter/codex-register.git
synced 2026-05-12 02:21:07 +08:00
336 lines
9.5 KiB
Python
336 lines
9.5 KiB
Python
"""
|
||
OpenAI OAuth 授权模块
|
||
从 main.py 中提取的 OAuth 相关函数
|
||
"""
|
||
|
||
import base64
|
||
import hashlib
|
||
import json
|
||
import secrets
|
||
import time
|
||
import urllib.error
|
||
import urllib.parse
|
||
import urllib.request
|
||
from dataclasses import dataclass
|
||
from typing import Any, Dict, Optional
|
||
|
||
from ..config.constants import (
|
||
OAUTH_CLIENT_ID,
|
||
OAUTH_AUTH_URL,
|
||
OAUTH_TOKEN_URL,
|
||
OAUTH_REDIRECT_URI,
|
||
OAUTH_SCOPE,
|
||
)
|
||
|
||
|
||
def _b64url_no_pad(raw: bytes) -> str:
|
||
"""Base64 URL 编码(无填充)"""
|
||
return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")
|
||
|
||
|
||
def _sha256_b64url_no_pad(s: str) -> str:
|
||
"""SHA256 哈希后 Base64 URL 编码"""
|
||
return _b64url_no_pad(hashlib.sha256(s.encode("ascii")).digest())
|
||
|
||
|
||
def _random_state(nbytes: int = 16) -> str:
|
||
"""生成随机 state"""
|
||
return secrets.token_urlsafe(nbytes)
|
||
|
||
|
||
def _pkce_verifier() -> str:
|
||
"""生成 PKCE code_verifier"""
|
||
return secrets.token_urlsafe(64)
|
||
|
||
|
||
def _parse_callback_url(callback_url: str) -> Dict[str, str]:
|
||
"""解析回调 URL"""
|
||
candidate = callback_url.strip()
|
||
if not candidate:
|
||
return {"code": "", "state": "", "error": "", "error_description": ""}
|
||
|
||
if "://" not in candidate:
|
||
if candidate.startswith("?"):
|
||
candidate = f"http://localhost{candidate}"
|
||
elif any(ch in candidate for ch in "/?#") or ":" in candidate:
|
||
candidate = f"http://{candidate}"
|
||
elif "=" in candidate:
|
||
candidate = f"http://localhost/?{candidate}"
|
||
|
||
parsed = urllib.parse.urlparse(candidate)
|
||
query = urllib.parse.parse_qs(parsed.query, keep_blank_values=True)
|
||
fragment = urllib.parse.parse_qs(parsed.fragment, keep_blank_values=True)
|
||
|
||
for key, values in fragment.items():
|
||
if key not in query or not query[key] or not (query[key][0] or "").strip():
|
||
query[key] = values
|
||
|
||
def get1(k: str) -> str:
|
||
v = query.get(k, [""])
|
||
return (v[0] or "").strip()
|
||
|
||
code = get1("code")
|
||
state = get1("state")
|
||
error = get1("error")
|
||
error_description = get1("error_description")
|
||
|
||
if code and not state and "#" in code:
|
||
code, state = code.split("#", 1)
|
||
|
||
if not error and error_description:
|
||
error, error_description = error_description, ""
|
||
|
||
return {
|
||
"code": code,
|
||
"state": state,
|
||
"error": error,
|
||
"error_description": error_description,
|
||
}
|
||
|
||
|
||
def _jwt_claims_no_verify(id_token: str) -> Dict[str, Any]:
|
||
"""解析 JWT ID Token(不验证签名)"""
|
||
if not id_token or id_token.count(".") < 2:
|
||
return {}
|
||
payload_b64 = id_token.split(".")[1]
|
||
pad = "=" * ((4 - (len(payload_b64) % 4)) % 4)
|
||
try:
|
||
payload = base64.urlsafe_b64decode((payload_b64 + pad).encode("ascii"))
|
||
return json.loads(payload.decode("utf-8"))
|
||
except Exception:
|
||
return {}
|
||
|
||
|
||
def _decode_jwt_segment(seg: str) -> Dict[str, Any]:
|
||
"""解码 JWT 片段"""
|
||
raw = (seg or "").strip()
|
||
if not raw:
|
||
return {}
|
||
pad = "=" * ((4 - (len(raw) % 4)) % 4)
|
||
try:
|
||
decoded = base64.urlsafe_b64decode((raw + pad).encode("ascii"))
|
||
return json.loads(decoded.decode("utf-8"))
|
||
except Exception:
|
||
return {}
|
||
|
||
|
||
def _to_int(v: Any) -> int:
|
||
"""转换为整数"""
|
||
try:
|
||
return int(v)
|
||
except (TypeError, ValueError):
|
||
return 0
|
||
|
||
|
||
def _post_form(url: str, data: Dict[str, str], timeout: int = 30) -> Dict[str, Any]:
|
||
"""发送 POST 表单请求"""
|
||
body = urllib.parse.urlencode(data).encode("utf-8")
|
||
req = urllib.request.Request(
|
||
url,
|
||
data=body,
|
||
method="POST",
|
||
headers={
|
||
"Content-Type": "application/x-www-form-urlencoded",
|
||
"Accept": "application/json",
|
||
},
|
||
)
|
||
try:
|
||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||
raw = resp.read()
|
||
if resp.status != 200:
|
||
raise RuntimeError(
|
||
f"token exchange failed: {resp.status}: {raw.decode('utf-8', 'replace')}"
|
||
)
|
||
return json.loads(raw.decode("utf-8"))
|
||
except urllib.error.HTTPError as exc:
|
||
raw = exc.read()
|
||
raise RuntimeError(
|
||
f"token exchange failed: {exc.code}: {raw.decode('utf-8', 'replace')}"
|
||
) from exc
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class OAuthStart:
|
||
"""OAuth 开始信息"""
|
||
auth_url: str
|
||
state: str
|
||
code_verifier: str
|
||
redirect_uri: str
|
||
|
||
|
||
def generate_oauth_url(
|
||
*,
|
||
redirect_uri: str = OAUTH_REDIRECT_URI,
|
||
scope: str = OAUTH_SCOPE,
|
||
client_id: str = OAUTH_CLIENT_ID
|
||
) -> OAuthStart:
|
||
"""
|
||
生成 OAuth 授权 URL
|
||
|
||
Args:
|
||
redirect_uri: 回调地址
|
||
scope: 权限范围
|
||
client_id: OpenAI Client ID
|
||
|
||
Returns:
|
||
OAuthStart 对象,包含授权 URL 和必要参数
|
||
"""
|
||
state = _random_state()
|
||
code_verifier = _pkce_verifier()
|
||
code_challenge = _sha256_b64url_no_pad(code_verifier)
|
||
|
||
params = {
|
||
"client_id": client_id,
|
||
"response_type": "code",
|
||
"redirect_uri": redirect_uri,
|
||
"scope": scope,
|
||
"state": state,
|
||
"code_challenge": code_challenge,
|
||
"code_challenge_method": "S256",
|
||
"prompt": "login",
|
||
"id_token_add_organizations": "true",
|
||
"codex_cli_simplified_flow": "true",
|
||
}
|
||
auth_url = f"{OAUTH_AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||
return OAuthStart(
|
||
auth_url=auth_url,
|
||
state=state,
|
||
code_verifier=code_verifier,
|
||
redirect_uri=redirect_uri,
|
||
)
|
||
|
||
|
||
def submit_callback_url(
|
||
*,
|
||
callback_url: str,
|
||
expected_state: str,
|
||
code_verifier: str,
|
||
redirect_uri: str = OAUTH_REDIRECT_URI,
|
||
client_id: str = OAUTH_CLIENT_ID,
|
||
token_url: str = OAUTH_TOKEN_URL
|
||
) -> str:
|
||
"""
|
||
处理 OAuth 回调 URL,获取访问令牌
|
||
|
||
Args:
|
||
callback_url: 回调 URL
|
||
expected_state: 预期的 state 值
|
||
code_verifier: PKCE code_verifier
|
||
redirect_uri: 回调地址
|
||
client_id: OpenAI Client ID
|
||
token_url: Token 交换地址
|
||
|
||
Returns:
|
||
包含访问令牌等信息的 JSON 字符串
|
||
|
||
Raises:
|
||
RuntimeError: OAuth 错误
|
||
ValueError: 缺少必要参数或 state 不匹配
|
||
"""
|
||
cb = _parse_callback_url(callback_url)
|
||
if cb["error"]:
|
||
desc = cb["error_description"]
|
||
raise RuntimeError(f"oauth error: {cb['error']}: {desc}".strip())
|
||
|
||
if not cb["code"]:
|
||
raise ValueError("callback url missing ?code=")
|
||
if not cb["state"]:
|
||
raise ValueError("callback url missing ?state=")
|
||
if cb["state"] != expected_state:
|
||
raise ValueError("state mismatch")
|
||
|
||
token_resp = _post_form(
|
||
token_url,
|
||
{
|
||
"grant_type": "authorization_code",
|
||
"client_id": client_id,
|
||
"code": cb["code"],
|
||
"redirect_uri": redirect_uri,
|
||
"code_verifier": code_verifier,
|
||
},
|
||
)
|
||
|
||
access_token = (token_resp.get("access_token") or "").strip()
|
||
refresh_token = (token_resp.get("refresh_token") or "").strip()
|
||
id_token = (token_resp.get("id_token") or "").strip()
|
||
expires_in = _to_int(token_resp.get("expires_in"))
|
||
|
||
claims = _jwt_claims_no_verify(id_token)
|
||
email = str(claims.get("email") or "").strip()
|
||
auth_claims = claims.get("https://api.openai.com/auth") or {}
|
||
account_id = str(auth_claims.get("chatgpt_account_id") or "").strip()
|
||
|
||
now = int(time.time())
|
||
expired_rfc3339 = time.strftime(
|
||
"%Y-%m-%dT%H:%M:%SZ", time.gmtime(now + max(expires_in, 0))
|
||
)
|
||
now_rfc3339 = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(now))
|
||
|
||
config = {
|
||
"id_token": id_token,
|
||
"access_token": access_token,
|
||
"refresh_token": refresh_token,
|
||
"account_id": account_id,
|
||
"last_refresh": now_rfc3339,
|
||
"email": email,
|
||
"type": "codex",
|
||
"expired": expired_rfc3339,
|
||
}
|
||
|
||
return json.dumps(config, ensure_ascii=False, separators=(",", ":"))
|
||
|
||
|
||
class OAuthManager:
|
||
"""OAuth 管理器"""
|
||
|
||
def __init__(
|
||
self,
|
||
client_id: str = OAUTH_CLIENT_ID,
|
||
auth_url: str = OAUTH_AUTH_URL,
|
||
token_url: str = OAUTH_TOKEN_URL,
|
||
redirect_uri: str = OAUTH_REDIRECT_URI,
|
||
scope: str = OAUTH_SCOPE
|
||
):
|
||
self.client_id = client_id
|
||
self.auth_url = auth_url
|
||
self.token_url = token_url
|
||
self.redirect_uri = redirect_uri
|
||
self.scope = scope
|
||
|
||
def start_oauth(self) -> OAuthStart:
|
||
"""开始 OAuth 流程"""
|
||
return generate_oauth_url(
|
||
redirect_uri=self.redirect_uri,
|
||
scope=self.scope,
|
||
client_id=self.client_id
|
||
)
|
||
|
||
def handle_callback(
|
||
self,
|
||
callback_url: str,
|
||
expected_state: str,
|
||
code_verifier: str
|
||
) -> Dict[str, Any]:
|
||
"""处理 OAuth 回调"""
|
||
result_json = submit_callback_url(
|
||
callback_url=callback_url,
|
||
expected_state=expected_state,
|
||
code_verifier=code_verifier,
|
||
redirect_uri=self.redirect_uri,
|
||
client_id=self.client_id,
|
||
token_url=self.token_url
|
||
)
|
||
return json.loads(result_json)
|
||
|
||
def extract_account_info(self, id_token: str) -> Dict[str, Any]:
|
||
"""从 ID Token 中提取账户信息"""
|
||
claims = _jwt_claims_no_verify(id_token)
|
||
email = str(claims.get("email") or "").strip()
|
||
auth_claims = claims.get("https://api.openai.com/auth") or {}
|
||
account_id = str(auth_claims.get("chatgpt_account_id") or "").strip()
|
||
|
||
return {
|
||
"email": email,
|
||
"account_id": account_id,
|
||
"claims": claims
|
||
} |