FastAPI高级特性与扩展:解锁框架的全部潜力
当你已经掌握了FastAPI的基础和中级功能,接下来就是探索它的高级特性和扩展能力的时候了。这些特性将让你的应用从"能用"升级到"卓越"。
11.1 自定义路由类:超越标准路由的灵活性
为什么需要自定义路由类?
标准的路由器(APIRouter)已经很强大了,但有些场景需要更精细的控制。自定义路由类允许你:
- 统一添加中间件到特定路由组
- 自动注入依赖项
- 实现路由级别的缓存策略
- 自定义路由匹配逻辑
基础自定义路由实现
python
# app/routing/custom_router.py
from fastapi import APIRouter, Depends, Request
from fastapi.routing import APIRoute
from typing import Callable, List, Optional, Dict, Any
import time
import logging
from functools import wraps
logger = logging.getLogger(__name__)
class LoggingRoute(APIRoute):
"""带有详细日志记录的自定义路由"""
def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()
async def custom_route_handler(request: Request) -> Any:
# 请求前日志
start_time = time.time()
logger.info(
f"开始处理请求: {request.method} {request.url.path}",
extra={
"method": request.method,
"path": request.url.path,
"client_ip": request.client.host if request.client else None
}
)
try:
# 处理请求
response = await original_route_handler(request)
duration = time.time() - start_time
# 请求后日志
logger.info(
f"请求处理完成: {request.method} {request.url.path} - {response.status_code}",
extra={
"method": request.method,
"path": request.url.path,
"status_code": response.status_code,
"duration": duration
}
)
# 添加性能头
response.headers["X-Response-Time"] = f"{duration:.3f}s"
return response
except Exception as exc:
duration = time.time() - start_time
logger.error(
f"请求处理失败: {request.method} {request.url.path}",
extra={
"method": request.method,
"path": request.url.path,
"duration": duration,
"error": str(exc)
},
exc_info=True
)
raise
return custom_route_handler
class CacheRoute(APIRoute):
"""支持缓存的自定义路由"""
def __init__(self, *args, cache_ttl: int = 300, **kwargs):
super().__init__(*args, **kwargs)
self.cache_ttl = cache_ttl
def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()
async def cached_route_handler(request: Request) -> Any:
from app.core.cache import RedisCache
# 构建缓存键
cache_key = self._build_cache_key(request)
# 尝试从缓存获取
cache = RedisCache()
cached_response = await cache.get(cache_key)
if cached_response is not None:
# 返回缓存的响应
from fastapi.responses import Response
return Response(
content=cached_response["content"],
status_code=cached_response["status_code"],
headers=cached_response["headers"],
media_type=cached_response["media_type"]
)
# 执行原始处理
response = await original_route_handler(request)
# 缓存响应
if 200 <= response.status_code < 300:
cache_data = {
"content": response.body,
"status_code": response.status_code,
"headers": dict(response.headers),
"media_type": response.media_type
}
await cache.set(cache_key, cache_data, self.cache_ttl)
return response
return cached_route_handler
def _build_cache_key(self, request: Request) -> str:
"""构建缓存键"""
import hashlib
import json
key_parts = [
request.method,
request.url.path,
request.url.query,
json.dumps(dict(request.headers), sort_keys=True)
]
key_string = ":".join(str(p) for p in key_parts)
return f"route_cache:{hashlib.md5(key_string.encode()).hexdigest()}"使用自定义路由类
python
# app/api/advanced.py
from fastapi import APIRouter, FastAPI
from app.routing.custom_router import LoggingRoute, CacheRoute
# 创建支持日志的路由器
logging_router = APIRouter(route_class=LoggingRoute)
@logging_router.get("/logged")
async def logged_endpoint():
"""这个端点会自动记录详细的日志"""
return {"message": "This request is logged"}
# 创建支持缓存的路由器
cache_router = APIRouter(route_class=CacheRoute)
@cache_router.get("/cached", cache_ttl=60)
async def cached_endpoint():
"""这个端点会被缓存60秒"""
import time
time.sleep(1) # 模拟耗时操作
return {"message": f"Cached response at {time.time()}"}
# 高级路由:需要认证的自定义路由
from fastapi import HTTPException, Security
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
security = HTTPBearer()
class AuthenticatedRoute(APIRoute):
"""需要认证的自定义路由"""
def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()
async def authenticated_route_handler(
request: Request,
credentials: HTTPAuthorizationCredentials = Security(security)
) -> Any:
# 验证token
token = credentials.credentials
# 这里应该是实际的token验证逻辑
if token != "secret-token":
raise HTTPException(
status_code=401,
detail="Invalid authentication credentials"
)
# 将用户信息添加到请求状态
request.state.user = {"id": 1, "username": "admin"}
return await original_route_handler(request)
return authenticated_route_handler
# 创建应用并注册路由
app = FastAPI()
# 动态路由注册
def register_routes_with_prefix(router: APIRouter, prefix: str):
"""为路由自动添加前缀的装饰器"""
for route in router.routes:
route.path = prefix + route.path
return router
# 使用示例
v1_router = APIRouter()
v1_router.include_router(logging_router, prefix="/v1")
v1_router.include_router(cache_router, prefix="/v1")
app.include_router(v1_router)11.2 响应编码与内容协商
内容协商基础
内容协商允许客户端和服务器就响应的格式进行协商。FastAPI原生支持通过Accept头进行内容协商。
python
# app/api/content_negotiation.py
from fastapi import APIRouter, Request
from fastapi.responses import (
Response,
JSONResponse,
HTMLResponse,
PlainTextResponse,
ORJSONResponse,
UJSONResponse
)
from typing import Optional, Dict, Any
import json
import xml.etree.ElementTree as ET
from xml.dom import minidom
router = APIRouter()
# 基础内容协商
@router.get("/negotiate")
async def negotiate_content(request: Request):
"""根据Accept头返回不同格式的响应"""
accept_header = request.headers.get("Accept", "")
data = {
"message": "Hello World",
"timestamp": "2024-01-15T10:30:00Z",
"version": "1.0"
}
if "application/xml" in accept_header:
# XML响应
root = ET.Element("response")
for key, value in data.items():
child = ET.SubElement(root, key)
child.text = str(value)
xml_str = ET.tostring(root, encoding="unicode")
return Response(
content=xml_str,
media_type="application/xml"
)
elif "text/plain" in accept_header:
# 纯文本响应
text_content = "\n".join(f"{k}: {v}" for k, v in data.items())
return PlainTextResponse(content=text_content)
elif "text/html" in accept_header:
# HTML响应
html_content = f"""
<!DOCTYPE html>
<html>
<head><title>API Response</title></head>
<body>
<h1>API Response</h1>
<ul>
<li>Message: {data['message']}</li>
<li>Timestamp: {data['timestamp']}</li>
<li>Version: {data['version']}</li>
</ul>
</body>
</html>
"""
return HTMLResponse(content=html_content)
else:
# 默认JSON响应
return JSONResponse(content=data)
# 自定义响应编码器
class CustomEncoder(json.JSONEncoder):
"""自定义JSON编码器"""
def default(self, obj):
import datetime
from decimal import Decimal
from uuid import UUID
if isinstance(obj, datetime.datetime):
return obj.isoformat()
elif isinstance(obj, datetime.date):
return obj.isoformat()
elif isinstance(obj, Decimal):
return float(obj)
elif isinstance(obj, UUID):
return str(obj)
elif hasattr(obj, "__dict__"):
return obj.__dict__
return super().default(obj)
# 使用自定义编码器的响应
@router.get("/custom-json")
async def custom_json_response():
"""使用自定义JSON编码器的响应"""
import datetime
from decimal import Decimal
from uuid import uuid4
data = {
"id": uuid4(),
"name": "Product",
"price": Decimal("99.99"),
"created_at": datetime.datetime.now(),
"metadata": {
"category": "electronics",
"tags": ["new", "popular"]
}
}
# 手动编码
json_str = json.dumps(data, cls=CustomEncoder, indent=2)
return Response(
content=json_str,
media_type="application/json",
headers={"X-Custom-Encoded": "true"}
)高级内容协商策略
python
# app/negotiation/strategies.py
from typing import List, Tuple, Optional
from fastapi.responses import Response
import json
class ContentNegotiator:
"""智能内容协商器"""
def __init__(self):
self.supported_types = [
("application/json", 1.0),
("application/xml", 0.9),
("text/html", 0.8),
("text/plain", 0.7),
("application/yaml", 0.6)
]
def negotiate(self, accept_header: str) -> str:
"""协商最佳内容类型"""
if not accept_header:
return "application/json"
# 解析Accept头
client_preferences = self._parse_accept_header(accept_header)
# 找到最佳匹配
best_match = None
best_score = -1.0
for client_type, client_q in client_preferences:
for server_type, server_q in self.supported_types:
if self._type_matches(client_type, server_type):
score = client_q * server_q
if score > best_score:
best_score = score
best_match = server_type
return best_match or "application/json"
def _parse_accept_header(self, header: str) -> List[Tuple[str, float]]:
"""解析Accept头"""
result = []
for part in header.split(","):
part = part.strip()
if ";" in part:
type_part, q_part = part.split(";", 1)
type_part = type_part.strip()
# 提取q值
q_value = 1.0
for param in q_part.split(";"):
param = param.strip()
if param.startswith("q="):
try:
q_value = float(param[2:])
except ValueError:
q_value = 1.0
result.append((type_part, q_value))
else:
result.append((part.strip(), 1.0))
# 按q值降序排序
result.sort(key=lambda x: x[1], reverse=True)
return result
def _type_matches(self, client_type: str, server_type: str) -> bool:
"""检查类型是否匹配"""
if client_type == "*/*" or server_type == "*/*":
return True
if "/" in client_type and "/" in server_type:
c_main, c_sub = client_type.split("/")
s_main, s_sub = server_type.split("/")
if c_sub == "*" or s_sub == "*":
return c_main == s_main
else:
return client_type == server_type
return client_type == server_type
# 使用智能协商器的端点
@router.get("/smart-negotiate")
async def smart_negotiate(request: Request):
"""智能内容协商端点"""
negotiator = ContentNegotiator()
best_type = negotiator.negotiate(request.headers.get("Accept", ""))
data = {
"message": "Smart content negotiation",
"negotiated_type": best_type,
"available_types": [t for t, _ in negotiator.supported_types]
}
if best_type == "application/xml":
# 转换为XML
import xml.etree.ElementTree as ET
root = ET.Element("response")
for key, value in data.items():
child = ET.SubElement(root, key)
child.text = str(value)
xml_str = ET.tostring(root, encoding="unicode")
return Response(content=xml_str, media_type=best_type)
elif best_type == "text/yaml":
# 转换为YAML
import yaml
yaml_str = yaml.dump(data, default_flow_style=False)
return Response(content=yaml_str, media_type=best_type)
else:
# 默认JSON
return Response(
content=json.dumps(data, indent=2),
media_type="application/json"
)流式响应和服务器推送
python
# app/api/streaming.py
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
import asyncio
import json
from typing import AsyncGenerator
router = APIRouter()
# SSE (Server-Sent Events) 示例
@router.get("/sse")
async def server_sent_events():
"""服务器推送事件流"""
async def event_generator():
for i in range(10):
# 生成事件数据
event_data = {
"id": i,
"event": "message",
"data": f"Event {i} at {time.time()}",
"retry": 3000 # 重连时间
}
# SSE格式: "event: {event}\ndata: {data}\n\n"
yield f"id: {event_data['id']}\n"
yield f"event: {event_data['event']}\n"
yield f"data: {json.dumps(event_data['data'])}\n"
yield f"retry: {event_data['retry']}\n\n"
await asyncio.sleep(1) # 每秒发送一个事件
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用Nginx缓冲
}
)
# 大文件流式下载
@router.get("/stream-large-file")
async def stream_large_file():
"""流式下载大文件"""
async def file_chunk_generator():
chunk_size = 1024 * 1024 # 1MB chunks
# 模拟生成大文件
for i in range(100): # 100MB文件
chunk = b"x" * chunk_size
yield chunk
await asyncio.sleep(0.01) # 稍微延迟,模拟I/O
headers = {
"Content-Disposition": "attachment; filename=large_file.bin",
"Content-Length": str(100 * 1024 * 1024) # 100MB
}
return StreamingResponse(
file_chunk_generator(),
media_type="application/octet-stream",
headers=headers
)
# 实时日志流
@router.get("/logs/stream")
async def stream_logs():
"""实时日志流"""
import sys
import io
class TeeStream(io.StringIO):
"""同时写入到多个流的类"""
def __init__(self, *streams):
super().__init__()
self.streams = streams
def write(self, text):
super().write(text)
for stream in self.streams:
stream.write(text)
return len(text)
# 创建双向流
async def log_stream():
import logging
# 设置日志流
log_capture_string = io.StringIO()
# 创建日志处理器
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.INFO)
# 创建格式化器
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
# 添加到根日志器
logging.getLogger().addHandler(ch)
try:
# 持续读取日志
while True:
content = log_capture_string.getvalue()
if content:
yield content
log_capture_string.truncate(0)
log_capture_string.seek(0)
await asyncio.sleep(0.1)
finally:
# 清理
logging.getLogger().removeHandler(ch)
return StreamingResponse(
log_stream(),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"X-Content-Type-Options": "nosniff"
}
)11.3 WebHooks实现:构建事件驱动的API
WebHooks基础架构
python
# app/webhooks/core.py
from typing import Dict, List, Optional, Callable, Any
from pydantic import BaseModel, HttpUrl
from datetime import datetime
import asyncio
import hashlib
import hmac
import json
from enum import Enum
class WebhookEvent(str, Enum):
"""Webhook事件类型"""
USER_CREATED = "user.created"
USER_UPDATED = "user.updated"
ORDER_CREATED = "order.created"
ORDER_UPDATED = "order.updated"
PAYMENT_SUCCESS = "payment.success"
PAYMENT_FAILED = "payment.failed"
class WebhookSubscription(BaseModel):
"""Webhook订阅模型"""
id: str
url: HttpUrl
events: List[WebhookEvent]
secret: Optional[str] = None
enabled: bool = True
created_at: datetime
updated_at: datetime
retry_count: int = 0
last_delivery: Optional[datetime] = None
class WebhookPayload(BaseModel):
"""Webhook负载"""
event: WebhookEvent
data: Dict[str, Any]
timestamp: datetime
webhook_id: str
class WebhookManager:
"""Webhook管理器"""
def __init__(self):
self.subscriptions: Dict[str, WebhookSubscription] = {}
self.queues: Dict[str, asyncio.Queue] = {}
self.workers: Dict[str, asyncio.Task] = {}
async def subscribe(self, subscription: WebhookSubscription):
"""添加订阅"""
self.subscriptions[subscription.id] = subscription
# 为每个订阅创建处理队列和工作线程
if subscription.id not in self.queues:
self.queues[subscription.id] = asyncio.Queue(maxsize=1000)
self.workers[subscription.id] = asyncio.create_task(
self._process_webhook_queue(subscription.id)
)
async def unsubscribe(self, webhook_id: str):
"""取消订阅"""
if webhook_id in self.subscriptions:
del self.subscriptions[webhook_id]
# 停止工作线程
if webhook_id in self.workers:
self.workers[webhook_id].cancel()
del self.workers[webhook_id]
del self.queues[webhook_id]
async def trigger(self, event: WebhookEvent, data: Dict[str, Any]):
"""触发Webhook事件"""
payload = WebhookPayload(
event=event,
data=data,
timestamp=datetime.now(),
webhook_id=self._generate_webhook_id()
)
# 找到订阅了此事件的所有Webhook
for subscription in self.subscriptions.values():
if event in subscription.events and subscription.enabled:
await self.queues[subscription.id].put(payload)
async def _process_webhook_queue(self, webhook_id: str):
"""处理Webhook队列"""
import httpx
while True:
try:
payload = await self.queues[webhook_id].get()
subscription = self.subscriptions[webhook_id]
# 发送Webhook
success = await self._deliver_webhook(
subscription,
payload
)
if success:
# 更新最后发送时间
subscription.last_delivery = datetime.now()
subscription.retry_count = 0
else:
# 失败重试
await self._retry_delivery(subscription, payload)
self.queues[webhook_id].task_done()
except asyncio.CancelledError:
break
except Exception as e:
print(f"Webhook processing error: {e}")
await asyncio.sleep(1)
async def _deliver_webhook(
self,
subscription: WebhookSubscription,
payload: WebhookPayload
) -> bool:
"""发送Webhook"""
import httpx
try:
async with httpx.AsyncClient(timeout=30.0) as client:
# 构建请求头
headers = {
"User-Agent": "FastAPI-Webhooks/1.0",
"Content-Type": "application/json",
"X-Webhook-Event": payload.event,
"X-Webhook-ID": payload.webhook_id,
"X-Webhook-Timestamp": payload.timestamp.isoformat()
}
# 添加签名
if subscription.secret:
signature = self._generate_signature(
subscription.secret,
payload.json()
)
headers["X-Webhook-Signature"] = signature
# 发送请求
response = await client.post(
str(subscription.url),
json=payload.dict(),
headers=headers
)
# 检查响应
return 200 <= response.status_code < 300
except Exception as e:
print(f"Webhook delivery failed: {e}")
return False
def _generate_signature(self, secret: str, payload: str) -> str:
"""生成Webhook签名"""
return hmac.new(
secret.encode(),
payload.encode(),
hashlib.sha256
).hexdigest()
def _generate_webhook_id(self) -> str:
"""生成Webhook ID"""
import uuid
return str(uuid.uuid4())
async def _retry_delivery(
self,
subscription: WebhookSubscription,
payload: WebhookPayload
):
"""重试发送"""
max_retries = 3
retry_delays = [1, 5, 15] # 秒
if subscription.retry_count < max_retries:
subscription.retry_count += 1
delay = retry_delays[subscription.retry_count - 1]
await asyncio.sleep(delay)
await self.queues[subscription.id].put(payload)WebHooks API端点
python
# app/api/webhooks.py
from fastapi import APIRouter, HTTPException, BackgroundTasks
from typing import List
import uuid
from datetime import datetime
router = APIRouter()
webhook_manager = WebhookManager()
# Webhook注册端点
@router.post("/webhooks", response_model=WebhookSubscription)
async def create_webhook(
url: str,
events: List[WebhookEvent],
secret: str = None
):
"""注册Webhook"""
webhook_id = str(uuid.uuid4())
subscription = WebhookSubscription(
id=webhook_id,
url=url,
events=events,
secret=secret,
created_at=datetime.now(),
updated_at=datetime.now(),
enabled=True
)
await webhook_manager.subscribe(subscription)
return subscription
@router.get("/webhooks", response_model=List[WebhookSubscription])
async def list_webhooks():
"""列出所有Webhook"""
return list(webhook_manager.subscriptions.values())
@router.delete("/webhooks/{webhook_id}")
async def delete_webhook(webhook_id: str):
"""删除Webhook"""
await webhook_manager.unsubscribe(webhook_id)
return {"message": "Webhook deleted"}
# Webhook测试端点
@router.post("/webhooks/{webhook_id}/test")
async def test_webhook(webhook_id: str):
"""测试Webhook"""
if webhook_id not in webhook_manager.subscriptions:
raise HTTPException(status_code=404, detail="Webhook not found")
# 触发测试事件
await webhook_manager.trigger(
WebhookEvent.USER_CREATED,
{
"user_id": 1,
"username": "test_user",
"email": "test@example.com",
"test": True
}
)
return {"message": "Test webhook triggered"}
# Webhook接收端点(用于接收其他服务的Webhook)
@router.post("/webhooks/receive")
async def receive_webhook(
background_tasks: BackgroundTasks,
x_webhook_signature: str = None,
x_webhook_event: str = None
):
"""接收外部Webhook"""
from fastapi import Request
import json
async def verify_and_process(request: Request):
body = await request.body()
# 验证签名(如果有)
if x_webhook_signature:
# 这里应该验证签名
pass
# 处理Webhook
data = json.loads(body)
# 根据事件类型处理
event_type = x_webhook_event or data.get("event")
if event_type == "payment.success":
# 处理支付成功
pass
elif event_type == "payment.failed":
# 处理支付失败
pass
return {"received": True}
# 在后台处理Webhook
background_tasks.add_task(verify_and_process)
return {"message": "Webhook received"}11.4 GraphQL集成:拥抱现代API查询语言
使用Strawberry集成GraphQL
python
# app/graphql/schema.py
import strawberry
from strawberry.fastapi import GraphQLRouter
from typing import List, Optional, Dict, Any
from datetime import datetime
from pydantic import BaseModel
# GraphQL类型定义
@strawberry.type
class User:
id: strawberry.ID
username: str
email: str
created_at: datetime
is_active: bool = True
@strawberry.field
def posts(self, limit: Optional[int] = 10) -> List["Post"]:
"""获取用户的文章"""
# 这里应该是数据库查询
from .resolvers import get_user_posts
return get_user_posts(self.id, limit)
@strawberry.type
class Post:
id: strawberry.ID
title: str
content: str
author: User
created_at: datetime
updated_at: Optional[datetime] = None
@strawberry.field
def excerpt(self, length: Optional[int] = 100) -> str:
"""文章摘要"""
if len(self.content) <= length:
return self.content
return self.content[:length] + "..."
@strawberry.type
class Comment:
id: strawberry.ID
content: str
author: User
post: Post
created_at: datetime
# 输入类型(用于Mutation)
@strawberry.input
class UserInput:
username: str
email: str
password: str
@strawberry.input
class PostInput:
title: str
content: str
author_id: strawberry.ID
# 查询类型
@strawberry.type
class Query:
@strawberry.field
async def user(self, id: strawberry.ID) -> Optional[User]:
"""获取单个用户"""
from .resolvers import get_user_by_id
return await get_user_by_id(id)
@strawberry.field
async def users(
self,
skip: Optional[int] = 0,
limit: Optional[int] = 10,
search: Optional[str] = None
) -> List[User]:
"""获取用户列表"""
from .resolvers import get_users
return await get_users(skip, limit, search)
@strawberry.field
async def post(self, id: strawberry.ID) -> Optional[Post]:
"""获取单个文章"""
from .resolvers import get_post_by_id
return await get_post_by_id(id)
@strawberry.field
async def posts(
self,
skip: Optional[int] = 0,
limit: Optional[int] = 10,
author_id: Optional[strawberry.ID] = None
) -> List[Post]:
"""获取文章列表"""
from .resolvers import get_posts
return await get_posts(skip, limit, author_id)
# 变更类型
@strawberry.type
class Mutation:
@strawberry.mutation
async def create_user(self, user_input: UserInput) -> User:
"""创建用户"""
from .resolvers import create_user
return await create_user(user_input)
@strawberry.mutation
async def update_user(
self,
id: strawberry.ID,
username: Optional[str] = None,
email: Optional[str] = None
) -> User:
"""更新用户"""
from .resolvers import update_user
return await update_user(id, username, email)
@strawberry.mutation
async def create_post(self, post_input: PostInput) -> Post:
"""创建文章"""
from .resolvers import create_post
return await create_post(post_input)
@strawberry.mutation
async def delete_post(self, id: strawberry.ID) -> bool:
"""删除文章"""
from .resolvers import delete_post
return await delete_post(id)
# 创建Schema
schema = strawberry.Schema(query=Query, mutation=Mutation)
# 创建GraphQL路由
graphql_router = GraphQLRouter(
schema,
graphiql=True, # 启用GraphiQL界面
allow_queries_via_get=True # 允许通过GET请求查询
)
# 解析器实现
# app/graphql/resolvers.py
from typing import List, Optional
import asyncio
from datetime import datetime
import uuid
# 模拟数据库
users_db = {}
posts_db = {}
async def get_user_by_id(id: str) -> Optional[User]:
"""根据ID获取用户"""
await asyncio.sleep(0.1) # 模拟数据库延迟
return users_db.get(id)
async def get_users(skip: int = 0, limit: int = 10, search: str = None) -> List[User]:
"""获取用户列表"""
await asyncio.sleep(0.1)
users = list(users_db.values())
if search:
users = [u for u in users if search.lower() in u.username.lower()]
return users[skip:skip+limit]
async def create_user(user_input: UserInput) -> User:
"""创建用户"""
user_id = str(uuid.uuid4())
user = User(
id=user_id,
username=user_input.username,
email=user_input.email,
created_at=datetime.now()
)
users_db[user_id] = user
return user
async def get_post_by_id(id: str) -> Optional[Post]:
"""根据ID获取文章"""
await asyncio.sleep(0.1)
return posts_db.get(id)
async def get_posts(skip: int = 0, limit: int = 10, author_id: str = None) -> List[Post]:
"""获取文章列表"""
await asyncio.sleep(0.1)
posts = list(posts_db.values())
if author_id:
posts = [p for p in posts if p.author.id == author_id]
return posts[skip:skip+limit]
async def create_post(post_input: PostInput) -> Post:
"""创建文章"""
post_id = str(uuid.uuid4())
author = users_db.get(post_input.author_id)
if not author:
raise ValueError("Author not found")
post = Post(
id=post_id,
title=post_input.title,
content=post_input.content,
author=author,
created_at=datetime.now()
)
posts_db[post_id] = post
return post高级GraphQL特性
python
# app/graphql/advanced.py
import strawberry
from typing import Optional, Annotated
from strawberry.types import Info
from strawberry.permission import BasePermission
from strawberry.extensions import Extension
# GraphQL权限控制
class IsAuthenticated(BasePermission):
"""认证检查权限"""
message = "User is not authenticated"
def has_permission(self, source: Any, info: Info, **kwargs) -> bool:
# 从context中获取用户
user = info.context.get("user")
return user is not None
class IsAdmin(BasePermission):
"""管理员权限"""
message = "User is not admin"
def has_permission(self, source: Any, info: Info, **kwargs) -> bool:
user = info.context.get("user")
return user and user.get("role") == "admin"
@strawberry.type
class AdminQuery:
"""管理员查询"""
@strawberry.field(permission_classes=[IsAdmin])
async def all_users(self) -> List[User]:
"""获取所有用户(仅管理员)"""
from .resolvers import get_all_users
return await get_all_users()
@strawberry.field(permission_classes=[IsAdmin])
async def system_stats(self) -> Dict[str, Any]:
"""系统统计(仅管理员)"""
return {
"total_users": len(users_db),
"total_posts": len(posts_db),
"active_users": sum(1 for u in users_db.values() if u.is_active)
}
# GraphQL扩展(用于监控、日志等)
class QueryLoggingExtension(Extension):
"""查询日志扩展"""
def on_operation(self):
# 操作开始时
start_time = time.time()
self.execution_context.context["start_time"] = start_time
yield
# 操作结束时
end_time = time.time()
duration = end_time - start_time
# 记录日志
query = self.execution_context.query
variables = self.execution_context.variables
logger.info(
f"GraphQL query executed in {duration:.3f}s",
extra={
"query": query,
"variables": variables,
"duration": duration
}
)
class RateLimitExtension(Extension):
"""速率限制扩展"""
def __init__(self, max_operations: int = 100, window_seconds: int = 60):
self.max_operations = max_operations
self.window_seconds = window_seconds
self.operation_counts = {}
def on_operation(self):
client_ip = self.execution_context.context.get("client_ip")
if not client_ip:
yield
return
# 检查速率限制
current_time = time.time()
window_start = current_time - self.window_seconds
# 清理过期的记录
self.operation_counts[client_ip] = [
ts for ts in self.operation_counts.get(client_ip, [])
if ts > window_start
]
# 检查是否超过限制
if len(self.operation_counts[client_ip]) >= self.max_operations:
raise Exception("Rate limit exceeded")
# 记录操作
self.operation_counts[client_ip].append(current_time)
yield
# 联合类型和接口
@strawberry.interface
class SearchResult:
"""搜索结果接口"""
id: strawberry.ID
title: str
relevance: float
@strawberry.type
class UserSearchResult(SearchResult):
"""用户搜索结果"""
username: str
email: str
@strawberry.type
class PostSearchResult(SearchResult):
"""文章搜索结果"""
excerpt: str
author_username: str
@strawberry.type
class SearchQuery:
"""搜索查询"""
@strawberry.field
async def search(
self,
query: str,
limit: Optional[int] = 10
) -> List[SearchResult]:
"""搜索"""
results = []
# 搜索用户
for user in users_db.values():
if query.lower() in user.username.lower():
results.append(
UserSearchResult(
id=user.id,
title=user.username,
relevance=0.8,
username=user.username,
email=user.email
)
)
# 搜索文章
for post in posts_db.values():
if query.lower() in post.title.lower() or query.lower() in post.content.lower():
results.append(
PostSearchResult(
id=post.id,
title=post.title,
relevance=0.6,
excerpt=post.content[:100],
author_username=post.author.username
)
)
# 按相关性排序
results.sort(key=lambda x: x.relevance, reverse=True)
return results[:limit]
# 创建包含高级特性的Schema
advanced_schema = strawberry.Schema(
query=strawberry.type(
"AdvancedQuery",
(Query, AdminQuery, SearchQuery)
),
mutation=Mutation,
extensions=[
QueryLoggingExtension,
lambda: RateLimitExtension(max_operations=60)
],
types=[User, Post, Comment, UserSearchResult, PostSearchResult]
)11.5 自定义OpenAPI文档:打造专属API文档体验
自定义文档配置
python
# app/docs/customization.py
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
from typing import Dict, Any, Optional
def custom_openapi(app: FastAPI) -> Dict[str, Any]:
"""自定义OpenAPI文档"""
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title="My Awesome API",
version="2.0.0",
description="""
# My Awesome API
## 特性
- 🔥 高性能
- 🔐 安全认证
- 📊 实时监控
- 🔄 Webhook支持
## 快速开始
1. 获取API密钥
2. 使用Bearer认证
3. 开始调用API
[查看完整文档](https://docs.example.com)
""",
routes=app.routes,
)
# 自定义服务器
openapi_schema["servers"] = [
{
"url": "https://api.example.com",
"description": "生产环境"
},
{
"url": "https://staging-api.example.com",
"description": "测试环境"
},
{
"url": "http://localhost:8000",
"description": "开发环境"
}
]
# 添加安全方案
openapi_schema["components"]["securitySchemes"] = {
"BearerAuth": {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT",
"description": "输入你的JWT令牌,格式: Bearer <token>"
},
"ApiKeyAuth": {
"type": "apiKey",
"in": "header",
"name": "X-API-Key",
"description": "API密钥认证"
}
}
# 添加标签元数据
openapi_schema["tags"] = [
{
"name": "users",
"description": "用户管理操作",
"externalDocs": {
"description": "了解更多",
"url": "https://docs.example.com/users"
}
},
{
"name": "products",
"description": "商品管理操作",
"externalDocs": {
"description": "商品API文档",
"url": "https://docs.example.com/products"
}
},
{
"name": "orders",
"description": "订单管理操作"
}
]
# 添加扩展
openapi_schema["x-logo"] = {
"url": "https://fastapi.tiangolo.com/img/logo-margin/logo-teal.png",
"backgroundColor": "#FFFFFF",
"altText": "API Logo"
}
openapi_schema["x-tagGroups"] = [
{
"name": "核心API",
"tags": ["users", "products", "orders"]
},
{
"name": "扩展API",
"tags": ["webhooks", "graphql", "files"]
}
]
app.openapi_schema = openapi_schema
return app.openapi_schema
# 自定义Swagger UI
def get_custom_swagger_ui_html(
*,
openapi_url: str,
title: str,
swagger_js_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js",
swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css",
swagger_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png",
oauth2_redirect_url: Optional[str] = None,
init_oauth: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> HTMLResponse:
"""自定义Swagger UI界面"""
html = f"""
<!DOCTYPE html>
<html>
<head>
<link type="text/css" rel="stylesheet" href="{swagger_css_url}">
<link rel="shortcut icon" href="{swagger_favicon_url}">
<title>{title}</title>
<style>
html {{
box-sizing: border-box;
overflow: -moz-scrollbars-vertical;
overflow-y: scroll;
}}
*,
*:before,
*:after {{
box-sizing: inherit;
}}
body {{
margin: 0;
background: #fafafa;
}}
.topbar {{
background-color: #1e1e1e;
padding: 10px 0;
text-align: center;
color: white;
font-size: 1.2em;
}}
.topbar a {{
color: #61dafb;
text-decoration: none;
}}
.version {{
background-color: #61dafb;
color: #1e1e1e;
padding: 2px 8px;
border-radius: 10px;
font-size: 0.8em;
margin-left: 10px;
}}
</style>
</head>
<body>
<div class="topbar">
<span>{title} <span class="version">v2.0</span></span>
<span style="margin-left: 20px;">
<a href="/docs">Swagger UI</a> |
<a href="/redoc">ReDoc</a> |
<a href="/graphql">GraphQL</a>
</span>
</div>
<div id="swagger-ui"></div>
<script src="{swagger_js_url}"></script>
<!-- `SwaggerUIBundle` is now available on the page -->
<script>
const ui = SwaggerUIBundle({{
url: '{openapi_url}',
dom_id: '#swagger-ui',
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
deepLinking: true,
showExtensions: true,
showCommonExtensions: true,
oauth2RedirectUrl: '{oauth2_redirect_url or ''}',
initOAuth: {json.dumps(init_oauth) if init_oauth else '{}'},
onComplete: function() {{
// 自定义完成回调
console.log("Swagger UI loaded");
// 添加自定义CSS
const style = document.createElement('style');
style.innerHTML = `
.opblock-tag {{
font-size: 16px;
margin: 0 0 10px;
font-family: sans-serif;
}}
.opblock-tag-section {{
background-color: #f8f9fa;
border-radius: 8px;
padding: 15px;
margin-bottom: 20px;
}}
`;
document.head.appendChild(style);
// 添加自定义按钮
const tryOutBtn = document.createElement('button');
tryOutBtn.innerHTML = '🚀 快速测试';
tryOutBtn.className = 'btn try-out__btn';
tryOutBtn.style.marginLeft = '10px';
tryOutBtn.onclick = function() {{
// 自动填充示例数据
const inputs = document.querySelectorAll('input[type="text"]');
inputs.forEach(input => {{
if (input.placeholder.includes('ID')) {{
input.value = '123';
}} else if (input.placeholder.includes('name')) {{
input.value = '示例名称';
}} else if (input.placeholder.includes('email')) {{
input.value = 'example@test.com';
}}
}});
}};
// 找到合适的位置添加按钮
const executeBtn = document.querySelector('.execute');
if (executeBtn) {{
executeBtn.parentNode.insertBefore(tryOutBtn, executeBtn);
}}
}}
}});
// 自定义主题
ui.initOAuth({{
clientId: 'your-client-id',
clientSecret: 'your-client-secret-if-required',
realm: 'your-realms',
appName: 'Swagger UI',
scopeSeparator: ' ',
additionalQueryStringParams: {{}},
useBasicAuthenticationWithAccessCodeGrant: false,
usePkceWithAuthorizationCodeGrant: false,
}});
</script>
</body>
</html>
"""
return HTMLResponse(html)API文档生成与导出
python
# app/docs/generator.py
from typing import Dict, Any, List
import json
import yaml
from datetime import datetime
from pathlib import Path
class APIDocumentationGenerator:
"""API文档生成器"""
def __init__(self, app: FastAPI):
self.app = app
self.schema = custom_openapi(app)
def generate_markdown(self) -> str:
"""生成Markdown格式的文档"""
md_lines = []
md_lines.append(f"# {self.schema['info']['title']}")
md_lines.append(f"\n{self.schema['info']['description']}\n")
# 服务器信息
md_lines.append("## 服务器")
for server in self.schema.get("servers", []):
md_lines.append(f"- **{server['description']}**: `{server['url']}`")
# 安全认证
md_lines.append("\n## 认证")
if "securitySchemes" in self.schema.get("components", {}):
for name, scheme in self.schema["components"]["securitySchemes"].items():
md_lines.append(f"\n### {name}")
md_lines.append(f"- 类型: {scheme['type']}")
md_lines.append(f"- 描述: {scheme.get('description', '')}")
# 标签分组
md_lines.append("\n## API端点")
# 按标签分组
endpoints_by_tag = {}
for path, methods in self.schema["paths"].items():
for method, details in methods.items():
tags = details.get("tags", ["default"])
for tag in tags:
if tag not in endpoints_by_tag:
endpoints_by_tag[tag] = []
endpoints_by_tag[tag].append({
"path": path,
"method": method.upper(),
"details": details
})
# 生成每个标签的文档
for tag, endpoints in endpoints_by_tag.items():
md_lines.append(f"\n### {tag}")
for endpoint in endpoints:
details = endpoint["details"]
md_lines.append(f"\n#### {endpoint['method']} {endpoint['path']}")
md_lines.append(f"\n{details.get('summary', '')}")
if "description" in details:
md_lines.append(f"\n{details['description']}")
# 参数
if "parameters" in details:
md_lines.append("\n**参数:**")
md_lines.append("| 名称 | 位置 | 类型 | 必填 | 描述 |")
md_lines.append("|------|------|------|------|------|")
for param in details["parameters"]:
md_lines.append(
f"| {param['name']} | {param['in']} | "
f"{param.get('schema', {}).get('type', 'string')} | "
f"{'是' if param.get('required', False) else '否'} | "
f"{param.get('description', '')} |"
)
# 请求体
if "requestBody" in details:
md_lines.append("\n**请求体:**")
content = details["requestBody"]["content"]
for media_type, schema_info in content.items():
md_lines.append(f"\n*{media_type}*:")
schema = schema_info.get("schema", {})
if "$ref" in schema:
ref_name = schema["$ref"].split("/")[-1]
md_lines.append(f" 使用: `{ref_name}`")
# 响应
if "responses" in details:
md_lines.append("\n**响应:**")
for status_code, response_info in details["responses"].items():
description = response_info.get("description", "")
md_lines.append(f"\n- **{status_code}**: {description}")
return "\n".join(md_lines)
def generate_postman_collection(self) -> Dict[str, Any]:
"""生成Postman集合"""
collection = {
"info": {
"name": self.schema["info"]["title"],
"description": self.schema["info"]["description"],
"schema": "https://schema.getpostman.com/json/collection/v2.1.0/collection.json"
},
"item": []
}
# 按标签分组
endpoints_by_tag = {}
for path, methods in self.schema["paths"].items():
for method, details in methods.items():
tags = details.get("tags", ["default"])
for tag in tags:
if tag not in endpoints_by_tag:
endpoints_by_tag[tag] = []
endpoints_by_tag[tag].append({
"path": path,
"method": method.upper(),
"details": details
})
# 创建文件夹结构
for tag, endpoints in endpoints_by_tag.items():
folder = {
"name": tag,
"item": []
}
for endpoint in endpoints:
item = {
"name": endpoint["details"].get("summary", endpoint["path"]),
"request": {
"method": endpoint["method"],
"url": {
"raw": f"{{{{base_url}}}}{endpoint['path']}",
"host": ["{{base_url}}"],
"path": endpoint["path"].strip("/").split("/")
}
}
}
# 添加认证
security = endpoint["details"].get("security", [])
if security:
item["request"]["auth"] = {
"type": "bearer",
"bearer": [{"key": "token", "value": "{{api_token}}", "type": "string"}]
}
# 添加参数
if "parameters" in endpoint["details"]:
query_params = []
path_vars = []
for param in endpoint["details"]["parameters"]:
if param["in"] == "query":
query_params.append({
"key": param["name"],
"value": "",
"description": param.get("description", ""),
"disabled": not param.get("required", False)
})
elif param["in"] == "path":
path_vars.append({
"key": param["name"],
"value": "example",
"description": param.get("description", "")
})
if query_params:
item["request"]["url"]["query"] = query_params
# 更新路径变量
if path_vars:
for path_var in path_vars:
item["request"]["url"]["path"] = [
segment.replace(
f"{{{path_var['key']}}}",
path_var["value"]
)
for segment in item["request"]["url"]["path"]
]
folder["item"].append(item)
collection["item"].append(folder)
return collection
def export_all_formats(self, output_dir: str = "./docs"):
"""导出所有格式的文档"""
Path(output_dir).mkdir(exist_ok=True)
# 导出OpenAPI JSON
json_path = Path(output_dir) / "openapi.json"
with open(json_path, "w", encoding="utf-8") as f:
json.dump(self.schema, f, indent=2, ensure_ascii=False)
# 导出OpenAPI YAML
yaml_path = Path(output_dir) / "openapi.yaml"
with open(yaml_path, "w", encoding="utf-8") as f:
yaml.dump(self.schema, f, allow_unicode=True)
# 导出Markdown
md_path = Path(output_dir) / "api_documentation.md"
with open(md_path, "w", encoding="utf-8") as f:
f.write(self.generate_markdown())
# 导出Postman集合
postman_path = Path(output_dir) / "postman_collection.json"
with open(postman_path, "w", encoding="utf-8") as f:
json.dump(self.generate_postman_collection(), f, indent=2)
# 导出HTML
html_path = Path(output_dir) / "api_documentation.html"
self._generate_html_documentation(html_path)
return {
"json": str(json_path),
"yaml": str(yaml_path),
"markdown": str(md_path),
"postman": str(postman_path),
"html": str(html_path)
}
def _generate_html_documentation(self, output_path: Path):
"""生成HTML文档"""
import markdown
from markdown.extensions.toc import TocExtension
md_content = self.generate_markdown()
# 转换为HTML
html = markdown.markdown(
md_content,
extensions=[
TocExtension(toc_depth="2-4"),
'fenced_code',
'tables',
'codehilite'
],
output_format='html5'
)
# 添加HTML模板
full_html = f"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{self.schema['info']['title']} - API文档</title>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/github-markdown-css/5.1.0/github-markdown.min.css">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.7.0/styles/github.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.7.0/highlight.min.js"></script>
<script>hljs.highlightAll();</script>
<style>
body {{
margin: 0;
padding: 0;
background: #f6f8fa;
}}
.markdown-body {{
box-sizing: border-box;
min-width: 200px;
max-width: 980px;
margin: 0 auto;
padding: 45px;
}}
@media (max-width: 767px) {{
.markdown-body {{
padding: 15px;
}}
}}
.header {{
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 2rem;
margin-bottom: 2rem;
}}
.header h1 {{
margin: 0;
font-size: 2.5rem;
}}
.header .version {{
background: rgba(255, 255, 255, 0.2);
padding: 0.2rem 0.8rem;
border-radius: 20px;
font-size: 0.9rem;
margin-left: 1rem;
}}
.download-links {{
background: white;
padding: 1rem;
border-radius: 8px;
margin: 1rem 0;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}}
.download-links a {{
margin-right: 1rem;
text-decoration: none;
color: #0366d6;
}}
</style>
</head>
<body>
<div class="header">
<h1>{self.schema['info']['title']} <span class="version">v{self.schema['info']['version']}</span></h1>
<p>{self.schema['info'].get('description', '').split('\\n')[0]}</p>
</div>
<div class="markdown-body">
<div class="download-links">
<strong>下载格式:</strong>
<a href="openapi.json">OpenAPI JSON</a>
<a href="openapi.yaml">OpenAPI YAML</a>
<a href="api_documentation.md">Markdown</a>
<a href="postman_collection.json">Postman</a>
</div>
{html}
<hr>
<footer style="text-align: center; color: #666; margin-top: 3rem; padding-top: 1rem; border-top: 1px solid #eee;">
<p>文档生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
<p>使用 <a href="https://fastapi.tiangolo.com" target="_blank">FastAPI</a> 生成</p>
</footer>
</div>
</body>
</html>
"""
with open(output_path, "w", encoding="utf-8") as f:
f.write(full_html)11.6 插件系统开发:构建可扩展的应用架构
插件系统架构
python
# app/plugins/core.py
from typing import Dict, List, Any, Optional, Callable, Type
from abc import ABC, abstractmethod
from fastapi import FastAPI, APIRouter
from pydantic import BaseModel
import importlib
import inspect
from pathlib import Path
class PluginConfig(BaseModel):
"""插件配置"""
name: str
version: str
enabled: bool = True
dependencies: List[str] = []
settings: Dict[str, Any] = {}
class PluginManifest(BaseModel):
"""插件清单"""
name: str
version: str
description: Optional[str] = None
author: Optional[str] = None
license: Optional[str] = None
dependencies: List[str] = []
routers: List[str] = []
middleware: List[str] = []
events: Dict[str, List[str]] = {}
commands: Dict[str, str] = {}
class BasePlugin(ABC):
"""插件基类"""
def __init__(self, app: FastAPI, config: PluginConfig):
self.app = app
self.config = config
self.router = APIRouter(prefix=f"/plugin/{config.name}")
@abstractmethod
def setup(self):
"""插件安装"""
pass
@abstractmethod
def teardown(self):
"""插件卸载"""
pass
def register_routes(self):
"""注册路由"""
self.app.include_router(self.router)
def register_middleware(self, middleware_class: Type):
"""注册中间件"""
self.app.add_middleware(middleware_class)
def register_event_handler(self, event_type: str, handler: Callable):
"""注册事件处理器"""
self.app.add_event_handler(event_type, handler)
class PluginManager:
"""插件管理器"""
def __init__(self, app: FastAPI):
self.app = app
self.plugins: Dict[str, BasePlugin] = {}
self.loaded_plugins: Dict[str, Any] = {}
# 插件目录
self.plugin_dir = Path("plugins")
self.plugin_dir.mkdir(exist_ok=True)
def load_plugin(self, plugin_name: str) -> bool:
"""加载插件"""
try:
# 导入插件模块
module_name = f"plugins.{plugin_name}"
plugin_module = importlib.import_module(module_name)
# 查找插件类
plugin_class = None
for name, obj in inspect.getmembers(plugin_module):
if (inspect.isclass(obj) and
issubclass(obj, BasePlugin) and
obj != BasePlugin):
plugin_class = obj
break
if not plugin_class:
print(f"未找到插件类: {plugin_name}")
return False
# 读取配置
config_path = self.plugin_dir / plugin_name / "config.json"
if config_path.exists():
import json
config_data = json.loads(config_path.read_text())
config = PluginConfig(**config_data)
else:
config = PluginConfig(name=plugin_name, version="1.0.0")
# 创建插件实例
plugin_instance = plugin_class(self.app, config)
# 检查依赖
for dep in config.dependencies:
if dep not in self.plugins:
print(f"插件 {plugin_name} 缺少依赖: {dep}")
return False
# 安装插件
plugin_instance.setup()
plugin_instance.register_routes()
self.plugins[plugin_name] = plugin_instance
self.loaded_plugins[plugin_name] = plugin_module
print(f"插件加载成功: {plugin_name}")
return True
except Exception as e:
print(f"插件加载失败 {plugin_name}: {e}")
return False
def unload_plugin(self, plugin_name: str) -> bool:
"""卸载插件"""
if plugin_name not in self.plugins:
return False
try:
plugin = self.plugins[plugin_name]
plugin.teardown()
del self.plugins[plugin_name]
del self.loaded_plugins[plugin_name]
print(f"插件卸载成功: {plugin_name}")
return True
except Exception as e:
print(f"插件卸载失败 {plugin_name}: {e}")
return False
def load_all_plugins(self):
"""加载所有插件"""
for plugin_dir in self.plugin_dir.iterdir():
if plugin_dir.is_dir() and (plugin_dir / "__init__.py").exists():
self.load_plugin(plugin_dir.name)
def get_plugin(self, plugin_name: str) -> Optional[BasePlugin]:
"""获取插件实例"""
return self.plugins.get(plugin_name)
def list_plugins(self) -> List[Dict[str, Any]]:
"""列出所有插件"""
result = []
for name, plugin in self.plugins.items():
result.append({
"name": name,
"version": plugin.config.version,
"enabled": plugin.config.enabled,
"dependencies": plugin.config.dependencies
})
return result插件示例:认证插件
python
# plugins/auth_plugin/__init__.py
from app.plugins.core import BasePlugin, PluginConfig
from fastapi import Depends, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from typing import Optional
import jwt
class AuthPlugin(BasePlugin):
"""认证插件"""
def __init__(self, app, config: PluginConfig):
super().__init__(app, config)
self.secret_key = config.settings.get("secret_key", "default-secret-key")
self.security = HTTPBearer()
def setup(self):
"""安装插件"""
print(f"安装认证插件: {self.config.name}")
# 注册路由
@self.router.post("/login")
async def login(username: str, password: str):
# 验证用户
user = self.authenticate_user(username, password)
if not user:
raise HTTPException(status_code=401, detail="Invalid credentials")
# 生成token
token = self.generate_token(user)
return {"access_token": token, "token_type": "bearer"}
@self.router.get("/me")
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(self.security)
):
token = credentials.credentials
user = self.verify_token(token)
if not user:
raise HTTPException(status_code=401, detail="Invalid token")
return user
def teardown(self):
"""卸载插件"""
print(f"卸载认证插件: {self.config.name}")
def authenticate_user(self, username: str, password: str) -> Optional[dict]:
"""验证用户"""
# 这里应该是数据库查询
users = {
"admin": {"id": 1, "username": "admin", "role": "admin"},
"user": {"id": 2, "username": "user", "role": "user"}
}
if username in users and password == "password":
return users[username]
return None
def generate_token(self, user: dict) -> str:
"""生成JWT令牌"""
import time
payload = {
"sub": user["id"],
"username": user["username"],
"role": user["role"],
"iat": int(time.time()),
"exp": int(time.time()) + 3600 # 1小时过期
}
return jwt.encode(payload, self.secret_key, algorithm="HS256")
def verify_token(self, token: str) -> Optional[dict]:
"""验证JWT令牌"""
try:
payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
return {
"id": payload["sub"],
"username": payload["username"],
"role": payload["role"]
}
except jwt.PyJWTError:
return None
# 插件配置
# plugins/auth_plugin/config.json
{
"name": "auth",
"version": "1.0.0",
"description": "用户认证插件",
"enabled": true,
"dependencies": [],
"settings": {
"secret_key": "your-secret-key-here",
"token_expire_hours": 24
}
}插件示例:监控插件
python
# plugins/monitoring_plugin/__init__.py
from app.plugins.core import BasePlugin, PluginConfig
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
import time
import psutil
import os
class MonitoringPlugin(BasePlugin):
"""监控插件"""
def __init__(self, app, config: PluginConfig):
super().__init__(app, config)
self.metrics = {
"request_count": 0,
"error_count": 0,
"total_response_time": 0
}
def setup(self):
"""安装插件"""
print(f"安装监控插件: {self.config.name}")
# 添加监控中间件
@self.app.middleware("http")
async def monitor_middleware(request: Request, call_next):
start_time = time.time()
try:
response = await call_next(request)
duration = time.time() - start_time
# 更新指标
self.metrics["request_count"] += 1
self.metrics["total_response_time"] += duration
# 添加性能头
response.headers["X-Response-Time"] = f"{duration:.3f}s"
return response
except Exception as e:
self.metrics["error_count"] += 1
raise
# 注册监控端点
@self.router.get("/metrics")
async def get_metrics():
"""获取监控指标"""
process = psutil.Process(os.getpid())
return {
"requests": {
"total": self.metrics["request_count"],
"errors": self.metrics["error_count"],
"avg_response_time": (
self.metrics["total_response_time"] /
max(self.metrics["request_count"], 1)
)
},
"system": {
"cpu_percent": process.cpu_percent(),
"memory_percent": process.memory_percent(),
"memory_rss": process.memory_info().rss,
"threads": process.num_threads()
},
"uptime": time.time() - process.create_time()
}
@self.router.get("/health")
async def health_check():
"""健康检查"""
return {
"status": "healthy",
"timestamp": time.time(),
"plugin": self.config.name
}
def teardown(self):
"""卸载插件"""
print(f"卸载监控插件: {self.config.name}")11.7 源码阅读与贡献指南
FastAPI源码结构解析
fastapi/
├── __init__.py # 主要导出
├── applications.py # FastAPI类定义
├── routing.py # 路由相关
├── datastructures.py # 数据结构
├── params.py # 参数处理
├── dependencies.py # 依赖注入系统
├── security.py # 安全相关
├── background.py # 后台任务
├── responses.py # 响应类
├── staticfiles.py # 静态文件
├── templating.py # 模板渲染
├── middleware # 中间件目录
│ ├── __init__.py
│ ├── cors.py
│ └── gzip.py
├── exceptions.py # 异常处理
├── openapi # OpenAPI相关
│ ├── __init__.py
│ ├── docs.py
│ ├── models.py
│ └── utils.py
└── types.py # 类型定义关键源码阅读
1. FastAPI类核心实现
python
# fastapi/applications.py (简化版)
class FastAPI(Starlette):
"""FastAPI主类"""
def __init__(
self,
debug: bool = False,
routes: List[BaseRoute] = None,
**kwargs
):
super().__init__(debug=debug, routes=routes, **kwargs)
# 路由相关的属性
self.router: routing.APIRouter = routing.APIRouter()
self.openapi_version = "3.0.2"
self.openapi_schema: Optional[Dict[str, Any]] = None
# 添加默认文档路由
self.setup_docs()
def setup_docs(self):
"""设置文档路由"""
if self.openapi_url:
self.add_route(
self.openapi_url,
self.openapi,
include_in_schema=False,
)
if self.docs_url:
self.add_route(
self.docs_url,
self.swagger_ui_html,
include_in_schema=False,
)
def get(self, path: str, **kwargs):
"""GET装饰器"""
return self.router.get(path, **kwargs)
def post(self, path: str, **kwargs):
"""POST装饰器"""
return self.router.post(path, **kwargs)
# ... 其他HTTP方法装饰器
def include_router(self, router: routing.APIRouter, **kwargs):
"""包含路由器"""
self.router.include_router(router, **kwargs)2. 依赖注入系统解析
python
# fastapi/dependencies/utils.py (简化版)
async def solve_dependencies(
*,
dependant: Dependant,
body: Optional[Dict[str, Any]] = None,
background_tasks: Optional[BackgroundTasks] = None,
response: Optional[Response] = None,
dependency_overrides: Optional[Dict[Callable, Callable]] = None,
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
"""解决依赖关系"""
values: Dict[str, Any] = {}
errors: List[ErrorWrapper] = []
# 遍历依赖树
for sub_dependant in dependant.dependencies:
sub_values, sub_errors = await solve_dependencies(
dependant=sub_dependant,
body=body,
background_tasks=background_tasks,
response=response,
dependency_overrides=dependency_overrides,
)
if sub_errors:
errors.extend(sub_errors)
continue
# 调用依赖函数
call = sub_dependant.call
if dependency_overrides and call in dependency_overrides:
call = dependency_overrides[call]
try:
# 执行依赖函数
solved = await call(**sub_values)
values.update({sub_dependant.name: solved})
except Exception as e:
errors.append(ErrorWrapper(e, loc=("dependency",)))
return values, errors贡献指南
如何开始贡献
Fork仓库:
bash# 访问 https://github.com/tiangolo/fastapi # 点击右上角的 "Fork" 按钮克隆你的分支:
bashgit clone https://github.com/YOUR_USERNAME/fastapi.git cd fastapi设置开发环境:
bash# 创建虚拟环境 python -m venv venv source venv/bin/activate # Linux/Mac # venv\Scripts\activate # Windows # 安装开发依赖 pip install -e .[dev] pip install pytest pytest-cov运行测试:
bash# 运行所有测试 pytest # 运行特定测试文件 pytest tests/test_main.py # 带覆盖率报告 pytest --cov=fastapi tests/
贡献流程
创建新分支:
bashgit checkout -b feature/your-feature-name编写代码并测试:
python# 确保添加测试用例 # tests/test_your_feature.py def test_your_feature(): from fastapi import FastAPI from fastapi.testclient import TestClient app = FastAPI() @app.get("/") async def root(): return {"message": "Hello World"} client = TestClient(app) response = client.get("/") assert response.status_code == 200代码风格检查:
bash# 格式化代码 black fastapi/ tests/ # 排序imports isort fastapi/ tests/ # 类型检查 mypy fastapi/提交更改:
bashgit add . git commit -m "feat: add your feature description" git push origin feature/your-feature-name创建Pull Request:
- 访问 https://github.com/tiangolo/fastapi
- 点击 "New Pull Request"
- 选择你的分支
- 填写PR描述
贡献规范
提交消息格式:
type(scope): description body footer类型说明:
feat: 新功能fix: bug修复docs: 文档更新style: 代码风格调整refactor: 重构test: 测试相关chore: 构建/工具更新
测试要求:
- 新功能必须包含测试
- 修复bug需要添加回归测试
- 测试覆盖率不应降低
文档要求:
- 新功能需要更新文档
- 公共API需要类型提示和docstring
- 复杂功能需要示例代码
学习资源
官方文档:
社区资源:
相关项目:
实战练习:实现一个简单的特性
任务:为FastAPI添加一个@retry装饰器,当API调用失败时自动重试。
python
# 实现示例
def retry(max_retries: int = 3, delay: float = 1.0):
"""重试装饰器"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
for attempt in range(max_retries + 1):
try:
return await func(*args, **kwargs)
except Exception as e:
if attempt == max_retries:
raise e
await asyncio.sleep(delay * (2 ** attempt)) # 指数退避
return None
return wrapper
return decorator
# 使用示例
@app.get("/unstable")
@retry(max_retries=3, delay=0.5)
async def unstable_endpoint():
import random
if random.random() < 0.5:
raise HTTPException(status_code=500, detail="随机失败")
return {"status": "ok"}总结
通过本章的学习,你已经深入了解了FastAPI的高级特性和扩展能力。这些高级特性包括:
- 自定义路由类:实现更灵活的路由控制
- 响应编码与内容协商:支持多种响应格式
- WebHooks实现:构建事件驱动的API系统
- GraphQL集成:提供灵活的查询能力
- 自定义OpenAPI文档:打造专业的API文档
- 插件系统开发:构建可扩展的应用架构
- 源码阅读与贡献:深入理解框架并参与贡献
这些高级特性让FastAPI不仅是一个Web框架,更是一个完整的API开发平台。掌握这些技能后,你将能够:
- 构建更强大、更灵活的API系统
- 提供更好的开发者体验
- 实现更复杂的业务需求
- 为开源社区做出贡献
记住:技术只是工具,真正的价值在于用这些工具解决实际问题。不断实践,不断优化,你将成为FastAPI专家。