FastAPI测试策略与质量保证:从单元测试到CI/CD的完整指南
在2023年StackOverflow开发者调查中,测试覆盖率不足位列技术债务榜首。我曾见证一个线上事故:由于缺少集成测试,一个API更改导致整个支付系统瘫痪,损失数百万。本文将带你构建坚如磐石的FastAPI测试体系。
引言:为什么测试如此重要?
测试不只是QA的工作,而是开发者的责任。一个完整的测试策略能够:
- 减少70%以上的生产环境bug
- 提升代码可维护性和可重构性
- 加快新功能开发速度
- 增强团队对代码变更的信心
FastAPI基于Python类型提示和Pydantic,天生适合测试。但如何有效测试异步代码、数据库操作和外部依赖?本文将带你构建完整的测试体系。
1. 测试金字塔:单元、集成、E2E
测试金字塔理论
python
# 测试金字塔结构
"""
/\
/ \ 少量E2E测试(用户场景)
/ \
/______\ 更多集成测试(组件交互)
/ \
/__________\ 大量单元测试(独立函数/类)
/ \
/______________\
比例建议:70%单元测试,20%集成测试,10%E2E测试
"""
class TestPyramid:
"""测试金字塔实现"""
def __init__(self):
self.unit_tests = []
self.integration_tests = []
self.e2e_tests = []
def add_unit_test(self, test_func):
"""添加单元测试"""
self.unit_tests.append(test_func)
def add_integration_test(self, test_func):
"""添加集成测试"""
self.integration_tests.append(test_func)
def add_e2e_test(self, test_func):
"""添加端到端测试"""
self.e2e_tests.append(test_func)
def run_all(self):
"""运行所有测试"""
results = {
"unit": self._run_tests(self.unit_tests, "单元测试"),
"integration": self._run_tests(self.integration_tests, "集成测试"),
"e2e": self._run_tests(self.e2e_tests, "E2E测试")
}
total = sum(len(tests) for tests in [self.unit_tests, self.integration_tests, self.e2e_tests])
print(f"\n测试金字塔统计:")
print(f"单元测试: {len(self.unit_tests)} ({len(self.unit_tests)/total:.1%})")
print(f"集成测试: {len(self.integration_tests)} ({len(self.integration_tests)/total:.1%})")
print(f"E2E测试: {len(self.e2e_tests)} ({len(self.e2e_tests)/total:.1%})")
return results
def _run_tests(self, tests, category):
"""运行指定类别的测试"""
print(f"\n运行{category}:")
results = []
for test in tests:
try:
test()
results.append({"test": test.__name__, "status": "PASS"})
print(f" ✓ {test.__name__}")
except Exception as e:
results.append({"test": test.__name__, "status": "FAIL", "error": str(e)})
print(f" ✗ {test.__name__}: {e}")
return results
# 示例测试
def test_addition():
"""单元测试示例"""
assert 1 + 1 == 2
def test_user_registration_flow():
"""集成测试示例"""
# 测试用户注册的完整流程
pass
def test_full_payment_process():
"""E2E测试示例"""
# 测试完整的支付流程
pass
# 构建测试金字塔
pyramid = TestPyramid()
pyramid.add_unit_test(test_addition)
pyramid.add_integration_test(test_user_registration_flow)
pyramid.add_e2e_test(test_full_payment_process)
pyramid.run_all()FastAPI测试策略
python
# tests/__init__.py
"""
测试目录结构:
tests/
├── unit/ # 单元测试
│ ├── test_models.py
│ ├── test_services.py
│ └── test_utils.py
├── integration/ # 集成测试
│ ├── test_api.py
│ ├── test_database.py
│ └── test_auth.py
├── e2e/ # 端到端测试
│ ├── test_user_flow.py
│ └── test_payment_flow.py
├── conftest.py # 共享fixture
└── fixtures/ # 测试数据
├── users.json
└── products.json
"""
# conftest.py - 共享测试配置
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.main import app
from app.database import Base, get_db
from app.core.config import settings
# 创建测试数据库引擎
test_engine = create_engine(
settings.TEST_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
# 创建测试会话工厂
TestingSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=test_engine
)
# 在测试运行前创建表
@pytest.fixture(scope="session", autouse=True)
def setup_database():
"""设置测试数据库"""
Base.metadata.create_all(bind=test_engine)
yield
Base.metadata.drop_all(bind=test_engine)
@pytest.fixture
def db_session():
"""创建数据库会话"""
connection = test_engine.connect()
transaction = connection.begin()
session = TestingSessionLocal(bind=connection)
# 使用嵌套事务,允许回滚
nested = connection.begin_nested()
@event.listens_for(session, "after_transaction_end")
def end_savepoint(session, transaction):
nonlocal nested
if not nested.is_active:
nested = connection.begin_nested()
yield session
session.close()
transaction.rollback()
connection.close()
@pytest.fixture
def client(db_session):
"""创建测试客户端"""
def override_get_db():
try:
yield db_session
finally:
pass
app.dependency_overrides[get_db] = override_get_db
yield TestClient(app)
app.dependency_overrides.clear()
@pytest.fixture
def authenticated_client(client, db_session):
"""创建已认证的测试客户端"""
from app.models.user import User
from app.core.security import create_access_token
# 创建测试用户
user = User(
email="test@example.com",
hashed_password="hashed_password",
is_active=True
)
db_session.add(user)
db_session.commit()
# 创建访问令牌
token = create_access_token(data={"sub": user.email})
# 设置认证头
client.headers.update({"Authorization": f"Bearer {token}"})
yield client, user2. Pytest框架深度使用
高级Fixtures
python
# tests/conftest.py - 扩展fixtures
import pytest
import tempfile
import json
from pathlib import Path
from typing import Dict, Any, Generator
from unittest.mock import Mock, AsyncMock
# 1. 参数化fixture
@pytest.fixture(params=["user1", "user2", "admin"])
def user_type(request):
"""参数化用户类型fixture"""
return request.param
@pytest.fixture
def user_data(user_type):
"""根据用户类型生成测试数据"""
data = {
"email": f"{user_type}@example.com",
"password": "secure_password123",
"is_active": True
}
if user_type == "admin":
data["is_superuser"] = True
data["role"] = "admin"
return data
# 2. 工厂fixture
@pytest.fixture
def user_factory():
"""用户工厂fixture"""
from app.models.user import User
from app.core.security import get_password_hash
def create_user(**kwargs):
"""创建用户实例"""
defaults = {
"email": "user@example.com",
"hashed_password": get_password_hash("password"),
"is_active": True,
"full_name": "Test User"
}
defaults.update(kwargs)
return User(**defaults)
return create_user
# 3. 临时文件fixture
@pytest.fixture
def temp_config_file():
"""临时配置文件fixture"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
config = {
"database": {"url": "sqlite:///:memory:"},
"security": {"secret_key": "test_secret"},
"logging": {"level": "DEBUG"}
}
json.dump(config, f)
f.flush()
yield Path(f.name)
# 清理临时文件
Path(f.name).unlink(missing_ok=True)
# 4. 异步fixture
@pytest.fixture
async def async_client():
"""异步测试客户端"""
from httpx import AsyncClient
async with AsyncClient(app=app, base_url="http://test") as ac:
yield ac
# 5. 自动使用fixture
@pytest.fixture(autouse=True)
def setup_test_environment(monkeypatch):
"""自动设置测试环境"""
# 模拟环境变量
monkeypatch.setenv("ENVIRONMENT", "testing")
monkeypatch.setenv("DATABASE_URL", "sqlite:///:memory:")
# 模拟时间
import datetime
fixed_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
class MockDateTime(datetime.datetime):
@classmethod
def now(cls, tz=None):
return fixed_time
monkeypatch.setattr(datetime, "datetime", MockDateTime)
# 清理回调
yield
# 测试后的清理工作
print("\n测试完成,清理环境...")
# 6. 作用域控制
@pytest.fixture(scope="module")
def module_scoped_resource():
"""模块级fixture"""
resource = {"initialized": False}
print("初始化模块级资源")
resource["initialized"] = True
yield resource
print("清理模块级资源")
@pytest.fixture(scope="class")
def class_scoped_resource():
"""类级fixture"""
return {"class": "scoped"}
@pytest.fixture(scope="session")
def session_scoped_database():
"""会话级数据库连接"""
from sqlalchemy import create_engine
engine = create_engine("sqlite:///:memory:")
yield engine
engine.dispose()
# 7. 动态fixture
@pytest.fixture
def dynamic_fixture(request):
"""动态fixture,根据测试参数调整"""
marker = request.node.get_closest_marker("fixture_data")
if marker and "value" in marker.kwargs:
return marker.kwargs["value"]
return "default_value"
# 8. Fixture依赖
@pytest.fixture
def raw_user_data():
"""原始用户数据"""
return {"email": "test@example.com", "password": "password"}
@pytest.fixture
def hashed_user_data(raw_user_data):
"""哈希密码后的用户数据"""
from app.core.security import get_password_hash
data = raw_user_data.copy()
data["hashed_password"] = get_password_hash(data.pop("password"))
return data
@pytest.fixture
def user_in_db(db_session, hashed_user_data):
"""数据库中的用户"""
from app.models.user import User
user = User(**hashed_user_data)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
return user
# 使用示例
@pytest.mark.fixture_data(value="custom_value")
def test_dynamic_fixture(dynamic_fixture):
assert dynamic_fixture == "custom_value"
class TestUserAPI:
"""使用类级fixture的测试类"""
def test_user_creation(self, class_scoped_resource):
assert class_scoped_resource["class"] == "scoped"
def test_user_deletion(self, class_scoped_resource):
# 同一个测试类中的测试共享fixture实例
pass高级断言和插件
python
# tests/test_advanced_assertions.py
import pytest
from pytest import approx, raises, warns, MonkeyPatch
from datetime import datetime, timedelta
import re
def test_advanced_assertions():
"""高级断言技巧"""
# 1. 浮点数近似相等
result = 0.1 + 0.2
assert result == approx(0.3)
assert result == approx(0.3, rel=1e-3) # 相对误差
assert result == approx(0.3, abs=1e-10) # 绝对误差
# 2. 异常断言
with raises(ValueError) as exc_info:
int("not a number")
assert "invalid literal" in str(exc_info.value)
assert exc_info.type == ValueError
# 3. 警告断言
import warnings
with warns(UserWarning) as warning_list:
warnings.warn("Deprecated!", UserWarning)
assert len(warning_list) == 1
assert "Deprecated" in str(warning_list[0].message)
# 4. 正则表达式匹配
text = "Hello, world!"
assert re.match(r"^Hello", text)
assert re.search(r"world", text)
# 5. 集合断言
set1 = {1, 2, 3}
set2 = {3, 2, 1}
assert set1 == set2
assert 1 in set1
assert {1, 2}.issubset(set1)
# 6. 字典断言
dict1 = {"a": 1, "b": 2}
dict2 = {"b": 2, "a": 1}
assert dict1 == dict2
assert dict1.keys() == {"a", "b"}
assert dict1["a"] == 1
# 7. 时间断言
now = datetime.now()
future = now + timedelta(days=1)
assert future > now
assert (future - now).days == 1
# 8. 自定义断言消息
value = 5
expected = 10
assert value == expected, f"Expected {expected}, got {value}"
# 使用pytest插件
def test_with_plugins():
"""使用pytest插件增强测试"""
# pytest-mock
import pytest_mock
# 在测试中使用 mocker fixture
# pytest-asyncio
import pytest_asyncio
# 支持异步测试
# pytest-cov
# 测试覆盖率
# pytest-xdist
# 并行测试
# pytest-django / pytest-flask / pytest-fastapi
# 特定框架支持
# 自定义断言
class CustomAssertions:
"""自定义断言类"""
@staticmethod
def assert_user_valid(user):
"""断言用户有效"""
assert user is not None
assert hasattr(user, "email")
assert "@" in user.email
assert hasattr(user, "id")
assert user.id > 0
return True
@staticmethod
def assert_response_success(response, status_code=200):
"""断言响应成功"""
assert response.status_code == status_code
data = response.json()
assert "error" not in data
return data
@staticmethod
def assert_pagination(response, expected_count=None):
"""断言分页响应"""
data = response.json()
assert "items" in data
assert "total" in data
assert "page" in data
assert "pages" in data
if expected_count is not None:
assert data["total"] == expected_count
return data
# 使用自定义断言
def test_custom_assertions():
user = Mock(email="test@example.com", id=1)
assert CustomAssertions.assert_user_valid(user)
response = Mock(status_code=200, json=lambda: {"data": "success"})
data = CustomAssertions.assert_response_success(response)
assert data["data"] == "success"
# 参数化测试
@pytest.mark.parametrize(
"input_value,expected",
[
(1, 2),
(2, 4),
(3, 6),
(0, 0),
(-1, -2),
]
)
def test_double_function(input_value, expected):
"""测试双倍函数"""
def double(x):
return x * 2
result = double(input_value)
assert result == expected
# 参数化组合
@pytest.mark.parametrize("x", [0, 1])
@pytest.mark.parametrize("y", [2, 3])
def test_combinations(x, y):
"""测试参数组合"""
assert x + y == sum([x, y])
# 使用fixture参数化
@pytest.fixture(params=[1, 2, 3])
def number_fixture(request):
return request.param
def test_with_fixture_param(number_fixture):
assert number_fixture in [1, 2, 3]
# 标记测试
@pytest.mark.slow
def test_slow_operation():
"""慢速测试"""
import time
time.sleep(2)
assert True
@pytest.mark.skip(reason="功能暂未实现")
def test_unimplemented():
"""跳过测试"""
assert False
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="需要Python 3.8或更高版本"
)
def test_python38_feature():
"""条件跳过测试"""
assert True
@pytest.mark.xfail(reason="已知问题,期待失败")
def test_known_bug():
"""预期失败的测试"""
assert 1 == 2 # 这应该失败
# 使用自定义标记
@pytest.mark.integration
def test_integration():
"""集成测试标记"""
pass
@pytest.mark.e2e
def test_e2e():
"""端到端测试标记"""
pass
@pytest.mark.security
def test_security():
"""安全测试标记"""
pass3. 异步代码测试技巧
AsyncIO测试基础
python
# tests/test_async_basic.py
import pytest
import asyncio
from httpx import AsyncClient
from unittest.mock import AsyncMock, patch
import time
# 基础异步测试
@pytest.mark.asyncio
async def test_async_function():
"""基础异步测试"""
async def async_add(a, b):
await asyncio.sleep(0.1) # 模拟异步操作
return a + b
result = await async_add(1, 2)
assert result == 3
# 异步超时测试
@pytest.mark.asyncio
async def test_async_timeout():
"""测试异步超时"""
async def slow_operation():
await asyncio.sleep(2)
return "done"
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(slow_operation(), timeout=0.1)
# 异步上下文管理器测试
@pytest.mark.asyncio
async def test_async_context():
"""测试异步上下文管理器"""
class AsyncResource:
async def __aenter__(self):
self.value = "initialized"
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.value = "cleaned up"
async with AsyncResource() as resource:
assert resource.value == "initialized"
assert resource.value == "cleaned up"
# FastAPI异步端点测试
@pytest.mark.asyncio
async def test_async_endpoint(async_client):
"""测试异步API端点"""
response = await async_client.get("/api/async-data")
assert response.status_code == 200
data = response.json()
assert "data" in data
assert isinstance(data["data"], list)
# 异步fixture
@pytest.fixture
async def async_user_data():
"""异步fixture"""
await asyncio.sleep(0.1) # 模拟异步初始化
return {"id": 1, "name": "Async User"}
@pytest.mark.asyncio
async def test_with_async_fixture(async_user_data):
"""使用异步fixture的测试"""
assert async_user_data["id"] == 1
assert async_user_data["name"] == "Async User"
# 异步Mock测试
@pytest.mark.asyncio
async def test_async_mock():
"""测试异步Mock"""
# 创建异步Mock
mock_service = AsyncMock()
mock_service.get_data.return_value = {"data": "test"}
# 调用异步方法
result = await mock_service.get_data()
# 验证调用
assert result == {"data": "test"}
mock_service.get_data.assert_awaited_once()
# 异步数据库测试
@pytest.mark.asyncio
async def test_async_database():
"""测试异步数据库操作"""
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy import text
# 创建异步引擎
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
async with engine.begin() as conn:
await conn.execute(text("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)"))
# 创建异步会话
AsyncSessionLocal = sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
async with AsyncSessionLocal() as session:
# 执行异步查询
result = await session.execute(
text("INSERT INTO test (name) VALUES (:name)"),
{"name": "test"}
)
await session.commit()
# 验证插入
result = await session.execute(text("SELECT * FROM test"))
rows = result.fetchall()
assert len(rows) == 1
assert rows[0].name == "test"
# 并发异步测试
@pytest.mark.asyncio
async def test_concurrent_async():
"""测试并发异步操作"""
async def process_item(item):
await asyncio.sleep(0.1)
return item * 2
items = [1, 2, 3, 4, 5]
# 顺序处理
start = time.time()
results_sequential = []
for item in items:
result = await process_item(item)
results_sequential.append(result)
sequential_time = time.time() - start
# 并发处理
start = time.time()
tasks = [process_item(item) for item in items]
results_concurrent = await asyncio.gather(*tasks)
concurrent_time = time.time() - start
# 验证结果
assert results_sequential == results_concurrent
assert concurrent_time < sequential_time # 并发应该更快
# 异步异常测试
@pytest.mark.asyncio
async def test_async_exception():
"""测试异步异常"""
async def failing_operation():
await asyncio.sleep(0.1)
raise ValueError("Async error")
with pytest.raises(ValueError) as exc_info:
await failing_operation()
assert "Async error" in str(exc_info.value)
# 使用pytest-asyncio的高级功能
@pytest.mark.asyncio
async def test_asyncio_event_loop():
"""测试事件循环"""
loop = asyncio.get_event_loop()
assert loop.is_running()
# 在事件循环中调度任务
future = loop.create_future()
loop.call_soon(future.set_result, "done")
result = await future
assert result == "done"
# 异步WebSocket测试
@pytest.mark.asyncio
async def test_websocket():
"""测试WebSocket连接"""
import websockets
async with websockets.connect("ws://localhost:8000/ws") as websocket:
# 发送消息
await websocket.send("Hello")
# 接收响应
response = await websocket.recv()
assert response == "Hello back"
# 测试Ping/Pong
pong_waiter = await websocket.ping()
await pong_waiter
# 测试关闭
await websocket.close()FastAPI异步测试实战
python
# tests/test_async_api.py
import pytest
import asyncio
from httpx import AsyncClient
from unittest.mock import patch, AsyncMock
from datetime import datetime, timedelta
import json
@pytest.mark.asyncio
class TestAsyncAPI:
"""异步API测试类"""
async def test_async_endpoint(self, async_client: AsyncClient):
"""测试异步端点"""
response = await async_client.get("/api/async/data")
assert response.status_code == 200
data = response.json()
assert "timestamp" in data
assert "data" in data
async def test_websocket_endpoint(self, async_client: AsyncClient):
"""测试WebSocket端点"""
async with async_client.websocket_connect("/ws/chat") as websocket:
# 发送消息
await websocket.send_json({
"type": "message",
"content": "Hello",
"user": "test"
})
# 接收响应
data = await websocket.receive_json()
assert data["type"] == "message"
assert data["content"] == "Hello"
assert data["user"] == "test"
async def test_concurrent_requests(self, async_client: AsyncClient):
"""测试并发请求"""
async def make_request():
response = await async_client.get("/api/async/data")
return response.status_code
# 创建10个并发请求
tasks = [make_request() for _ in range(10)]
results = await asyncio.gather(*tasks)
# 所有请求都应该成功
assert all(status == 200 for status in results)
async def test_async_dependency_injection(self, async_client: AsyncClient):
"""测试异步依赖注入"""
response = await async_client.get("/api/async/with-deps")
assert response.status_code == 200
data = response.json()
assert "dependency_result" in data
assert data["dependency_result"] == "async_dependency_worked"
async def test_rate_limiting(self, async_client: AsyncClient):
"""测试速率限制"""
responses = []
# 快速发送多个请求
for i in range(10):
response = await async_client.get("/api/rate-limited")
responses.append(response.status_code)
# 前5个应该成功,后面的可能被限制
success_count = sum(1 for status in responses if status == 200)
rate_limited_count = sum(1 for status in responses if status == 429)
assert success_count >= 1
assert rate_limited_count >= 0
async def test_async_file_upload(self, async_client: AsyncClient):
"""测试异步文件上传"""
# 创建测试文件
files = {"file": ("test.txt", b"Hello, World!", "text/plain")}
response = await async_client.post(
"/api/upload",
files=files
)
assert response.status_code == 200
data = response.json()
assert "filename" in data
assert "size" in data
assert data["size"] == 13
async def test_async_background_tasks(self, async_client: AsyncClient):
"""测试后台任务"""
with patch("app.tasks.background_task.delay") as mock_task:
response = await async_client.post(
"/api/trigger-task",
json={"data": "test"}
)
assert response.status_code == 202
mock_task.assert_called_once_with(data="test")
async def test_async_cache(self, async_client: AsyncClient):
"""测试异步缓存"""
# 第一次请求
response1 = await async_client.get("/api/cached-data")
assert response1.status_code == 200
data1 = response1.json()
assert "timestamp" in data1
# 短暂等待
await asyncio.sleep(0.5)
# 第二次请求(应该从缓存获取)
response2 = await async_client.get("/api/cached-data")
assert response2.status_code == 200
data2 = response2.json()
# 时间戳应该相同(缓存生效)
assert data1["timestamp"] == data2["timestamp"]
async def test_async_error_handling(self, async_client: AsyncClient):
"""测试异步错误处理"""
# 测试400错误
response = await async_client.post(
"/api/async/error",
json={"trigger_error": True}
)
assert response.status_code == 400
data = response.json()
assert "detail" in data
# 测试500错误
response = await async_client.get("/api/async/server-error")
assert response.status_code == 500
async def test_async_validation(self, async_client: AsyncClient):
"""测试异步验证"""
# 无效数据
response = await async_client.post(
"/api/async/validate",
json={"email": "invalid-email", "age": -5}
)
assert response.status_code == 422
data = response.json()
assert "detail" in data
# 验证错误详情
errors = data["detail"]
assert len(errors) >= 2
error_fields = [error["loc"][-1] for error in errors]
assert "email" in error_fields
assert "age" in error_fields
async def test_async_middleware(self, async_client: AsyncClient):
"""测试异步中间件"""
response = await async_client.get("/api/async/data")
# 检查中间件添加的头部
assert "X-Process-Time" in response.headers
assert "X-Request-ID" in response.headers
process_time = float(response.headers["X-Process-Time"])
assert process_time > 0
async def test_async_database_transaction(self, async_client: AsyncClient, db_session):
"""测试异步数据库事务"""
from app.models.user import User
# 获取初始用户数
initial_count = db_session.query(User).count()
# 创建用户
response = await async_client.post(
"/api/async/users",
json={
"email": "transaction_test@example.com",
"password": "secure123"
}
)
assert response.status_code == 201
# 验证用户已创建
final_count = db_session.query(User).count()
assert final_count == initial_count + 1
# 验证用户数据
user = db_session.query(User).filter(
User.email == "transaction_test@example.com"
).first()
assert user is not None
assert user.is_active is True
@pytest.mark.timeout(5) # 设置测试超时
async def test_async_timeout_handling(self, async_client: AsyncClient):
"""测试异步超时处理"""
# 这个端点应该快速响应
response = await async_client.get("/api/async/fast")
assert response.status_code == 200
# 这个端点可能较慢
try:
response = await async_client.get("/api/async/slow", timeout=1.0)
except Exception as e:
# 应该超时
assert "timeout" in str(e).lower()
async def test_async_circuit_breaker(self, async_client: AsyncClient):
"""测试熔断器模式"""
failures = []
# 模拟多次失败
with patch("app.services.external_api.call", side_effect=Exception("Service down")):
for i in range(5):
try:
response = await async_client.get("/api/with-external")
if response.status_code != 200:
failures.append(response.status_code)
except Exception:
failures.append("exception")
# 验证熔断器触发
assert len(failures) > 0
async def test_async_bulk_operations(self, async_client: AsyncClient):
"""测试批量操作"""
# 批量创建用户
users = [
{"email": f"user{i}@example.com", "password": "password123"}
for i in range(10)
]
response = await async_client.post(
"/api/async/users/bulk",
json={"users": users}
)
assert response.status_code == 201
data = response.json()
assert "created_count" in data
assert data["created_count"] == 10
assert "failed_count" in data
assert data["failed_count"] == 0
async def test_async_streaming_response(self, async_client: AsyncClient):
"""测试流式响应"""
response = await async_client.get("/api/async/stream")
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream"
# 验证流式数据
async for chunk in response.aiter_bytes():
assert chunk # 应该有数据
async def test_async_health_check(self, async_client: AsyncClient):
"""测试健康检查"""
response = await async_client.get("/health")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert data["status"] == "healthy"
assert "timestamp" in data
assert "version" in data4. 数据库测试:事务管理
数据库测试策略
python
# tests/test_database.py
import pytest
from sqlalchemy import text, select, delete
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, joinedload
import time
class TestDatabase:
"""数据库测试类"""
def test_database_connection(self, db_session: Session):
"""测试数据库连接"""
result = db_session.execute(text("SELECT 1"))
assert result.scalar() == 1
def test_create_and_read(self, db_session: Session):
"""测试创建和读取"""
from app.models.user import User
# 创建用户
user = User(
email="test@example.com",
hashed_password="hashed",
full_name="Test User"
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
# 读取用户
saved_user = db_session.query(User).filter_by(email="test@example.com").first()
assert saved_user is not None
assert saved_user.id == user.id
assert saved_user.email == user.email
assert saved_user.full_name == user.full_name
def test_update(self, db_session: Session):
"""测试更新操作"""
from app.models.user import User
# 创建用户
user = User(
email="update@example.com",
hashed_password="hashed",
full_name="Original Name"
)
db_session.add(user)
db_session.commit()
# 更新用户
user.full_name = "Updated Name"
db_session.commit()
db_session.refresh(user)
# 验证更新
assert user.full_name == "Updated Name"
def test_delete(self, db_session: Session):
"""测试删除操作"""
from app.models.user import User
# 创建用户
user = User(
email="delete@example.com",
hashed_password="hashed"
)
db_session.add(user)
db_session.commit()
user_id = user.id
# 删除用户
db_session.delete(user)
db_session.commit()
# 验证删除
deleted_user = db_session.query(User).get(user_id)
assert deleted_user is None
def test_unique_constraint(self, db_session: Session):
"""测试唯一约束"""
from app.models.user import User
# 创建第一个用户
user1 = User(
email="unique@example.com",
hashed_password="hashed"
)
db_session.add(user1)
db_session.commit()
# 尝试创建相同邮箱的用户(应该失败)
user2 = User(
email="unique@example.com", # 相同邮箱
hashed_password="hashed2"
)
db_session.add(user2)
with pytest.raises(IntegrityError):
db_session.commit()
# 回滚失败的提交
db_session.rollback()
def test_foreign_key_constraint(self, db_session: Session):
"""测试外键约束"""
from app.models.user import User
from app.models.article import Article
# 创建用户
user = User(
email="author@example.com",
hashed_password="hashed"
)
db_session.add(user)
db_session.commit()
# 创建文章(关联用户)
article = Article(
title="Test Article",
content="Content",
author_id=user.id # 有效的外键
)
db_session.add(article)
db_session.commit()
# 尝试创建无效外键的文章(应该失败)
invalid_article = Article(
title="Invalid Article",
content="Content",
author_id=99999 # 不存在的用户ID
)
db_session.add(invalid_article)
with pytest.raises(IntegrityError):
db_session.commit()
db_session.rollback()
def test_transaction_rollback(self, db_session: Session):
"""测试事务回滚"""
from app.models.user import User
# 记录初始用户数
initial_count = db_session.query(User).count()
try:
# 开始事务
user1 = User(
email="rollback1@example.com",
hashed_password="hashed"
)
db_session.add(user1)
db_session.flush() # 获取ID但不提交
user2 = User(
email="rollback2@example.com",
hashed_password="hashed"
)
db_session.add(user2)
db_session.flush()
# 故意引发异常
raise ValueError("Something went wrong")
except ValueError:
# 回滚事务
db_session.rollback()
# 验证回滚
final_count = db_session.query(User).count()
assert final_count == initial_count # 用户数应该不变
# 验证用户不存在
user1_exists = db_session.query(User).filter_by(
email="rollback1@example.com"
).first()
assert user1_exists is None
def test_nested_transactions(self, db_session: Session):
"""测试嵌套事务"""
from app.models.user import User
initial_count = db_session.query(User).count()
# 外层事务
try:
user1 = User(
email="outer@example.com",
hashed_password="hashed"
)
db_session.add(user1)
# 内层事务(保存点)
nested = db_session.begin_nested()
try:
user2 = User(
email="inner@example.com",
hashed_password="hashed"
)
db_session.add(user2)
# 内层回滚
raise ValueError("Inner transaction failed")
except ValueError:
nested.rollback()
# 外层提交
db_session.commit()
except Exception:
db_session.rollback()
raise
# 验证:只有外层用户被保存
final_count = db_session.query(User).count()
assert final_count == initial_count + 1
outer_user = db_session.query(User).filter_by(
email="outer@example.com"
).first()
assert outer_user is not None
inner_user = db_session.query(User).filter_by(
email="inner@example.com"
).first()
assert inner_user is None
def test_concurrent_transactions(self, db_session: Session):
"""测试并发事务"""
from app.models.counter import Counter
import threading
# 创建计数器
counter = Counter(name="test", value=0)
db_session.add(counter)
db_session.commit()
db_session.refresh(counter)
results = []
lock = threading.Lock()
def increment_counter():
"""在独立会话中增加计数器"""
from app.database import SessionLocal
session = SessionLocal()
try:
# 获取当前值
counter_obj = session.query(Counter).filter_by(name="test").with_for_update().first()
# 增加
time.sleep(0.01) # 模拟处理时间
counter_obj.value += 1
session.commit()
with lock:
results.append(counter_obj.value)
finally:
session.close()
# 启动多个线程
threads = []
for _ in range(10):
thread = threading.Thread(target=increment_counter)
threads.append(thread)
thread.start()
# 等待所有线程完成
for thread in threads:
thread.join()
# 验证最终值
final_counter = db_session.query(Counter).filter_by(name="test").first()
assert final_counter.value == 10 # 每个线程增加1
def test_database_isolation_levels(self, db_session: Session):
"""测试数据库隔离级别"""
# 测试读已提交(Read Committed)
db_session.connection(execution_options={"isolation_level": "READ COMMITTED"})
# 测试可重复读(Repeatable Read)
db_session.connection(execution_options={"isolation_level": "REPEATABLE READ"})
# 测试序列化(Serializable)
db_session.connection(execution_options={"isolation_level": "SERIALIZABLE"})
def test_bulk_operations(self, db_session: Session):
"""测试批量操作"""
from app.models.user import User
import uuid
# 批量插入
users = []
for i in range(100):
users.append({
"email": f"bulk{i}@example.com",
"hashed_password": "hashed",
"full_name": f"User {i}"
})
# 使用bulk_insert_mappings提高性能
db_session.bulk_insert_mappings(User, users)
db_session.commit()
# 验证插入
count = db_session.query(User).filter(
User.email.like("bulk%@example.com")
).count()
assert count == 100
# 批量更新
db_session.query(User).filter(
User.email.like("bulk%@example.com")
).update({"is_active": False})
db_session.commit()
# 验证更新
inactive_count = db_session.query(User).filter(
User.email.like("bulk%@example.com"),
User.is_active == False
).count()
assert inactive_count == 100
def test_query_optimization(self, db_session: Session):
"""测试查询优化"""
from app.models.user import User
from app.models.article import Article
# 创建测试数据
user = User(
email="query_test@example.com",
hashed_password="hashed"
)
db_session.add(user)
db_session.commit()
for i in range(10):
article = Article(
title=f"Article {i}",
content=f"Content {i}",
author_id=user.id
)
db_session.add(article)
db_session.commit()
# 测试N+1问题
# 不好的写法(N+1查询)
articles = db_session.query(Article).filter_by(author_id=user.id).all()
for article in articles:
# 每次循环都会查询作者(N+1问题)
author = db_session.query(User).get(article.author_id)
# 好的写法(使用join)
articles_with_author = db_session.query(Article, User).join(
User, Article.author_id == User.id
).filter(Article.author_id == user.id).all()
assert len(articles_with_author) == 10
# 测试索引使用
# 解释查询计划
explain = db_session.execute(
text("EXPLAIN QUERY PLAN SELECT * FROM users WHERE email = :email"),
{"email": "query_test@example.com"}
).fetchall()
# 检查是否使用了索引
explain_str = str(explain).lower()
assert "scan" in explain_str or "search" in explain_str
def test_database_migrations(self, db_session: Session):
"""测试数据库迁移"""
# 验证表结构
tables = db_session.execute(
text("SELECT name FROM sqlite_master WHERE type='table'")
).fetchall()
table_names = [t[0] for t in tables]
assert "users" in table_names
assert "articles" in table_names
# 验证列存在
columns = db_session.execute(
text("PRAGMA table_info(users)")
).fetchall()
column_names = [c[1] for c in columns]
assert "email" in column_names
assert "hashed_password" in column_names
assert "created_at" in column_names
def test_connection_pool(self, db_session: Session):
"""测试连接池"""
# 获取连接池信息
pool = db_session.bind.pool
# 检查连接池配置
assert pool.size() <= 20 # 最大连接数
assert pool.checkedin() >= 0
assert pool.checkedout() >= 0
def test_database_cleanup(self, db_session: Session):
"""测试数据库清理"""
# 确保每个测试后数据库是干净的
tables = db_session.execute(
text("SELECT name FROM sqlite_master WHERE type='table'")
).fetchall()
# 删除所有数据(保留表结构)
for table in tables:
if table[0] != "sqlite_sequence": # 跳过自增序列表
db_session.execute(text(f"DELETE FROM {table[0]}"))
db_session.commit()
# 验证表为空
for table in tables:
if table[0] != "sqlite_sequence":
count = db_session.execute(
text(f"SELECT COUNT(*) FROM {table[0]}")
).scalar()
assert count == 0
# 数据库工厂fixture
@pytest.fixture
def user_factory(db_session):
"""用户工厂fixture"""
from app.models.user import User
import uuid
def create_user(**kwargs):
"""创建用户"""
defaults = {
"email": f"user_{uuid.uuid4().hex[:8]}@example.com",
"hashed_password": "hashed_password",
"is_active": True
}
defaults.update(kwargs)
user = User(**defaults)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
return user
return create_user
@pytest.fixture
def article_factory(db_session, user_factory):
"""文章工厂fixture"""
from app.models.article import Article
def create_article(**kwargs):
"""创建文章"""
if "author_id" not in kwargs:
user = user_factory()
kwargs["author_id"] = user.id
defaults = {
"title": f"Test Article {uuid.uuid4().hex[:8]}",
"content": "Test content",
"is_published": True
}
defaults.update(kwargs)
article = Article(**defaults)
db_session.add(article)
db_session.commit()
db_session.refresh(article)
return article
return create_article
# 使用工厂的测试
def test_with_factories(user_factory, article_factory):
"""使用工厂fixture的测试"""
# 创建用户
user = user_factory(
email="factory@example.com",
full_name="Factory User"
)
assert user.id is not None
assert user.email == "factory@example.com"
# 创建文章
article = article_factory(
title="Factory Article",
author_id=user.id
)
assert article.id is not None
assert article.author_id == user.id
assert article.title == "Factory Article"5. 模拟外部依赖
Mock和Patch高级用法
python
# tests/test_mocking.py
import pytest
from unittest.mock import Mock, patch, MagicMock, AsyncMock, call, PropertyMock
import requests
from datetime import datetime
import asyncio
class TestMocking:
"""模拟外部依赖测试"""
def test_basic_mock(self):
"""基础Mock测试"""
# 创建Mock对象
mock_service = Mock()
# 设置返回值
mock_service.get_data.return_value = {"data": "test"}
# 调用Mock方法
result = mock_service.get_data()
# 验证调用
assert result == {"data": "test"}
mock_service.get_data.assert_called_once()
def test_mock_with_side_effect(self):
"""使用side_effect的Mock"""
mock_func = Mock()
# 设置side_effect为函数
def side_effect_func(x):
return x * 2
mock_func.side_effect = side_effect_func
assert mock_func(5) == 10
assert mock_func(10) == 20
# 设置side_effect为异常
mock_func.side_effect = ValueError("Error occurred")
with pytest.raises(ValueError):
mock_func(1)
# 设置side_effect为序列
mock_func.side_effect = [1, 2, 3]
assert mock_func() == 1
assert mock_func() == 2
assert mock_func() == 3
# 后续调用会抛出StopIteration
with pytest.raises(StopIteration):
mock_func()
def test_mock_attributes(self):
"""测试Mock属性"""
mock_obj = Mock()
# 设置属性
mock_obj.name = "Test Object"
mock_obj.value = 42
assert mock_obj.name == "Test Object"
assert mock_obj.value == 42
# 动态属性
mock_obj.dynamic_attribute = "dynamic"
assert mock_obj.dynamic_attribute == "dynamic"
def test_magic_mock(self):
"""测试MagicMock"""
magic_mock = MagicMock()
# MagicMock支持魔术方法
magic_mock.__len__.return_value = 5
assert len(magic_mock) == 5
magic_mock.__getitem__.return_value = "item"
assert magic_mock[0] == "item"
assert magic_mock["key"] == "item"
# 调用验证
magic_mock.__getitem__.assert_any_call(0)
magic_mock.__getitem__.assert_any_call("key")
def test_property_mock(self):
"""测试属性Mock"""
class MyClass:
@property
def value(self):
return "real value"
obj = MyClass()
with patch.object(MyClass, 'value', new_callable=PropertyMock) as mock_prop:
mock_prop.return_value = "mocked value"
assert obj.value == "mocked value"
mock_prop.assert_called_once()
def test_patch_decorator(self):
"""测试patch装饰器"""
# 要测试的函数
def get_external_data():
import requests
response = requests.get("https://api.example.com/data")
return response.json()
# 使用patch模拟requests.get
@patch('requests.get')
def test_get_external_data(mock_get):
# 设置Mock响应
mock_response = Mock()
mock_response.json.return_value = {"data": "mocked"}
mock_get.return_value = mock_response
# 调用被测试函数
result = get_external_data()
# 验证
assert result == {"data": "mocked"}
mock_get.assert_called_once_with("https://api.example.com/data")
test_get_external_data()
def test_patch_context_manager(self):
"""测试patch上下文管理器"""
# 使用上下文管理器
with patch('datetime.datetime') as mock_datetime:
# 设置now()返回固定时间
fixed_time = datetime(2023, 1, 1, 12, 0, 0)
mock_datetime.now.return_value = fixed_time
# 在上下文中,datetime.now()返回模拟值
now = datetime.now()
assert now == fixed_time
# 上下文外,datetime.now()恢复正常
assert datetime.now() != fixed_time
def test_patch_multiple(self):
"""测试同时patch多个对象"""
def complex_function():
import requests
import time
data = requests.get("https://api.example.com/data").json()
time.sleep(1) # 模拟处理时间
return data
with patch('requests.get') as mock_get, \
patch('time.sleep') as mock_sleep:
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_get.return_value = mock_response
result = complex_function()
assert result == {"result": "success"}
mock_get.assert_called_once()
mock_sleep.assert_called_once_with(1)
def test_async_mock(self):
"""测试异步Mock"""
async_mock = AsyncMock()
# 设置异步方法的返回值
async_mock.fetch_data.return_value = {"data": "async"}
async def test_async():
result = await async_mock.fetch_data()
return result
# 运行异步测试
import asyncio
result = asyncio.run(test_async())
assert result == {"data": "async"}
async_mock.fetch_data.assert_awaited_once()
def test_mock_call_args(self):
"""测试调用参数验证"""
mock_func = Mock()
# 调用多次
mock_func(1, 2, 3)
mock_func(a=1, b=2)
mock_func(1, b=2)
# 验证调用
assert mock_func.call_count == 3
# 获取所有调用
calls = mock_func.call_args_list
assert len(calls) == 3
# 验证特定调用
mock_func.assert_any_call(1, 2, 3)
mock_func.assert_any_call(a=1, b=2)
# 验证最后一次调用
mock_func.assert_called_with(1, b=2)
def test_mock_reset(self):
"""测试重置Mock"""
mock_obj = Mock()
mock_obj.method()
mock_obj.another_method()
assert mock_obj.method.call_count == 1
assert mock_obj.another_method.call_count == 1
# 重置Mock
mock_obj.reset_mock()
assert mock_obj.method.call_count == 0
assert mock_obj.another_method.call_count == 0
def test_patch_object(self):
"""测试patch.object"""
class RealClass:
def method(self):
return "real"
obj = RealClass()
# 替换对象的方法
with patch.object(obj, 'method', return_value="mocked"):
assert obj.method() == "mocked"
# 方法恢复
assert obj.method() == "real"
def test_patch_dict(self):
"""测试patch.dict"""
config = {"host": "localhost", "port": 8080}
# 临时修改字典
with patch.dict(config, {"port": 9000, "debug": True}, clear=False):
assert config["host"] == "localhost" # 保留原值
assert config["port"] == 9000 # 修改的值
assert config["debug"] == True # 新增的值
# 恢复原状
assert config == {"host": "localhost", "port": 8080}
# 使用clear=True清空字典
with patch.dict(config, {"new": "value"}, clear=True):
assert config == {"new": "value"}
assert config == {"host": "localhost", "port": 8080}
# 实际场景:测试外部API调用
class TestExternalAPI:
"""测试外部API调用"""
@pytest.fixture
def mock_requests(self):
"""模拟requests模块"""
with patch('requests.get') as mock_get:
yield mock_get
def test_fetch_user_data(self, mock_requests):
"""测试获取用户数据"""
from app.services.external import fetch_user_data
# 设置Mock响应
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"id": 1,
"name": "John Doe",
"email": "john@example.com"
}
mock_requests.return_value = mock_response
# 调用被测试函数
user_data = fetch_user_data(1)
# 验证
assert user_data["id"] == 1
assert user_data["name"] == "John Doe"
mock_requests.assert_called_once_with(
"https://api.external.com/users/1",
timeout=10
)
def test_fetch_user_data_error(self, mock_requests):
"""测试获取用户数据错误"""
from app.services.external import fetch_user_data
# 模拟请求异常
mock_requests.side_effect = requests.exceptions.Timeout("Request timeout")
# 验证异常处理
with pytest.raises(requests.exceptions.Timeout):
fetch_user_data(1)
def test_fetch_user_data_not_found(self, mock_requests):
"""测试用户不存在"""
from app.services.external import fetch_user_data
# 模拟404响应
mock_response = Mock()
mock_response.status_code = 404
mock_requests.return_value = mock_response
user_data = fetch_user_data(999)
assert user_data is None
# 测试数据库依赖
class TestDatabaseDependencies:
"""测试数据库依赖"""
@pytest.fixture
def mock_db_session(self):
"""模拟数据库会话"""
with patch('app.database.get_db') as mock_get_db:
mock_session = Mock()
mock_get_db.return_value = mock_session
yield mock_session
def test_user_service(self, mock_db_session):
"""测试用户服务"""
from app.services.user import UserService
# 设置Mock查询
mock_user = Mock()
mock_user.id = 1
mock_user.email = "test@example.com"
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = mock_user
mock_db_session.query.return_value = mock_query
# 测试服务方法
service = UserService(mock_db_session)
user = service.get_user_by_email("test@example.com")
# 验证
assert user.id == 1
assert user.email == "test@example.com"
# 验证数据库调用
mock_db_session.query.assert_called_once()
mock_query.filter_by.assert_called_once_with(email="test@example.com")
# 测试第三方服务
class TestThirdPartyServices:
"""测试第三方服务"""
def test_email_service(self):
"""测试邮件服务"""
with patch('smtplib.SMTP') as mock_smtp:
# 设置Mock
mock_server = Mock()
mock_smtp.return_value.__enter__.return_value = mock_server
from app.services.email import send_email
# 发送邮件
send_email(
to="recipient@example.com",
subject="Test",
body="Test email"
)
# 验证SMTP调用
mock_smtp.assert_called_once_with('smtp.gmail.com', 587)
mock_server.starttls.assert_called_once()
mock_server.login.assert_called_once()
mock_server.sendmail.assert_called_once()
def test_payment_gateway(self):
"""测试支付网关"""
with patch('app.services.payment.stripe.Charge.create') as mock_charge:
# 设置Mock响应
mock_charge.return_value = {
'id': 'ch_123',
'amount': 1000,
'status': 'succeeded'
}
from app.services.payment import process_payment
result = process_payment(
token="tok_123",
amount=1000,
currency="usd"
)
assert result['status'] == 'succeeded'
mock_charge.assert_called_once_with(
amount=1000,
currency='usd',
source='tok_123',
description='Payment'
)
def test_cache_service(self):
"""测试缓存服务"""
with patch('redis.Redis') as mock_redis_class:
mock_redis = Mock()
mock_redis_class.return_value = mock_redis
from app.services.cache import CacheService
cache = CacheService()
# 测试设置缓存
mock_redis.set.return_value = True
result = cache.set("key", "value", 3600)
assert result is True
mock_redis.set.assert_called_once_with("key", "value", ex=3600)
# 测试获取缓存
mock_redis.get.return_value = b"cached value"
value = cache.get("key")
assert value == "cached value"
mock_redis.get.assert_called_once_with("key")
# 测试缓存未命中
mock_redis.get.return_value = None
value = cache.get("nonexistent")
assert value is None
# 测试文件系统操作
class TestFileSystem:
"""测试文件系统操作"""
def test_file_operations(self, tmp_path):
"""测试文件操作"""
import os
import shutil
# 创建临时目录
test_dir = tmp_path / "test_dir"
test_dir.mkdir()
test_file = test_dir / "test.txt"
test_file.write_text("Hello, World!")
# 测试文件存在
assert test_file.exists()
assert test_file.read_text() == "Hello, World!"
# 测试文件操作
from app.services.file import FileService
service = FileService()
# 使用patch模拟os.path.getsize
with patch('os.path.getsize', return_value=13):
size = service.get_file_size(str(test_file))
assert size == 13
# 使用patch模拟shutil.copy
with patch('shutil.copy') as mock_copy:
dest_file = test_dir / "copy.txt"
service.copy_file(str(test_file), str(dest_file))
mock_copy.assert_called_once_with(str(test_file), str(dest_file))
def test_temp_files(self):
"""测试临时文件"""
from tempfile import NamedTemporaryFile
# 使用NamedTemporaryFile创建临时文件
with NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
f.write("Temporary content")
temp_path = f.name
# 验证文件内容
with open(temp_path, 'r') as f:
content = f.read()
assert content == "Temporary content"
# 清理
import os
os.unlink(temp_path)
assert not os.path.exists(temp_path)
# 集成测试:模拟多个外部依赖
class TestIntegrationWithMocks:
"""使用Mock的集成测试"""
def test_complete_order_flow(self):
"""测试完整订单流程"""
with patch('app.services.payment.process_payment') as mock_payment, \
patch('app.services.inventory.check_stock') as mock_stock, \
patch('app.services.email.send_order_confirmation') as mock_email, \
patch('app.services.shipping.create_shipment') as mock_shipping:
# 设置Mock返回值
mock_stock.return_value = True # 有库存
mock_payment.return_value = {'status': 'succeeded', 'id': 'pay_123'}
mock_shipping.return_value = {'tracking_number': 'TRK123'}
from app.services.order import OrderService
service = OrderService()
result = service.process_order({
'user_id': 1,
'items': [{'product_id': 101, 'quantity': 2}],
'shipping_address': '123 Main St'
})
# 验证结果
assert result['success'] is True
assert 'order_id' in result
assert 'payment_id' in result
assert 'tracking_number' in result
# 验证服务调用顺序
assert mock_stock.called
assert mock_payment.called
assert mock_shipping.called
assert mock_email.called
# 验证调用参数
mock_payment.assert_called_once()
mock_shipping.assert_called_once_with(
address='123 Main St',
items=[{'product_id': 101, 'quantity': 2}]
)6. 性能测试:Locust实战
Locust性能测试框架
python
# locustfile.py - FastAPI性能测试
from locust import HttpUser, task, between, TaskSet, events
from locust.runners import MasterRunner, WorkerRunner
import random
import json
from datetime import datetime
# 自定义事件监听器
@events.init.add_listener
def on_locust_init(environment, **kwargs):
"""Locust初始化事件"""
print(f"Locust初始化: {environment.host}")
if isinstance(environment.runner, MasterRunner):
print("运行在Master模式")
elif isinstance(environment.runner, WorkerRunner):
print("运行在Worker模式")
else:
print("运行在独立模式")
@events.test_start.add_listener
def on_test_start(environment, **kwargs):
"""测试开始事件"""
print(f"测试开始: {datetime.now()}")
@events.test_stop.add_listener
def on_test_stop(environment, **kwargs):
"""测试结束事件"""
print(f"测试结束: {datetime.now()}")
# 自定义度量
from locust import stats
stats.CSV_STATS_INTERVAL_SEC = 5 # CSV统计间隔
stats.CURRENT_RESPONSE_TIME_PERCENTILE_WINDOW = 10 # 响应时间百分位窗口
# 基础测试类
class FastAPIUser(HttpUser):
"""FastAPI用户基类"""
wait_time = between(1, 3) # 请求间隔1-3秒
host = "http://localhost:8000"
def on_start(self):
"""用户启动时执行"""
self.auth_token = None
self.user_id = None
self.login()
def on_stop(self):
"""用户停止时执行"""
self.logout()
def login(self):
"""登录"""
response = self.client.post("/api/auth/login", json={
"email": "test@example.com",
"password": "password123"
})
if response.status_code == 200:
data = response.json()
self.auth_token = data.get("access_token")
self.user_id = data.get("user_id")
self.headers = {"Authorization": f"Bearer {self.auth_token}"}
def logout(self):
"""登出"""
if self.auth_token:
self.client.post("/api/auth/logout", headers=self.headers)
# API端点测试
class APITests(TaskSet):
"""API端点测试集"""
@task(3) # 权重3,更频繁执行
def get_public_data(self):
"""获取公开数据"""
with self.client.get("/api/public/data", catch_response=True) as response:
if response.status_code == 200:
response.success()
else:
response.failure(f"获取公开数据失败: {response.status_code}")
@task(2)
def get_user_profile(self):
"""获取用户资料"""
if hasattr(self.user, 'headers'):
self.client.get("/api/users/me", headers=self.user.headers)
@task(1)
def create_item(self):
"""创建项目"""
if hasattr(self.user, 'headers'):
item_data = {
"name": f"Item_{random.randint(1, 1000)}",
"description": "Test item",
"price": random.uniform(10, 1000)
}
with self.client.post(
"/api/items",
json=item_data,
headers=self.user.headers,
catch_response=True
) as response:
if response.status_code == 201:
response.success()
else:
response.failure(f"创建项目失败: {response.status_code}")
@task(1)
def update_item(self):
"""更新项目"""
if hasattr(self.user, 'headers'):
item_id = random.randint(1, 100)
update_data = {
"name": f"Updated_{random.randint(1, 1000)}",
"price": random.uniform(10, 1000)
}
self.client.put(
f"/api/items/{item_id}",
json=update_data,
headers=self.user.headers
)
# 数据库操作测试
class DatabaseTests(TaskSet):
"""数据库操作测试集"""
@task(5)
def simple_query(self):
"""简单查询"""
self.client.get("/api/db/simple-query")
@task(3)
def complex_query(self):
"""复杂查询"""
self.client.get("/api/db/complex-query")
@task(2)
def write_operation(self):
"""写操作"""
data = {
"data": f"test_{random.randint(1, 10000)}",
"value": random.randint(1, 100)
}
self.client.post("/api/db/write", json=data)
@task(1)
def transaction_test(self):
"""事务测试"""
self.client.post("/api/db/transaction")
# WebSocket测试
class WebSocketTests(TaskSet):
"""WebSocket测试集"""
@task
def websocket_chat(self):
"""WebSocket聊天测试"""
import websocket
import threading
import time
ws_url = self.user.host.replace("http", "ws") + "/ws/chat"
def on_message(ws, message):
print(f"收到消息: {message}")
def on_error(ws, error):
print(f"WebSocket错误: {error}")
def on_close(ws, close_status_code, close_msg):
print(f"WebSocket关闭: {close_status_code} - {close_msg}")
def on_open(ws):
print("WebSocket连接已打开")
# 发送测试消息
ws.send(json.dumps({
"type": "message",
"content": f"Hello from Locust {random.randint(1, 100)}",
"timestamp": datetime.now().isoformat()
}))
# 创建WebSocket连接
ws = websocket.WebSocketApp(
ws_url,
on_open=on_open,
on_message=on_message,
on_error=on_error,
on_close=on_close
)
# 在新线程中运行WebSocket
wst = threading.Thread(target=ws.run_forever)
wst.daemon = True
wst.start()
# 保持连接一段时间
time.sleep(5)
# 关闭连接
ws.close()
# 完整用户模拟
class RegularUser(FastAPIUser):
"""普通用户"""
tasks = [APITests, DatabaseTests]
weight = 3 # 权重,出现频率
class PowerUser(FastAPIUser):
"""高级用户"""
tasks = [APITests, DatabaseTests, WebSocketTests]
weight = 1
# 自定义负载形状
from locust import LoadTestShape
class CustomLoadShape(LoadTestShape):
"""自定义负载形状"""
stages = [
{"duration": 60, "users": 10, "spawn_rate": 10}, # 第1分钟:10用户
{"duration": 120, "users": 50, "spawn_rate": 5}, # 第2-3分钟:50用户
{"duration": 180, "users": 100, "spawn_rate": 10}, # 第3-6分钟:100用户
{"duration": 240, "users": 50, "spawn_rate": 5}, # 第6-8分钟:50用户
{"duration": 300, "users": 10, "spawn_rate": 10}, # 第8-10分钟:10用户
]
def tick(self):
"""返回当前阶段的用户数和生成率"""
run_time = self.get_run_time()
for stage in self.stages:
if run_time < stage["duration"]:
return (stage["users"], stage["spawn_rate"])
return None # 测试结束
# 自定义用户等待时间
class RandomWaitUser(HttpUser):
"""随机等待用户"""
@property
def wait_time(self):
"""动态等待时间"""
import random
# 90%的请求等待1-2秒,10%的请求等待5-10秒
if random.random() < 0.9:
return random.uniform(1, 2)
else:
return random.uniform(5, 10)
# 分布式测试配置
"""
# 主节点启动命令
locust -f locustfile.py --master --host=http://localhost:8000
# 工作节点启动命令
locust -f locustfile.py --worker --master-host=192.168.1.100
# Web界面
locust -f locustfile.py --web-host=0.0.0.0 --web-port=8089
"""
# 命令行运行
"""
# 无Web界面运行测试
locust -f locustfile.py --headless -u 100 -r 10 -t 5m
# 参数说明:
# --headless: 无头模式
# -u: 用户数
# -r: 生成率(每秒生成用户数)
# -t: 测试时间
# --csv: 保存CSV结果
# --html: 生成HTML报告
"""
# 测试结果分析
class TestResultAnalyzer:
"""测试结果分析器"""
@staticmethod
def analyze_csv(csv_file):
"""分析CSV结果文件"""
import pandas as pd
df = pd.read_csv(csv_file)
print("测试结果分析:")
print(f"总请求数: {df['Request Count'].sum()}")
print(f"失败请求数: {df['Failure Count'].sum()}")
print(f"平均响应时间: {df['Average Response Time'].mean():.2f}ms")
print(f"95%响应时间: {df['95%'].max():.2f}ms")
print(f"最大响应时间: {df['Max Response Time'].max():.2f}ms")
# 请求率
total_time = df['Total Average Response Time'].max() / 1000 # 秒
total_requests = df['Request Count'].sum()
rps = total_requests / total_time if total_time > 0 else 0
print(f"平均RPS: {rps:.2f}")
return df
@staticmethod
def generate_report(df, output_file="performance_report.md"):
"""生成性能报告"""
with open(output_file, "w") as f:
f.write("# 性能测试报告\n\n")
f.write(f"生成时间: {datetime.now()}\n\n")
f.write("## 总体统计\n")
f.write(f"- 总请求数: {df['Request Count'].sum()}\n")
f.write(f"- 失败请求数: {df['Failure Count'].sum()}\n")
f.write(f"- 失败率: {df['Failure Count'].sum()/df['Request Count'].sum()*100:.2f}%\n")
f.write(f"- 平均响应时间: {df['Average Response Time'].mean():.2f}ms\n")
f.write(f"- 95%响应时间: {df['95%'].max():.2f}ms\n\n")
f.write("## 端点性能\n")
for endpoint in df['Name'].unique():
endpoint_data = df[df['Name'] == endpoint]
f.write(f"### {endpoint}\n")
f.write(f"- 请求数: {endpoint_data['Request Count'].sum()}\n")
f.write(f"- 平均响应时间: {endpoint_data['Average Response Time'].mean():.2f}ms\n")
f.write(f"- 失败率: {endpoint_data['Failure Count'].sum()/endpoint_data['Request Count'].sum()*100:.2f}%\n\n")
f.write("## 建议\n")
if df['Failure Count'].sum() > 0:
f.write("1. 存在失败请求,需要检查API稳定性\n")
if df['95%'].max() > 1000: # 超过1秒
f.write("2. 部分端点响应时间较长,需要优化\n")
if df['Average Response Time'].mean() > 500: # 平均超过500ms
f.write("3. 整体响应时间偏高,建议进行性能优化\n")性能测试实战
python
# tests/performance/test_api_performance.py
import pytest
import time
import statistics
from locust import HttpUser, task, between, events
from locust.env import Environment
from locust.stats import stats_history, StatsEntry
import pandas as pd
from datetime import datetime
class PerformanceTestSuite:
"""性能测试套件"""
def __init__(self, base_url="http://localhost:8000"):
self.base_url = base_url
self.results = {}
def run_single_endpoint_test(self, endpoint, method="GET", payload=None):
"""运行单端点性能测试"""
class SingleEndpointUser(HttpUser):
host = self.base_url
wait_time = between(0.1, 0.5)
@task
def test_endpoint(self):
if method == "GET":
self.client.get(endpoint)
elif method == "POST":
self.client.post(endpoint, json=payload)
elif method == "PUT":
self.client.put(endpoint, json=payload)
elif method == "DELETE":
self.client.delete(endpoint)
# 创建测试环境
env = Environment(user_classes=[SingleEndpointUser], host=self.base_url)
env.create_local_runner()
# 启动测试
env.runner.start(10, spawn_rate=10) # 10个用户,每秒生成10个
time.sleep(30) # 运行30秒
env.runner.stop()
# 收集结果
stats = env.stats
endpoint_stats = stats.get(endpoint, "GET")
if endpoint_stats:
self.results[endpoint] = {
"requests": endpoint_stats.num_requests,
"failures": endpoint_stats.num_failures,
"avg_response_time": endpoint_stats.avg_response_time,
"median_response_time": endpoint_stats.median_response_time,
"min_response_time": endpoint_stats.min_response_time,
"max_response_time": endpoint_stats.max_response_time,
"requests_per_second": endpoint_stats.total_rps,
"failure_rate": endpoint_stats.fail_ratio,
}
return self.results.get(endpoint)
def run_load_test(self, user_count=100, duration=300):
"""运行负载测试"""
class LoadTestUser(HttpUser):
host = self.base_url
wait_time = between(1, 3)
@task(5)
def get_public_data(self):
self.client.get("/api/public/data")
@task(3)
def get_users(self):
self.client.get("/api/users")
@task(2)
def create_item(self):
self.client.post("/api/items", json={
"name": "Test Item",
"description": "Performance test item"
})
@task(1)
def heavy_computation(self):
self.client.get("/api/compute-heavy")
# 创建测试环境
env = Environment(user_classes=[LoadTestUser], host=self.base_url)
env.create_local_runner()
# 设置测试数据收集
test_data = []
@events.request.add_listener
def on_request(request_type, name, response_time, response_length, exception, **kwargs):
test_data.append({
"timestamp": datetime.now(),
"request_type": request_type,
"endpoint": name,
"response_time": response_time,
"response_length": response_length,
"exception": exception
})
# 启动测试
print(f"开始负载测试: {user_count}用户, {duration}秒")
env.runner.start(user_count, spawn_rate=10)
time.sleep(duration)
env.runner.stop()
# 分析结果
df = pd.DataFrame(test_data)
overall_stats = {
"total_requests": len(df),
"failed_requests": df['exception'].notna().sum(),
"failure_rate": df['exception'].notna().sum() / len(df) * 100,
"avg_response_time": df['response_time'].mean(),
"median_response_time": df['response_time'].median(),
"p95_response_time": df['response_time'].quantile(0.95),
"p99_response_time": df['response_time'].quantile(0.99),
"requests_per_second": len(df) / duration,
}
# 按端点统计
endpoint_stats = {}
for endpoint in df['endpoint'].unique():
endpoint_df = df[df['endpoint'] == endpoint]
endpoint_stats[endpoint] = {
"requests": len(endpoint_df),
"failures": endpoint_df['exception'].notna().sum(),
"avg_response_time": endpoint_df['response_time'].mean(),
"p95_response_time": endpoint_df['response_time'].quantile(0.95),
}
return {
"overall": overall_stats,
"endpoints": endpoint_stats,
"raw_data": df
}
def run_stress_test(self, ramp_up_users=50, max_users=500, duration=600):
"""运行压力测试"""
class StressTestUser(HttpUser):
host = self.base_url
wait_time = between(0.5, 1.5)
@task(10)
def light_endpoint(self):
self.client.get("/api/light")
@task(5)
def medium_endpoint(self):
self.client.get("/api/medium")
@task(1)
def heavy_endpoint(self):
self.client.get("/api/heavy")
env = Environment(user_classes=[StressTestUser], host=self.base_url)
env.create_local_runner()
# 逐步增加负载
results = []
current_users = ramp_up_users
while current_users <= max_users:
print(f"当前用户数: {current_users}")
env.runner.start(current_users, spawn_rate=100)
time.sleep(60) # 每个级别运行60秒
env.runner.stop()
# 收集该级别的结果
stats = env.stats
total_requests = sum(s.num_requests for s in stats.values())
total_failures = sum(s.num_failures for s in stats.values())
avg_response_time = statistics.mean(
[s.avg_response_time for s in stats.values() if s.num_requests > 0]
) if any(s.num_requests > 0 for s in stats.values()) else 0
results.append({
"users": current_users,
"requests": total_requests,
"failures": total_failures,
"failure_rate": total_failures / total_requests * 100 if total_requests > 0 else 0,
"avg_response_time": avg_response_time,
"rps": total_requests / 60, # 每秒请求数
})
# 增加用户数
current_users += ramp_up_users
return results
def run_endurance_test(self, user_count=100, duration=3600):
"""运行耐力测试(1小时)"""
class EnduranceTestUser(HttpUser):
host = self.base_url
wait_time = between(2, 5) # 较长的等待时间,模拟真实用户
@task(3)
def browse_data(self):
self.client.get("/api/data")
@task(1)
def submit_form(self):
self.client.post("/api/submit", json={
"field1": "value1",
"field2": "value2"
})
env = Environment(user_classes=[EnduranceTestUser], host=self.base_url)
env.create_local_runner()
# 每小时记录一次统计
hourly_stats = []
def collect_hourly_stats():
for i in range(int(duration / 3600)):
time.sleep(3600) # 等待1小时
stats = env.stats
current_stats = {
"hour": i + 1,
"total_requests": sum(s.num_requests for s in stats.values()),
"total_failures": sum(s.num_failures for s in stats.values()),
"memory_usage": self._get_memory_usage(),
}
hourly_stats.append(current_stats)
print(f"第{i+1}小时统计: {current_stats}")
# 在后台线程中收集统计
import threading
stats_thread = threading.Thread(target=collect_hourly_stats)
stats_thread.daemon = True
stats_thread.start()
# 启动测试
env.runner.start(user_count, spawn_rate=10)
time.sleep(duration)
env.runner.stop()
return hourly_stats
def _get_memory_usage(self):
"""获取内存使用情况"""
import psutil
import os
process = psutil.Process(os.getpid())
return process.memory_info().rss / 1024 / 1024 # MB
def generate_performance_report(self, test_results, output_file="performance_report.html"):
"""生成性能测试报告"""
from jinja2 import Template
template_str = """
<!DOCTYPE html>
<html>
<head>
<title>性能测试报告</title>
<style>
body { font-family: Arial, sans-serif; margin: 40px; }
.metric { margin: 20px 0; padding: 15px; border: 1px solid #ddd; }
.good { background-color: #d4edda; }
.warning { background-color: #fff3cd; }
.critical { background-color: #f8d7da; }
table { border-collapse: collapse; width: 100%; }
th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }
th { background-color: #f2f2f2; }
</style>
</head>
<body>
<h1>性能测试报告</h1>
<p>生成时间: {{ timestamp }}</p>
<h2>总体性能指标</h2>
<div class="metric {{ 'good' if overall.failure_rate < 1 else 'warning' if overall.failure_rate < 5 else 'critical' }}">
<h3>失败率: {{ overall.failure_rate|round(2) }}%</h3>
<p>总请求数: {{ overall.total_requests }}</p>
<p>失败请求数: {{ overall.failed_requests }}</p>
</div>
<div class="metric {{ 'good' if overall.avg_response_time < 100 else 'warning' if overall.avg_response_time < 500 else 'critical' }}">
<h3>响应时间</h3>
<p>平均响应时间: {{ overall.avg_response_time|round(2) }}ms</p>
<p>中位数响应时间: {{ overall.median_response_time|round(2) }}ms</p>
<p>P95响应时间: {{ overall.p95_response_time|round(2) }}ms</p>
<p>P99响应时间: {{ overall.p99_response_time|round(2) }}ms</p>
</div>
<div class="metric">
<h3>吞吐量</h3>
<p>平均RPS: {{ overall.requests_per_second|round(2) }}</p>
</div>
<h2>端点性能详情</h2>
<table>
<tr>
<th>端点</th>
<th>请求数</th>
<th>失败数</th>
<th>失败率</th>
<th>平均响应时间</th>
<th>P95响应时间</th>
</tr>
{% for endpoint, stats in endpoints.items() %}
<tr>
<td>{{ endpoint }}</td>
<td>{{ stats.requests }}</td>
<td>{{ stats.failures }}</td>
<td>{{ (stats.failures / stats.requests * 100)|round(2) if stats.requests > 0 else 0 }}%</td>
<td>{{ stats.avg_response_time|round(2) }}ms</td>
<td>{{ stats.p95_response_time|round(2) }}ms</td>
</tr>
{% endfor %}
</table>
<h2>建议</h2>
<ul>
{% if overall.failure_rate > 5 %}
<li style="color: red;">失败率超过5%,需要立即调查并修复</li>
{% elif overall.failure_rate > 1 %}
<li style="color: orange;">失败率超过1%,建议调查原因</li>
{% else %}
<li style="color: green;">失败率在可接受范围内</li>
{% endif %}
{% if overall.p95_response_time > 1000 %}
<li style="color: red;">P95响应时间超过1秒,需要优化性能</li>
{% elif overall.p95_response_time > 500 %}
<li style="color: orange;">P95响应时间超过500ms,建议优化</li>
{% else %}
<li style="color: green;">响应时间性能良好</li>
{% endif %}
</ul>
</body>
</html>
"""
template = Template(template_str)
html = template.render(
timestamp=datetime.now().isoformat(),
overall=test_results.get("overall", {}),
endpoints=test_results.get("endpoints", {})
)
with open(output_file, "w") as f:
f.write(html)
print(f"性能报告已生成: {output_file}")
# Pytest集成性能测试
@pytest.mark.performance
class TestAPIPerformance:
"""API性能测试"""
@pytest.fixture(scope="class")
def perf_tester(self):
"""性能测试器fixture"""
return PerformanceTestSuite("http://localhost:8000")
def test_single_endpoint_performance(self, perf_tester):
"""测试单端点性能"""
endpoints = [
("/api/public/data", "GET"),
("/api/users", "GET"),
("/api/items", "GET"),
]
results = {}
for endpoint, method in endpoints:
result = perf_tester.run_single_endpoint_test(endpoint, method)
results[endpoint] = result
# 断言性能要求
assert result["avg_response_time"] < 200, f"{endpoint}响应时间过长"
assert result["failure_rate"] < 1, f"{endpoint}失败率过高"
return results
@pytest.mark.slow
def test_load_performance(self, perf_tester):
"""测试负载性能"""
result = perf_tester.run_load_test(user_count=100, duration=60)
# 性能断言
overall = result["overall"]
assert overall["failure_rate"] < 5, f"失败率过高: {overall['failure_rate']}%"
assert overall["avg_response_time"] < 500, f"平均响应时间过长: {overall['avg_response_time']}ms"
assert overall["p95_response_time"] < 1000, f"P95响应时间过长: {overall['p95_response_time']}ms"
# 生成报告
perf_tester.generate_performance_report(result)
return result
@pytest.mark.stress
def test_stress_performance(self, perf_tester):
"""测试压力性能"""
results = perf_tester.run_stress_test(
ramp_up_users=50,
max_users=500,
duration=300
)
# 分析压力测试结果
failure_rates = [r["failure_rate"] for r in results]
response_times = [r["avg_response_time"] for r in results]
# 断言在最大负载下性能可接受
max_load_result = results[-1]
assert max_load_result["failure_rate"] < 10, "高负载下失败率过高"
assert max_load_result["avg_response_time"] < 1000, "高负载下响应时间过长"
return results
@pytest.mark.endurance
def test_endurance_performance(self, perf_tester):
"""测试耐力性能"""
results = perf_tester.run_endurance_test(
user_count=50,
duration=7200 # 2小时
)
# 验证性能稳定性
failure_rates = [r.get("failure_rate", 0) for r in results]
avg_failure_rate = statistics.mean(failure_rates)
assert avg_failure_rate < 2, "耐力测试期间平均失败率过高"
# 检查内存泄漏
memory_usage = [r.get("memory_usage", 0) for r in results]
if len(memory_usage) > 1:
memory_growth = (memory_usage[-1] - memory_usage[0]) / memory_usage[0]
assert memory_growth < 0.5, "可能存在内存泄漏"
return results
# 基准测试
@pytest.mark.benchmark
class TestBenchmark:
"""基准测试"""
@pytest.mark.parametrize("concurrent_users", [1, 10, 50, 100])
def test_concurrent_users_benchmark(self, concurrent_users):
"""并发用户基准测试"""
tester = PerformanceTestSuite()
result = tester.run_load_test(
user_count=concurrent_users,
duration=30
)
overall = result["overall"]
print(f"\n并发用户数: {concurrent_users}")
print(f"平均RPS: {overall['requests_per_second']:.2f}")
print(f"平均响应时间: {overall['avg_response_time']:.2f}ms")
print(f"P95响应时间: {overall['p95_response_time']:.2f}ms")
# 基准要求
if concurrent_users <= 10:
assert overall["avg_response_time"] < 100
elif concurrent_users <= 50:
assert overall["avg_response_time"] < 300
else:
assert overall["avg_response_time"] < 500
return overall
def test_comparison_benchmark(self):
"""对比基准测试"""
# 测试不同配置的性能
configurations = [
{"name": "默认配置", "url": "http://localhost:8000"},
{"name": "优化配置", "url": "http://localhost:8001"},
{"name": "生产配置", "url": "http://localhost:8002"},
]
results = {}
for config in configurations:
tester = PerformanceTestSuite(config["url"])
result = tester.run_load_test(user_count=50, duration=60)
results[config["name"]] = result["overall"]
# 生成对比报告
self._generate_comparison_report(results)
return results
def _generate_comparison_report(self, results):
"""生成对比报告"""
import matplotlib.pyplot as plt
# 创建对比图表
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# 响应时间对比
config_names = list(results.keys())
avg_response_times = [r["avg_response_time"] for r in results.values()]
p95_response_times = [r["p95_response_time"] for r in results.values()]
axes[0, 0].bar(config_names, avg_response_times)
axes[0, 0].set_title("平均响应时间对比")
axes[0, 0].set_ylabel("毫秒")
axes[0, 1].bar(config_names, p95_response_times)
axes[0, 1].set_title("P95响应时间对比")
axes[0, 1].set_ylabel("毫秒")
# 吞吐量对比
rps_values = [r["requests_per_second"] for r in results.values()]
axes[1, 0].bar(config_names, rps_values)
axes[1, 0].set_title("吞吐量对比")
axes[1, 0].set_ylabel("RPS")
# 失败率对比
failure_rates = [r["failure_rate"] for r in results.values()]
axes[1, 1].bar(config_names, failure_rates)
axes[1, 1].set_title("失败率对比")
axes[1, 1].set_ylabel("百分比")
plt.tight_layout()
plt.savefig("performance_comparison.png")
plt.close()
print("对比报告已生成: performance_comparison.png")7. 测试覆盖率与CI集成
测试覆盖率分析
python
# .coveragerc - 覆盖率配置文件
[run]
# 要测量的源文件路径
source = app
# 要忽略的文件
omit =
app/migrations/*
app/tests/*
*/__pycache__/*
*/site-packages/*
# 分支覆盖率
branch = True
# 并行运行
parallel = True
# 数据文件位置
data_file = .coverage
[report]
# 显示哪些行未覆盖
show_missing = True
# 忽略哪些文件
exclude_lines =
# 忽略pragma: no cover
pragma: no cover
# 忽略类型定义
def __repr__
def __str__
# 忽略测试代码
if __name__ == .__main__.:
# 忽略抽象方法
@abstractmethod
# 忽略调试代码
import pdb
pdb.set_trace
# 忽略日志语句
logger\.(debug|info|warning|error|critical)
# 覆盖率阈值
fail_under = 80
# 输出格式
format = markdown
[html]
# HTML报告目录
directory = coverage_html
# 标题
title = FastAPI Test Coverage Report
# 显示哪些文件
show_contexts = True
[json]
# JSON报告文件
output = coverage.json
[xml]
# XML报告文件(用于CI集成)
output = coverage.xml
# pytest-cov命令行使用
"""
# 基本用法
pytest --cov=app tests/
# 包含分支覆盖率
pytest --cov=app --cov-branch tests/
# 生成HTML报告
pytest --cov=app --cov-report=html tests/
# 生成多种格式报告
pytest --cov=app --cov-report=term --cov-report=html --cov-report=xml tests/
# 设置覆盖率阈值
pytest --cov=app --cov-fail-under=80 tests/
# 并行运行
pytest --cov=app -n auto tests/
"""
# 覆盖率分析工具
import coverage
import json
from pathlib import Path
class CoverageAnalyzer:
"""覆盖率分析器"""
def __init__(self, source_dir="app", coverage_file=".coverage"):
self.source_dir = Path(source_dir)
self.coverage_file = coverage_file
self.cov = coverage.Coverage(
source=[str(self.source_dir)],
data_file=coverage_file
)
def run_tests_with_coverage(self):
"""运行测试并收集覆盖率数据"""
import subprocess
import sys
# 开始收集覆盖率数据
self.cov.start()
# 运行pytest
result = subprocess.run([
sys.executable, "-m", "pytest",
"tests/",
"-v",
"--tb=short"
])
# 停止收集并保存
self.cov.stop()
self.cov.save()
return result.returncode
def generate_reports(self):
"""生成覆盖率报告"""
# 生成文本报告
print("\n" + "="*60)
print("测试覆盖率报告")
print("="*60)
self.cov.report()
# 生成HTML报告
self.cov.html_report(directory="coverage_html")
print(f"\nHTML报告: file://{Path('coverage_html').absolute()}/index.html")
# 生成JSON报告
self.cov.json_report(outfile="coverage.json")
# 生成XML报告(用于CI)
self.cov.xml_report(outfile="coverage.xml")
# 生成LCOV报告(用于Codecov等)
self.cov.lcov_report(outfile="coverage.lcov")
def analyze_coverage_data(self):
"""分析覆盖率数据"""
# 加载覆盖率数据
self.cov.load()
# 获取总体统计
total_stats = self.cov.get_data().summary()
print(f"\n总体覆盖率:")
print(f" 文件数: {total_stats.num_files}")
print(f" 总行数: {total_stats.num_statements}")
print(f" 覆盖行数: {total_stats.covered_lines}")
print(f" 覆盖率: {total_stats.percent_covered:.1f}%")
# 分析每个文件的覆盖率
file_stats = {}
for filename in self.cov.get_data().measured_files():
file_cov = self.cov.get_data().lines(filename)
file_stats[filename] = {
"total_lines": len(file_cov),
"covered_lines": len([line for line in file_cov if self.cov.get_data().has_line(filename, line)]),
"missing_lines": [line for line in file_cov if not self.cov.get_data().has_line(filename, line)]
}
# 找出覆盖率低的文件
low_coverage_files = []
for filename, stats in file_stats.items():
coverage_pct = (stats["covered_lines"] / stats["total_lines"] * 100) if stats["total_lines"] > 0 else 0
if coverage_pct < 80:
low_coverage_files.append({
"file": filename,
"coverage": coverage_pct,
"missing_lines": stats["missing_lines"][:10] # 只显示前10行
})
if low_coverage_files:
print(f"\n需要改进的文件(覆盖率<80%):")
for file_info in sorted(low_coverage_files, key=lambda x: x["coverage"]):
print(f" {file_info['file']}: {file_info['coverage']:.1f}%")
if file_info["missing_lines"]:
print(f" 未覆盖行号: {file_info['missing_lines'][:5]}...")
return {
"total": total_stats,
"files": file_stats,
"low_coverage": low_coverage_files
}
def generate_coverage_badge(self, output_file="coverage_badge.svg"):
"""生成覆盖率徽章"""
self.cov.load()
total_stats = self.cov.get_data().summary()
coverage_percent = total_stats.percent_covered
# 选择颜色
if coverage_percent >= 90:
color = "brightgreen"
elif coverage_percent >= 80:
color = "green"
elif coverage_percent >= 70:
color = "yellowgreen"
elif coverage_percent >= 60:
color = "yellow"
elif coverage_percent >= 50:
color = "orange"
else:
color = "red"
# 生成SVG徽章
badge_svg = f'''<svg xmlns="http://www.w3.org/2000/svg" width="120" height="20">
<linearGradient id="b" x2="0" y2="100%">
<stop offset="0" stop-color="#bbb" stop-opacity=".1"/>
<stop offset="1" stop-opacity=".1"/>
</linearGradient>
<mask id="a">
<rect width="120" height="20" rx="3" fill="#fff"/>
</mask>
<g mask="url(#a)">
<path fill="#555" d="M0 0h60v20H0z"/>
<path fill="#{color}" d="M60 0h60v20H60z"/>
<path fill="url(#b)" d="M0 0h120v20H0z"/>
</g>
<g fill="#fff" text-anchor="middle" font-family="DejaVu Sans,Verdana,Geneva,sans-serif" font-size="11">
<text x="30" y="15" fill="#010101" fill-opacity=".3">coverage</text>
<text x="30" y="14">coverage</text>
<text x="90" y="15" fill="#010101" fill-opacity=".3">{coverage_percent:.1f}%</text>
<text x="90" y="14">{coverage_percent:.1f}%</text>
</g>
</svg>'''
with open(output_file, "w") as f:
f.write(badge_svg)
print(f"覆盖率徽章已生成: {output_file}")
def check_coverage_threshold(self, min_coverage=80):
"""检查覆盖率是否达到阈值"""
self.cov.load()
total_stats = self.cov.get_data().summary()
coverage_percent = total_stats.percent_covered
if coverage_percent < min_coverage:
print(f"错误: 覆盖率 {coverage_percent:.1f}% 低于阈值 {min_coverage}%")
return False
else:
print(f"通过: 覆盖率 {coverage_percent:.1f}% 达到阈值 {min_coverage}%")
return True
# 使用示例
if __name__ == "__main__":
analyzer = CoverageAnalyzer()
# 运行测试并收集覆盖率
exit_code = analyzer.run_tests_with_coverage()
# 生成报告
analyzer.generate_reports()
# 分析数据
analysis = analyzer.analyze_coverage_data()
# 生成徽章
analyzer.generate_coverage_badge()
# 检查阈值
success = analyzer.check_coverage_threshold(80)
sys.exit(0 if success else 1)CI/CD集成
yaml
# .github/workflows/ci.yml
name: CI Pipeline
on:
push:
branches: [ main, develop ]
pull_request:
branches: [ main ]
jobs:
test:
name: Test and Coverage
runs-on: ubuntu-latest
services:
postgres:
image: postgres:13
env:
POSTGRES_USER: test_user
POSTGRES_PASSWORD: test_password
POSTGRES_DB: test_db
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
redis:
image: redis:7-alpine
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 6379:6379
env:
DATABASE_URL: postgresql://test_user:test_password@localhost:5432/test_db
REDIS_URL: redis://localhost:6379/0
SECRET_KEY: test_secret_key
ENVIRONMENT: testing
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-test.txt
- name: Run linting
run: |
# 代码风格检查
black --check app tests
flake8 app tests
mypy app
- name: Run unit tests
run: |
pytest tests/unit/ -v --cov=app --cov-report=term-missing
- name: Run integration tests
run: |
pytest tests/integration/ -v --cov=app --cov-append
- name: Run e2e tests
run: |
pytest tests/e2e/ -v --cov=app --cov-append
- name: Generate coverage report
run: |
pytest --cov=app --cov-report=xml --cov-report=html
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
name: codecov-umbrella
- name: Upload test artifacts
if: always()
uses: actions/upload-artifact@v3
with:
name: test-reports-${{ matrix.python-version }}
path: |
coverage_html/
test-results.xml
retention-days: 30
performance:
name: Performance Test
runs-on: ubuntu-latest
needs: test
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install locust
- name: Start application
run: |
uvicorn app.main:app --host 0.0.0.0 --port 8000 &
sleep 10 # 等待应用启动
- name: Run performance tests
run: |
locust -f tests/performance/locustfile.py \
--headless \
-u 100 \
-r 10 \
-t 1m \
--host=http://localhost:8000 \
--csv=performance_results
- name: Upload performance results
if: always()
uses: actions/upload-artifact@v3
with:
name: performance-results
path: performance_results*.csv
retention-days: 30
security:
name: Security Scan
runs-on: ubuntu-latest
needs: test
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Run bandit security scan
run: |
pip install bandit
bandit -r app -f json -o bandit-report.json
- name: Run safety dependency check
run: |
pip install safety
safety check --json > safety-report.json
- name: Upload security reports
if: always()
uses: actions/upload-artifact@v3
with:
name: security-reports
path: |
bandit-report.json
safety-report.json
build:
name: Build Docker Image
runs-on: ubuntu-latest
needs: [test, performance, security]
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Login to DockerHub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Build and push
uses: docker/build-push-action@v4
with:
context: .
push: true
tags: |
${{ secrets.DOCKER_USERNAME }}/fastapi-app:latest
${{ secrets.DOCKER_USERNAME }}/fastapi-app:${{ github.sha }}
cache-from: type=registry,ref=${{ secrets.DOCKER_USERNAME }}/fastapi-app:buildcache
cache-to: type=registry,ref=${{ secrets.DOCKER_USERNAME }}/fastapi-app:buildcache,mode=max
deploy:
name: Deploy to Staging
runs-on: ubuntu-latest
needs: build
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
environment: staging
steps:
- name: Deploy to Kubernetes
uses: appleboy/ssh-action@v0.1.5
with:
host: ${{ secrets.STAGING_HOST }}
username: ${{ secrets.STAGING_USERNAME }}
key: ${{ secrets.STAGING_SSH_KEY }}
script: |
kubectl set image deployment/fastapi-app \
fastapi-app=${{ secrets.DOCKER_USERNAME }}/fastapi-app:${{ github.sha }}
- name: Run smoke tests
run: |
curl --retry 5 --retry-delay 10 \
https://staging.example.com/health
pytest tests/smoke/ -v
# Jenkinsfile (声明式流水线)
pipeline {
agent any
environment {
DOCKER_REGISTRY = 'registry.example.com'
APP_NAME = 'fastapi-app'
PYTHON_VERSION = '3.11'
}
stages {
stage('Checkout') {
steps {
checkout scm
}
}
stage('Setup') {
steps {
sh '''
python -m venv venv
source venv/bin/activate
pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-test.txt
'''
}
}
stage('Lint') {
steps {
sh '''
source venv/bin/activate
black --check app tests
flake8 app tests
mypy app
'''
}
}
stage('Unit Tests') {
steps {
sh '''
source venv/bin/activate
pytest tests/unit/ -v \
--junitxml=test-results-unit.xml \
--cov=app \
--cov-report=xml:coverage-unit.xml
'''
}
post {
always {
junit 'test-results-unit.xml'
cobertura coberturaReportFile: 'coverage-unit.xml'
}
}
}
stage('Integration Tests') {
steps {
sh '''
source venv/bin/activate
docker-compose -f docker-compose.test.yml up -d
sleep 30
pytest tests/integration/ -v \
--junitxml=test-results-integration.xml \
--cov=app --cov-append \
--cov-report=xml:coverage-integration.xml
docker-compose -f docker-compose.test.yml down
'''
}
post {
always {
junit 'test-results-integration.xml'
cobertura coberturaReportFile: 'coverage-integration.xml'
}
}
}
stage('Build Docker Image') {
when {
branch 'main'
}
steps {
script {
docker.build("${DOCKER_REGISTRY}/${APP_NAME}:${env.BUILD_ID}")
}
}
}
stage('Push Docker Image') {
when {
branch 'main'
}
steps {
script {
docker.withRegistry("https://${DOCKER_REGISTRY}", 'docker-credentials') {
docker.image("${DOCKER_REGISTRY}/${APP_NAME}:${env.BUILD_ID}").push()
docker.image("${DOCKER_REGISTRY}/${APP_NAME}:${env.BUILD_ID}").push('latest')
}
}
}
}
stage('Deploy to Staging') {
when {
branch 'main'
}
steps {
sh '''
kubectl config use-context staging
kubectl set image deployment/${APP_NAME} \
${APP_NAME}=${DOCKER_REGISTRY}/${APP_NAME}:${env.BUILD_ID}
kubectl rollout status deployment/${APP_NAME}
'''
}
}
stage('Smoke Tests') {
when {
branch 'main'
}
steps {
sh '''
source venv/bin/activate
pytest tests/smoke/ -v \
--junitxml=test-results-smoke.xml
'''
}
post {
always {
junit 'test-results-smoke.xml'
}
}
}
}
post {
always {
sh '''
source venv/bin/activate
coverage combine coverage-*.xml
coverage report
coverage html
'''
// 清理
sh 'docker system prune -f'
}
success {
// 发送成功通知
emailext (
subject: "构建成功: ${env.JOB_NAME} #${env.BUILD_NUMBER}",
body: "构建 ${env.BUILD_URL} 成功完成。",
to: 'team@example.com'
)
}
failure {
// 发送失败通知
emailext (
subject: "构建失败: ${env.JOB_NAME} #${env.BUILD_NUMBER}",
body: "构建 ${env.BUILD_URL} 失败。",
to: 'team@example.com',
attachLog: true
)
}
}
}
# GitLab CI配置
# .gitlab-ci.yml
stages:
- test
- build
- deploy
variables:
DOCKER_IMAGE: registry.gitlab.com/$CI_PROJECT_PATH
DATABASE_URL: postgresql://postgres:password@postgres:5432/test_db
services:
- postgres:13-alpine
- redis:7-alpine
cache:
paths:
- venv/
key: $CI_COMMIT_REF_SLUG
before_script:
- python --version
- pip install virtualenv
- virtualenv venv
- source venv/bin/activate
- pip install -r requirements.txt
- pip install -r requirements-test.txt
unit_tests:
stage: test
script:
- pytest tests/unit/ -v --cov=app --cov-report=xml --cov-report=html
artifacts:
paths:
- coverage.xml
- htmlcov/
reports:
junit: test-results.xml
cobertura: coverage.xml
integration_tests:
stage: test
script:
- pytest tests/integration/ -v
artifacts:
paths:
- test-results-integration.xml
reports:
junit: test-results-integration.xml
performance_tests:
stage: test
script:
- pip install locust
- locust -f tests/performance/locustfile.py --headless -u 10 -r 1 -t 30s --host=http://localhost:8000
artifacts:
paths:
- locust_stats.csv
build_image:
stage: build
image: docker:20.10.16
services:
- docker:20.10.16-dind
variables:
DOCKER_TLS_CERTDIR: "/certs"
script:
- docker build -t $DOCKER_IMAGE:$CI_COMMIT_SHA -t $DOCKER_IMAGE:latest .
- docker push $DOCKER_IMAGE:$CI_COMMIT_SHA
- docker push $DOCKER_IMAGE:latest
only:
- main
deploy_staging:
stage: deploy
image: alpine:latest
script:
- apk add --no-cache curl
- |
curl -X POST \
-H "Content-Type: application/json" \
-d '{"image":"'"$DOCKER_IMAGE:$CI_COMMIT_SHA"'"}' \
$STAGING_DEPLOY_WEBHOOK
environment:
name: staging
url: https://staging.example.com
only:
- main
# 质量门禁配置
# sonar-project.properties
sonar.projectKey=fastapi_app
sonar.projectName=FastAPI Application
sonar.projectVersion=1.0
sonar.sources=app
sonar.tests=tests
sonar.test.inclusions=tests/**/*.py
sonar.exclusions=**/migrations/**,**/__pycache__/**,**/*.pyc
sonar.python.coverage.reportPaths=coverage.xml
sonar.python.xunit.reportPath=test-results.xml
sonar.python.pylint.reportPath=pylint-report.txt
sonar.sourceEncoding=UTF-8
sonar.host.url=https://sonar.example.com
sonar.login=${SONAR_TOKEN}
# 代码质量检查脚本
# scripts/quality_check.py
#!/usr/bin/env python3
"""
代码质量检查脚本
"""
import subprocess
import sys
import json
from pathlib import Path
class QualityChecker:
"""质量检查器"""
def __init__(self):
self.results = {
"linting": {"passed": False, "output": ""},
"type_checking": {"passed": False, "output": ""},
"security": {"passed": False, "output": ""},
"tests": {"passed": False, "output": ""},
"coverage": {"passed": False, "coverage": 0}
}
def run_black(self):
"""运行Black代码格式化检查"""
try:
result = subprocess.run(
["black", "--check", "app", "tests"],
capture_output=True,
text=True
)
self.results["linting"]["passed"] = result.returncode == 0
self.results["linting"]["output"] = result.stdout
return result.returncode == 0
except Exception as e:
print(f"Black检查失败: {e}")
return False
def run_flake8(self):
"""运行Flake8代码风格检查"""
try:
result = subprocess.run(
["flake8", "app", "tests"],
capture_output=True,
text=True
)
self.results["linting"]["passed"] = self.results["linting"]["passed"] and result.returncode == 0
self.results["linting"]["output"] += "\n" + result.stdout
return result.returncode == 0
except Exception as e:
print(f"Flake8检查失败: {e}")
return False
def run_mypy(self):
"""运行Mypy类型检查"""
try:
result = subprocess.run(
["mypy", "app"],
capture_output=True,
text=True
)
self.results["type_checking"]["passed"] = result.returncode == 0
self.results["type_checking"]["output"] = result.stdout
return result.returncode == 0
except Exception as e:
print(f"Mypy检查失败: {e}")
return False
def run_bandit(self):
"""运行Bandit安全扫描"""
try:
result = subprocess.run(
["bandit", "-r", "app", "-f", "json"],
capture_output=True,
text=True
)
self.results["security"]["passed"] = True
# 解析Bandit结果
if result.returncode == 0:
try:
bandit_data = json.loads(result.stdout)
issues = bandit_data.get("metrics", {}).get("_totals", {}).get("SEVERITY.HIGH", 0)
self.results["security"]["passed"] = issues == 0
except:
pass
self.results["security"]["output"] = result.stdout
return self.results["security"]["passed"]
except Exception as e:
print(f"Bandit检查失败: {e}")
return False
def run_tests(self):
"""运行测试"""
try:
result = subprocess.run(
["pytest", "tests/", "-v", "--junitxml=test-results.xml"],
capture_output=True,
text=True
)
self.results["tests"]["passed"] = result.returncode == 0
self.results["tests"]["output"] = result.stdout
return result.returncode == 0
except Exception as e:
print(f"测试运行失败: {e}")
return False
def check_coverage(self):
"""检查测试覆盖率"""
try:
result = subprocess.run(
["pytest", "--cov=app", "--cov-report=term-missing", "tests/"],
capture_output=True,
text=True
)
# 从输出中提取覆盖率
import re
coverage_match = re.search(r'TOTAL\s+\d+\s+\d+\s+(\d+)%', result.stdout)
if coverage_match:
coverage = int(coverage_match.group(1))
self.results["coverage"]["coverage"] = coverage
self.results["coverage"]["passed"] = coverage >= 80
self.results["tests"]["passed"] = self.results["tests"]["passed"] and result.returncode == 0
return self.results["coverage"]["passed"]
except Exception as e:
print(f"覆盖率检查失败: {e}")
return False
def generate_report(self):
"""生成质量报告"""
report = {
"timestamp": datetime.now().isoformat(),
"results": self.results,
"summary": {
"passed_all": all([
self.results["linting"]["passed"],
self.results["type_checking"]["passed"],
self.results["security"]["passed"],
self.results["tests"]["passed"],
self.results["coverage"]["passed"]
]),
"coverage": self.results["coverage"]["coverage"]
}
}
# 保存报告
with open("quality-report.json", "w") as f:
json.dump(report, f, indent=2)
# 打印总结
print("\n" + "="*60)
print("代码质量检查报告")
print("="*60)
for check, result in self.results.items():
status = "✓" if result["passed"] else "✗"
print(f"{status} {check.upper()}")
if check == "coverage":
print(f" 覆盖率: {result['coverage']}%")
print("\n详细报告: quality-report.json")
return report["summary"]["passed_all"]
def run_all_checks(self):
"""运行所有检查"""
print("开始代码质量检查...")
checks = [
("代码格式化", self.run_black),
("代码风格", self.run_flake8),
("类型检查", self.run_mypy),
("安全扫描", self.run_bandit),
("单元测试", self.run_tests),
("覆盖率检查", self.check_coverage),
]
for name, check_func in checks:
print(f"\n正在执行: {name}")
if not check_func():
print(f" {name} 失败")
return self.generate_report()
if __name__ == "__main__":
from datetime import datetime
checker = QualityChecker()
success = checker.run_all_checks()
sys.exit(0 if success else 1)总结
构建完整的FastAPI测试体系需要系统化的方法和工具链。通过本章的学习,你应该能够:
关键要点
测试策略
- 遵循测试金字塔原则
- 合理分配单元、集成、E2E测试比例
- 根据业务需求调整测试策略
测试工具
- 熟练掌握pytest框架
- 有效测试异步代码
- 使用Mock和Patch模拟外部依赖
数据库测试
- 管理测试数据库事务
- 使用工厂模式创建测试数据
- 测试数据完整性和一致性
性能测试
- 使用Locust进行负载测试
- 分析性能瓶颈
- 建立性能基准
质量保证
- 监控测试覆盖率
- 集成CI/CD流水线
- 实施代码质量门禁
最佳实践
测试命名规范
python# 好的命名 test_user_can_login_with_valid_credentials() test_should_raise_error_when_email_is_invalid() # 不好的命名 test1() test_login()测试隔离
- 每个测试独立运行
- 测试之间不依赖顺序
- 清理测试数据
测试数据管理
- 使用fixture创建测试数据
- 避免硬编码测试数据
- 清理测试环境
测试报告
- 生成清晰的测试报告
- 监控测试趋势
- 及时修复失败的测试
工具推荐
- 测试框架: pytest + pytest-asyncio
- Mock库: unittest.mock
- 覆盖率: pytest-cov + coverage.py
- 性能测试: Locust
- CI/CD: GitHub Actions, GitLab CI, Jenkins
- 代码质量: Black, Flake8, Mypy, Bandit
- 监控: SonarQube, Codecov
故障排除
测试运行缓慢
- 优化数据库查询
- 使用测试数据库
- 并行运行测试
测试不稳定
- 检查时间相关的测试
- 避免竞争条件
- 增加适当的等待时间
Mock过于复杂
- 重构被测试代码
- 使用依赖注入
- 考虑集成测试替代
扩展学习
最后提醒:测试不是一次性的工作,而是持续的过程。建立良好的测试文化,让每个团队成员都重视测试。记住:好的测试不仅能发现bug,更能提升代码质量,加快开发速度,增强团队信心。
测试就像保险,平时可能感觉不到它的价值,但关键时刻它能拯救你的项目。投资在测试上的每一分钟,都会在未来带来十倍的回报。